Merge pull request #8 from ljensen505/refactor-async

removed many unneeded async definitions. The main controller remains …
This commit is contained in:
Lucas Jensen
2024-05-03 17:28:41 -07:00
committed by GitHub
13 changed files with 135 additions and 145 deletions

View File

@@ -21,7 +21,7 @@ class BaseController:
self.ALL_FILES = ALLOWED_FILES_TYPES self.ALL_FILES = ALLOWED_FILES_TYPES
self.MAX_FILE_SIZE = ONE_MB self.MAX_FILE_SIZE = ONE_MB
async def verify_image(self, file: UploadFile) -> bytes: def verify_image(self, file: UploadFile) -> bytes:
"""Verifies that the file is an image and is within the maximum file size. """Verifies that the file is an image and is within the maximum file size.
Args: Args:
@@ -38,7 +38,8 @@ class BaseController:
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type {file.content_type} not allowed. Allowed file types are {self.ALL_FILES}", detail=f"File type {file.content_type} not allowed. Allowed file types are {self.ALL_FILES}",
) )
image_file = await file.read() with file.file as f:
image_file = f.read()
if len(image_file) > self.MAX_FILE_SIZE: if len(image_file) > self.MAX_FILE_SIZE:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View File

@@ -19,6 +19,8 @@ class MainController:
All methods are either pass-throughs to the appropriate controller or All methods are either pass-throughs to the appropriate controller or
are used to coordinate multiple controllers. are used to coordinate multiple controllers.
All methods are asynchronous to facilitate asynchronous calls from the Router layer.
token-based authentication is handled here as needed per the nature of the data being accessed. token-based authentication is handled here as needed per the nature of the data being accessed.
""" """
@@ -29,10 +31,10 @@ class MainController:
self.group_controller = GroupController() self.group_controller = GroupController()
async def get_musicians(self) -> list[Musician]: async def get_musicians(self) -> list[Musician]:
return await self.musician_controller.get_musicians() return self.musician_controller.get_musicians()
async def get_musician(self, id: int) -> Musician: async def get_musician(self, id: int) -> Musician:
return await self.musician_controller.get_musician(id) return self.musician_controller.get_musician(id)
async def update_musician( async def update_musician(
self, self,
@@ -48,60 +50,60 @@ class MainController:
detail="ID in URL does not match ID in request body", detail="ID in URL does not match ID in request body",
) )
_, sub = oauth_token.email_and_sub(token) _, sub = oauth_token.email_and_sub(token)
await self.user_controller.get_user_by_sub(sub) self.user_controller.get_user_by_sub(sub)
return await self.musician_controller.update_musician( return self.musician_controller.update_musician(
musician_id=musician.id, musician_id=musician.id,
new_bio=musician.bio, new_bio=musician.bio,
file=file, file=file,
) )
async def get_events(self) -> list[EventSeries]: async def get_events(self) -> list[EventSeries]:
return await self.event_controller.get_all_series() return self.event_controller.get_all_series()
async def get_event(self, id: int) -> EventSeries: async def get_event(self, id: int) -> EventSeries:
return await self.event_controller.get_one_series_by_id(id) return self.event_controller.get_one_series_by_id(id)
async def create_event( async def create_event(
self, series: NewEventSeries, token: HTTPAuthorizationCredentials self, series: NewEventSeries, token: HTTPAuthorizationCredentials
) -> EventSeries: ) -> EventSeries:
_, sub = oauth_token.email_and_sub(token) _, sub = oauth_token.email_and_sub(token)
await self.user_controller.get_user_by_sub(sub) self.user_controller.get_user_by_sub(sub)
return await self.event_controller.create_series(series) return self.event_controller.create_series(series)
async def add_series_poster( async def add_series_poster(
self, series_id: int, poster: UploadFile, token: HTTPAuthorizationCredentials self, series_id: int, poster: UploadFile, token: HTTPAuthorizationCredentials
) -> EventSeries: ) -> EventSeries:
_, sub = oauth_token.email_and_sub(token) _, sub = oauth_token.email_and_sub(token)
await self.user_controller.get_user_by_sub(sub) self.user_controller.get_user_by_sub(sub)
return await self.event_controller.add_series_poster(series_id, poster) return self.event_controller.add_series_poster(series_id, poster)
async def delete_series(self, id: int, token: HTTPAuthorizationCredentials) -> None: async def delete_series(self, id: int, token: HTTPAuthorizationCredentials) -> None:
_, sub = oauth_token.email_and_sub(token) _, sub = oauth_token.email_and_sub(token)
await self.user_controller.get_user_by_sub(sub) self.user_controller.get_user_by_sub(sub)
await self.event_controller.delete_series(id) self.event_controller.delete_series(id)
async def update_series( async def update_series(
self, route_id: int, series: EventSeries, token: HTTPAuthorizationCredentials self, route_id: int, series: EventSeries, token: HTTPAuthorizationCredentials
) -> EventSeries: ) -> EventSeries:
_, sub = oauth_token.email_and_sub(token) _, sub = oauth_token.email_and_sub(token)
await self.user_controller.get_user_by_sub(sub) self.user_controller.get_user_by_sub(sub)
return await self.event_controller.update_series(route_id, series) return self.event_controller.update_series(route_id, series)
async def get_users(self) -> list[User]: async def get_users(self) -> list[User]:
return await self.user_controller.get_users() return self.user_controller.get_users()
async def get_user(self, id: int) -> User: async def get_user(self, id: int) -> User:
return await self.user_controller.get_user_by_id(id) return self.user_controller.get_user_by_id(id)
async def create_user(self, token: HTTPAuthorizationCredentials) -> User: async def create_user(self, token: HTTPAuthorizationCredentials) -> User:
return await self.user_controller.create_user(token) return self.user_controller.create_user(token)
async def get_group(self) -> Group: async def get_group(self) -> Group:
return await self.group_controller.get_group() return self.group_controller.get_group()
async def update_group_bio( async def update_group_bio(
self, bio: str, token: HTTPAuthorizationCredentials self, bio: str, token: HTTPAuthorizationCredentials
) -> Group: ) -> Group:
_, sub = oauth_token.email_and_sub(token) _, sub = oauth_token.email_and_sub(token)
await self.user_controller.get_user_by_sub(sub) self.user_controller.get_user_by_sub(sub)
return await self.group_controller.update_group_bio(bio) return self.group_controller.update_group_bio(bio)

View File

@@ -43,7 +43,7 @@ class EventController(BaseController):
return all_series return all_series
async def get_all_series(self) -> list[EventSeries]: def get_all_series(self) -> list[EventSeries]:
"""Retrieves all EventSeries objects from the database and returns them as a list. """Retrieves all EventSeries objects from the database and returns them as a list.
Raises: Raises:
@@ -52,7 +52,7 @@ class EventController(BaseController):
Returns: Returns:
list[EventSeries]: A list of EventSeries objects which are suitable for a response body list[EventSeries]: A list of EventSeries objects which are suitable for a response body
""" """
series_data = await self.db.select_all_series() series_data = self.db.select_all_series()
try: try:
return [series for series in self._all_series(series_data).values()] return [series for series in self._all_series(series_data).values()]
@@ -63,7 +63,7 @@ class EventController(BaseController):
detail=f"Error retrieving event objects: {e}", detail=f"Error retrieving event objects: {e}",
) )
async def get_one_series_by_id(self, series_id: int) -> EventSeries: def get_one_series_by_id(self, series_id: int) -> EventSeries:
"""Builds and returns a single EventSeries object by its numeric ID. """Builds and returns a single EventSeries object by its numeric ID.
Args: Args:
@@ -76,7 +76,7 @@ class EventController(BaseController):
Returns: Returns:
EventSeries: A single EventSeries object EventSeries: A single EventSeries object
""" """
if not (rows := await self.db.select_one_by_id(series_id)): if not (rows := self.db.select_one_by_id(series_id)):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Event not found" status_code=status.HTTP_404_NOT_FOUND, detail="Event not found"
) )
@@ -90,7 +90,7 @@ class EventController(BaseController):
detail=f"Error creating event object: {e}", detail=f"Error creating event object: {e}",
) )
async def create_series(self, series: NewEventSeries) -> EventSeries: def create_series(self, series: NewEventSeries) -> EventSeries:
"""Takes a NewEventSeries object and passes it to the database layer for insertion. """Takes a NewEventSeries object and passes it to the database layer for insertion.
Args: Args:
@@ -103,19 +103,17 @@ class EventController(BaseController):
EventSeries: The newly created EventSeries object with an ID EventSeries: The newly created EventSeries object with an ID
""" """
try: try:
inserted_id = await self.db.insert_one_series(series) inserted_id = self.db.insert_one_series(series)
for new_event in series.events: for new_event in series.events:
await self.db.insert_one_event(new_event, inserted_id) self.db.insert_one_event(new_event, inserted_id)
return await self.get_one_series_by_id(inserted_id) return self.get_one_series_by_id(inserted_id)
except IntegrityError as e: except IntegrityError as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Series name already exists. Each series must have a unique name.\n{e}", detail=f"Series name already exists. Each series must have a unique name.\n{e}",
) )
async def add_series_poster( def add_series_poster(self, series_id: int, poster: UploadFile) -> EventSeries:
self, series_id: int, poster: UploadFile
) -> EventSeries:
"""Adds (or updates) a poster image to a series. """Adds (or updates) a poster image to a series.
Args: Args:
@@ -125,12 +123,12 @@ class EventController(BaseController):
Returns: Returns:
EventSeries: The updated EventSeries object EventSeries: The updated EventSeries object
""" """
series = await self.get_one_series_by_id(series_id) series = self.get_one_series_by_id(series_id)
series.poster_id = await self._upload_poster(poster) series.poster_id = self._upload_poster(poster)
await self.db.update_series_poster(series) self.db.update_series_poster(series)
return await self.get_one_series_by_id(series.series_id) return self.get_one_series_by_id(series.series_id)
async def _upload_poster(self, poster: UploadFile) -> str: def _upload_poster(self, poster: UploadFile) -> str:
"""Uploads a poster image to Cloudinary and returns the public ID for storage in the database. """Uploads a poster image to Cloudinary and returns the public ID for storage in the database.
Should only be used internally. Should only be used internally.
@@ -143,7 +141,7 @@ class EventController(BaseController):
Returns: Returns:
str: The public ID of the uploaded image str: The public ID of the uploaded image
""" """
image_file = await self.verify_image(poster) image_file = self.verify_image(poster)
try: try:
data = uploader.upload(image_file) data = uploader.upload(image_file)
return data.get("public_id") return data.get("public_id")
@@ -153,16 +151,16 @@ class EventController(BaseController):
detail=f"Error uploading image: {e}", detail=f"Error uploading image: {e}",
) )
async def delete_series(self, id: int) -> None: def delete_series(self, id: int) -> None:
"""Ensures an EventSeries object exists and then deletes it from the database. """Ensures an EventSeries object exists and then deletes it from the database.
Args: Args:
id (int): The numeric ID of the series to delete id (int): The numeric ID of the series to delete
""" """
series = await self.get_one_series_by_id(id) series = self.get_one_series_by_id(id)
await self.db.delete_one_series(series) self.db.delete_one_series(series)
async def update_series(self, route_id: int, series: EventSeries) -> EventSeries: def update_series(self, route_id: int, series: EventSeries) -> EventSeries:
"""Updates an existing EventSeries object in the database. """Updates an existing EventSeries object in the database.
Args: Args:
@@ -181,14 +179,14 @@ class EventController(BaseController):
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="ID in URL does not match ID in request body", detail="ID in URL does not match ID in request body",
) )
prev_series = await self.get_one_series_by_id(series.series_id) prev_series = self.get_one_series_by_id(series.series_id)
if series.poster_id != prev_series.poster_id: if series.poster_id != prev_series.poster_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Poster ID cannot be updated directly. Use the /poster endpoint instead.", detail="Poster ID cannot be updated directly. Use the /poster endpoint instead.",
) )
await self.db.delete_events_by_series(series) self.db.delete_events_by_series(series)
await self.db.replace_series(series) self.db.replace_series(series)
for event in series.events: for event in series.events:
await self.db.insert_one_event(event, series.series_id) self.db.insert_one_event(event, series.series_id)
return await self.get_one_series_by_id(series.series_id) return self.get_one_series_by_id(series.series_id)

View File

@@ -11,8 +11,8 @@ class GroupController(BaseController):
super().__init__() super().__init__()
self.db: GroupQueries = group_queries self.db: GroupQueries = group_queries
async def get_group(self) -> Group: def get_group(self) -> Group:
if (data := await self.db.select_one_by_id()) is None: if (data := self.db.select_one_by_id()) is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found" status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
) )
@@ -24,12 +24,12 @@ class GroupController(BaseController):
detail=f"Error creating group object: {e}", detail=f"Error creating group object: {e}",
) )
async def update_group_bio(self, bio: str) -> Group: def update_group_bio(self, bio: str) -> Group:
try: try:
await self.db.update_group_bio(bio) self.db.update_group_bio(bio)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error updating group bio: {e}", detail=f"Error updating group bio: {e}",
) )
return await self.get_group() return self.get_group()

View File

@@ -21,7 +21,7 @@ class MusicianController(BaseController):
super().__init__() super().__init__()
self.db: MusicianQueries = musician_queries self.db: MusicianQueries = musician_queries
async def get_musicians(self) -> list[Musician]: def get_musicians(self) -> list[Musician]:
"""Retrieves all musicians from the database and returns them as a list of Musician objects. """Retrieves all musicians from the database and returns them as a list of Musician objects.
Raises: Raises:
@@ -30,7 +30,7 @@ class MusicianController(BaseController):
Returns: Returns:
list[Musician]: A list of Musician objects which are suitable for a response body list[Musician]: A list of Musician objects which are suitable for a response body
""" """
data = await self.db.select_all_series() data = self.db.select_all_series()
try: try:
return [Musician(**m) for m in data] return [Musician(**m) for m in data]
except Exception as e: except Exception as e:
@@ -39,7 +39,7 @@ class MusicianController(BaseController):
detail=f"Error creating musician objects: {e}", detail=f"Error creating musician objects: {e}",
) )
async def get_musician(self, musician_id: int) -> Musician: def get_musician(self, musician_id: int) -> Musician:
"""Retrieves a single musician from the database and returns it as a Musician object. """Retrieves a single musician from the database and returns it as a Musician object.
Args: Args:
@@ -52,7 +52,7 @@ class MusicianController(BaseController):
Returns: Returns:
Musician: A Musician object which is suitable for a response body Musician: A Musician object which is suitable for a response body
""" """
if (data := await self.db.select_one_by_id(musician_id)) is None: if (data := self.db.select_one_by_id(musician_id)) is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Musician not found" status_code=status.HTTP_404_NOT_FOUND, detail="Musician not found"
) )
@@ -64,7 +64,7 @@ class MusicianController(BaseController):
detail=f"Error creating musician object: {e}", detail=f"Error creating musician object: {e}",
) )
async def update_musician( def update_musician(
self, self,
musician_id: int, musician_id: int,
new_bio: str, new_bio: str,
@@ -83,18 +83,18 @@ class MusicianController(BaseController):
Returns: Returns:
Musician: The updated Musician object Musician: The updated Musician object
""" """
musician = await self.get_musician(musician_id) musician = self.get_musician(musician_id)
if new_bio != musician.bio: if new_bio != musician.bio:
return await self._update_musician_bio(musician, new_bio) return self._update_musician_bio(musician, new_bio)
if file is not None: if file is not None:
return await self._upload_headshot(musician, file) return self._upload_headshot(musician, file)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Update operation not implemented. Neither the bio or headshot was updated.", detail="Update operation not implemented. Neither the bio or headshot was updated.",
) )
async def _update_musician_headshot( def _update_musician_headshot(
self, musician: Musician, headshot_id: str self, musician: Musician, headshot_id: str
) -> Musician: ) -> Musician:
"""Updates a musician's headshot in the database. """Updates a musician's headshot in the database.
@@ -110,15 +110,15 @@ class MusicianController(BaseController):
Musician: The updated Musician object Musician: The updated Musician object
""" """
try: try:
await self.db.update_headshot(musician, headshot_id) self.db.update_headshot(musician, headshot_id)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error updating musician headshot: {e}", detail=f"Error updating musician headshot: {e}",
) )
return await self.get_musician(musician.id) return self.get_musician(musician.id)
async def _update_musician_bio(self, musician: Musician, bio: str) -> Musician: def _update_musician_bio(self, musician: Musician, bio: str) -> Musician:
"""Updates a musician's bio in the database. """Updates a musician's bio in the database.
Args: Args:
@@ -132,15 +132,15 @@ class MusicianController(BaseController):
Musician: The updated Musician object Musician: The updated Musician object
""" """
try: try:
await self.db.update_bio(musician, bio) self.db.update_bio(musician, bio)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error updating musician bio: {e}", detail=f"Error updating musician bio: {e}",
) )
return await self.get_musician(musician.id) return self.get_musician(musician.id)
async def _upload_headshot(self, musician: Musician, file: UploadFile) -> Musician: def _upload_headshot(self, musician: Musician, file: UploadFile) -> Musician:
"""Uploads a new headshot for a musician and updates the database with the new public ID. """Uploads a new headshot for a musician and updates the database with the new public ID.
Args: Args:
@@ -153,7 +153,7 @@ class MusicianController(BaseController):
Returns: Returns:
Musician: The updated Musician object Musician: The updated Musician object
""" """
image_file = await self.verify_image(file) image_file = self.verify_image(file)
data = uploader.upload(image_file) data = uploader.upload(image_file)
public_id = data.get("public_id") public_id = data.get("public_id")
if public_id is None: if public_id is None:
@@ -161,6 +161,6 @@ class MusicianController(BaseController):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to upload image", detail="Failed to upload image",
) )
await self._update_musician_headshot(musician, public_id) self._update_musician_headshot(musician, public_id)
return await self.get_musician(musician.id) return self.get_musician(musician.id)

View File

@@ -13,8 +13,8 @@ class UserController(BaseController):
super().__init__() super().__init__()
self.db: UserQueries = user_queries self.db: UserQueries = user_queries
async def get_users(self) -> list[User]: def get_users(self) -> list[User]:
data = await self.db.select_all_series() data = self.db.select_all_series()
try: try:
return [User(**e) for e in data] return [User(**e) for e in data]
except Exception as e: except Exception as e:
@@ -23,8 +23,8 @@ class UserController(BaseController):
detail=f"Error creating user objects: {e}", detail=f"Error creating user objects: {e}",
) )
async def get_user_by_id(self, id: int) -> User: def get_user_by_id(self, id: int) -> User:
if (data := await self.db.select_one_by_id(id)) is None: if (data := self.db.select_one_by_id(id)) is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found" status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
) )
@@ -36,8 +36,8 @@ class UserController(BaseController):
detail=f"Error creating user object: {e}", detail=f"Error creating user object: {e}",
) )
async def get_user_by_email(self, email: str) -> User: def get_user_by_email(self, email: str) -> User:
if (data := await self.db.get_one_by_email(email)) is None: if (data := self.db.get_one_by_email(email)) is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User does not exist" status_code=status.HTTP_404_NOT_FOUND, detail="User does not exist"
) )
@@ -49,8 +49,8 @@ class UserController(BaseController):
detail=f"Error creating user object: {e}", detail=f"Error creating user object: {e}",
) )
async def get_user_by_sub(self, sub: str) -> User: def get_user_by_sub(self, sub: str) -> User:
if (data := await self.db.get_one_by_sub(sub)) is None: if (data := self.db.get_one_by_sub(sub)) is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found" status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
) )
@@ -62,9 +62,9 @@ class UserController(BaseController):
detail=f"Error creating user object: {e}", detail=f"Error creating user object: {e}",
) )
async def create_user(self, token: HTTPAuthorizationCredentials) -> User: def create_user(self, token: HTTPAuthorizationCredentials) -> User:
email, sub = oauth_token.email_and_sub(token) email, sub = oauth_token.email_and_sub(token)
user: User = await self.get_user_by_email(email) user: User = self.get_user_by_email(email)
if user.sub is None: if user.sub is None:
await self.db.update_sub(user.email, sub) self.db.update_sub(user.email, sub)
return await self.get_user_by_sub(sub) return self.get_user_by_sub(sub)

View File

@@ -17,7 +17,7 @@ class BaseQueries:
self.table: str = None # type: ignore self.table: str = None # type: ignore
self.connect_db: Callable = connect_db self.connect_db: Callable = connect_db
async def select_all_series(self) -> list[dict]: def select_all_series(self) -> list[dict]:
query = f"SELECT * FROM {self.table}" query = f"SELECT * FROM {self.table}"
db = connect_db() db = connect_db()
cursor = db.cursor(dictionary=True) cursor = db.cursor(dictionary=True)
@@ -27,7 +27,7 @@ class BaseQueries:
db.close() db.close()
return data # type: ignore return data # type: ignore
async def select_one_by_id(self, id: int) -> dict | None: def select_one_by_id(self, id: int) -> dict | None:
query = f"SELECT * FROM {self.table} WHERE id = %s" query = f"SELECT * FROM {self.table} WHERE id = %s"
db = self.connect_db() db = self.connect_db()
cursor = db.cursor(dictionary=True) cursor = db.cursor(dictionary=True)

View File

@@ -17,7 +17,7 @@ class EventQueries(BaseQueries):
super().__init__() super().__init__()
self.table = SERIES_TABLE self.table = SERIES_TABLE
async def select_one_by_id(self, series_id: int) -> list[dict] | None: def select_one_by_id(self, series_id: int) -> list[dict] | None:
query = f""" query = f"""
SELECT s.series_id , s.name , s.description , s.poster_id , e.event_id , e.location , e.`time` , e.ticket_url , e.map_url SELECT s.series_id , s.name , s.description , s.poster_id , e.event_id , e.location , e.`time` , e.ticket_url , e.map_url
FROM {SERIES_TABLE} s FROM {SERIES_TABLE} s
@@ -33,7 +33,7 @@ class EventQueries(BaseQueries):
db.close() db.close()
return data return data
async def select_all_series(self) -> list[dict]: def select_all_series(self) -> list[dict]:
""" """
Queries for all Series and Event info and returns a list of dictionaries. Queries for all Series and Event info and returns a list of dictionaries.
Data is gathered with a LEFT JOIN on the Event table to ensure all Series are returned. Data is gathered with a LEFT JOIN on the Event table to ensure all Series are returned.
@@ -54,7 +54,7 @@ class EventQueries(BaseQueries):
db.close() db.close()
return data return data
async def insert_one_series(self, series: NewEventSeries) -> int: def insert_one_series(self, series: NewEventSeries) -> int:
query = f""" query = f"""
INSERT INTO {self.table} (name, description) INSERT INTO {self.table} (name, description)
VALUES (%s, %s) VALUES (%s, %s)
@@ -74,7 +74,7 @@ class EventQueries(BaseQueries):
db.close() db.close()
return inserted_id return inserted_id
async def insert_one_event(self, event: NewEvent, series_id: int) -> int: def insert_one_event(self, event: NewEvent, series_id: int) -> int:
query = f""" query = f"""
INSERT INTO {EVENT_TABLE} (series_id, location, time, ticket_url, map_url) INSERT INTO {EVENT_TABLE} (series_id, location, time, ticket_url, map_url)
VALUES (%s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s)
@@ -92,7 +92,7 @@ class EventQueries(BaseQueries):
db.close() db.close()
return iserted_id return iserted_id
async def delete_events_by_series(self, series: EventSeries) -> None: def delete_events_by_series(self, series: EventSeries) -> None:
query = f""" query = f"""
DELETE FROM {EVENT_TABLE} DELETE FROM {EVENT_TABLE}
WHERE series_id = %s WHERE series_id = %s
@@ -103,7 +103,7 @@ class EventQueries(BaseQueries):
db.commit() db.commit()
cursor.close() cursor.close()
async def delete_one_series(self, series: EventSeries) -> None: def delete_one_series(self, series: EventSeries) -> None:
query = f""" query = f"""
DELETE FROM {self.table} DELETE FROM {self.table}
WHERE series_id = %s WHERE series_id = %s
@@ -114,7 +114,7 @@ class EventQueries(BaseQueries):
db.commit() db.commit()
cursor.close() cursor.close()
async def update_series_poster(self, series: EventSeries) -> None: def update_series_poster(self, series: EventSeries) -> None:
query = f""" query = f"""
UPDATE {self.table} UPDATE {self.table}
SET poster_id = %s SET poster_id = %s
@@ -126,7 +126,7 @@ class EventQueries(BaseQueries):
db.commit() db.commit()
cursor.close() cursor.close()
async def replace_event(self, event: Event) -> None: def replace_event(self, event: Event) -> None:
query = f""" query = f"""
UPDATE {EVENT_TABLE} UPDATE {EVENT_TABLE}
SET location = %s, time = %s, ticket_url = %s, map_url = %s SET location = %s, time = %s, ticket_url = %s, map_url = %s
@@ -143,7 +143,7 @@ class EventQueries(BaseQueries):
cursor.close() cursor.close()
db.close() db.close()
async def replace_series(self, series: EventSeries) -> None: def replace_series(self, series: EventSeries) -> None:
query = f""" query = f"""
UPDATE {self.table} UPDATE {self.table}
SET name = %s, description = %s, poster_id = %s SET name = %s, description = %s, poster_id = %s

View File

@@ -7,7 +7,7 @@ class GroupQueries(BaseQueries):
super().__init__() super().__init__()
self.table = GROUP_TABLE self.table = GROUP_TABLE
async def select_one_by_id(self) -> dict: def select_one_by_id(self) -> dict:
query = f"SELECT * FROM {self.table}" query = f"SELECT * FROM {self.table}"
db = self.connect_db() db = self.connect_db()
cursor = db.cursor(dictionary=True) cursor = db.cursor(dictionary=True)
@@ -21,12 +21,12 @@ class GroupQueries(BaseQueries):
return data return data
async def select_all_series(self) -> None: def select_all_series(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"get_all method not implemented for GroupQueries. There's only one row in the table." "get_all method not implemented for GroupQueries. There's only one row in the table."
) )
async def update_group_bio(self, bio: str) -> None: def update_group_bio(self, bio: str) -> None:
db = self.connect_db() db = self.connect_db()
cursor = db.cursor() cursor = db.cursor()
query = f"UPDATE {self.table} SET bio = %s WHERE id = 1" # only one row in the table query = f"UPDATE {self.table} SET bio = %s WHERE id = 1" # only one row in the table

View File

@@ -11,7 +11,7 @@ class MusicianQueries(BaseQueries):
super().__init__() super().__init__()
self.table = MUSICIAN_TABLE self.table = MUSICIAN_TABLE
async def update_bio(self, musician: Musician, bio: str) -> None: def update_bio(self, musician: Musician, bio: str) -> None:
"""Updates a musician's biography in the database. """Updates a musician's biography in the database.
Args: Args:
@@ -26,7 +26,7 @@ class MusicianQueries(BaseQueries):
cursor.close() cursor.close()
db.close() db.close()
async def update_headshot(self, musician: Musician, headshot_id: str) -> None: def update_headshot(self, musician: Musician, headshot_id: str) -> None:
"""Updates a musician's headshot ID in the database. """Updates a musician's headshot ID in the database.
The image itself is stored with Cloudinary. The image itself is stored with Cloudinary.

View File

@@ -7,7 +7,7 @@ class UserQueries(BaseQueries):
super().__init__() super().__init__()
self.table = USER_TABLE self.table = USER_TABLE
async def get_one_by_email(self, email: str) -> dict | None: def get_one_by_email(self, email: str) -> dict | None:
query = f"SELECT * FROM {self.table} WHERE email = %s" query = f"SELECT * FROM {self.table} WHERE email = %s"
db = self.connect_db() db = self.connect_db()
cursor = db.cursor(dictionary=True) cursor = db.cursor(dictionary=True)
@@ -18,7 +18,7 @@ class UserQueries(BaseQueries):
return data return data
async def get_one_by_sub(self, sub: str) -> dict | None: def get_one_by_sub(self, sub: str) -> dict | None:
query = f"SELECT * FROM {self.table} WHERE sub = %s" query = f"SELECT * FROM {self.table} WHERE sub = %s"
db = self.connect_db() db = self.connect_db()
cursor = db.cursor(dictionary=True) cursor = db.cursor(dictionary=True)
@@ -32,7 +32,7 @@ class UserQueries(BaseQueries):
return data return data
async def update_sub(self, email: str, sub: str) -> None: def update_sub(self, email: str, sub: str) -> None:
query = f"UPDATE {self.table} SET sub = %s WHERE email = %s" query = f"UPDATE {self.table} SET sub = %s WHERE email = %s"
db = self.connect_db() db = self.connect_db()
cursor = db.cursor() cursor = db.cursor()

View File

@@ -22,23 +22,21 @@ def test_type():
assert isinstance(ec, EventController) assert isinstance(ec, EventController)
@pytest.mark.asyncio def test_all_series_with_no_data():
async def test_all_series_with_no_data():
"""Tests with absent data.""" """Tests with absent data."""
async def no_series() -> list[dict]: def no_series() -> list[dict]:
return [] return []
mock_queries.select_all_series = no_series mock_queries.select_all_series = no_series
result = await ec.get_all_series() result = ec.get_all_series()
assert result == [] assert result == []
@pytest.mark.asyncio def test_all_series_with_basic_data():
async def test_all_series_with_basic_data():
"""Tests a single valid row with no event info""" """Tests a single valid row with no event info"""
async def one_series_with_no_events() -> list[dict]: def one_series_with_no_events() -> list[dict]:
return [ return [
{ {
"name": "Test Series", "name": "Test Series",
@@ -49,7 +47,7 @@ async def test_all_series_with_basic_data():
] ]
mock_queries.select_all_series = one_series_with_no_events mock_queries.select_all_series = one_series_with_no_events
result = await ec.get_all_series() result = ec.get_all_series()
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 1 assert len(result) == 1
series = result[0] series = result[0]
@@ -61,8 +59,7 @@ async def test_all_series_with_basic_data():
assert series.events == [] assert series.events == []
@pytest.mark.asyncio def test_all_series_with_detailed_data():
async def test_all_series_with_detailed_data():
"""Tests a single valid row with event info""" """Tests a single valid row with event info"""
series_id = 1 series_id = 1
@@ -70,7 +67,7 @@ async def test_all_series_with_detailed_data():
series_description = "Pieces by Danzi, Gomez, Gounod, Grant, and Rusnak!" series_description = "Pieces by Danzi, Gomez, Gounod, Grant, and Rusnak!"
poster_id = "The_Grapefruits_Present_qhng6y" poster_id = "The_Grapefruits_Present_qhng6y"
async def one_series_with_events() -> list[dict]: def one_series_with_events() -> list[dict]:
row_1 = { row_1 = {
"series_id": series_id, "series_id": series_id,
@@ -108,7 +105,7 @@ async def test_all_series_with_detailed_data():
] ]
mock_queries.select_all_series = one_series_with_events mock_queries.select_all_series = one_series_with_events
result = await ec.get_all_series() result = ec.get_all_series()
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 1 assert len(result) == 1
series = result[0] series = result[0]
@@ -148,11 +145,10 @@ async def test_all_series_with_detailed_data():
assert e3.time == datetime(2024, 6, 23, 15, 0) assert e3.time == datetime(2024, 6, 23, 15, 0)
@pytest.mark.asyncio def test_all_series_with_many_series():
async def test_all_series_with_many_series():
"""Tests multiple series with no events.""" """Tests multiple series with no events."""
async def many_series() -> list[dict]: def many_series() -> list[dict]:
return [ return [
{ {
"name": "Test Series", "name": "Test Series",
@@ -175,7 +171,7 @@ async def test_all_series_with_many_series():
] ]
mock_queries.select_all_series = many_series mock_queries.select_all_series = many_series
result = await ec.get_all_series() result = ec.get_all_series()
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 3 assert len(result) == 3
for series in result: for series in result:
@@ -183,14 +179,13 @@ async def test_all_series_with_many_series():
assert series.events == [] assert series.events == []
@pytest.mark.asyncio def test_all_series_with_error():
async def test_all_series_with_error():
"""Tests an error during the retrieval process.""" """Tests an error during the retrieval process."""
mock_log_error = Mock() mock_log_error = Mock()
ec.log_error = mock_log_error ec.log_error = mock_log_error
async def invalid_series() -> list[dict]: def invalid_series() -> list[dict]:
# no series name # no series name
return [ return [
{ {
@@ -202,13 +197,12 @@ async def test_all_series_with_error():
mock_queries.select_all_series = invalid_series mock_queries.select_all_series = invalid_series
with pytest.raises(Exception): with pytest.raises(Exception):
await ec.get_all_series() ec.get_all_series()
Mock.assert_called_once(mock_log_error) Mock.assert_called_once(mock_log_error)
@pytest.mark.asyncio def test_one_series():
async def test_one_series(): def one_series_no_events(series_id: int) -> list[dict]:
async def one_series_no_events(series_id: int) -> list[dict]:
return [ return [
{ {
"series_id": series_id, "series_id": series_id,
@@ -219,7 +213,7 @@ async def test_one_series():
] ]
mock_queries.select_one_by_id = one_series_no_events mock_queries.select_one_by_id = one_series_no_events
series = await ec.get_one_series_by_id(1) series = ec.get_one_series_by_id(1)
assert isinstance(series, EventSeries) assert isinstance(series, EventSeries)
assert series.name == "Test Series" assert series.name == "Test Series"
assert series.description == "Test Description" assert series.description == "Test Description"
@@ -228,9 +222,8 @@ async def test_one_series():
assert series.events == [] assert series.events == []
@pytest.mark.asyncio def test_one_series_with_events():
async def test_one_series_with_events(): def one_series_with_events(series_id: int) -> list[dict]:
async def one_series_with_events(series_id: int) -> list[dict]:
return [ return [
{ {
"series_id": series_id, "series_id": series_id,
@@ -264,7 +257,7 @@ async def test_one_series_with_events():
] ]
mock_queries.select_one_by_id = one_series_with_events mock_queries.select_one_by_id = one_series_with_events
series = await ec.get_one_series_by_id(1) series = ec.get_one_series_by_id(1)
assert isinstance(series, EventSeries) assert isinstance(series, EventSeries)
assert series.name == "Test Series" assert series.name == "Test Series"
assert series.description == "Test Description" assert series.description == "Test Description"

View File

@@ -35,15 +35,15 @@ bad_data = [
] ]
async def mock_select_all_series(): def mock_select_all_series():
return sample_data return sample_data
async def mock_select_all_series_sad(): def mock_select_all_series_sad():
return bad_data return bad_data
async def mock_select_one_by_id(musician_id: int): def mock_select_one_by_id(musician_id: int):
for musician in sample_data: for musician in sample_data:
if musician.get("id") == musician_id: if musician.get("id") == musician_id:
return musician return musician
@@ -67,9 +67,8 @@ TODO: write tests for following methods:
""" """
@pytest.mark.asyncio def test_happy_get_musicians():
async def test_happy_get_musicians(): musicians = ec.get_musicians()
musicians = await ec.get_musicians()
assert isinstance(musicians, list) assert isinstance(musicians, list)
assert len(musicians) == 2 assert len(musicians) == 2
for musician in musicians: for musician in musicians:
@@ -85,19 +84,17 @@ async def test_happy_get_musicians():
assert m2.headshot_id == "headshotABC" assert m2.headshot_id == "headshotABC"
@pytest.mark.asyncio def test_sad_get_musicians():
async def test_sad_get_musicians():
mock_queries.select_all_series = mock_select_all_series_sad mock_queries.select_all_series = mock_select_all_series_sad
with pytest.raises(HTTPException) as e: with pytest.raises(HTTPException) as e:
await ec.get_musicians() ec.get_musicians()
mock_queries.select_all_series = mock_select_all_series mock_queries.select_all_series = mock_select_all_series
assert isinstance(e.value, HTTPException) assert isinstance(e.value, HTTPException)
assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@pytest.mark.asyncio def test_happy_get_musician():
async def test_happy_get_musician(): musician = ec.get_musician(1)
musician = await ec.get_musician(1)
assert isinstance(musician, Musician) assert isinstance(musician, Musician)
assert musician.id == 1 assert musician.id == 1
assert musician.name == "John Doe" assert musician.name == "John Doe"
@@ -105,10 +102,9 @@ async def test_happy_get_musician():
assert musician.headshot_id == "headshot123" assert musician.headshot_id == "headshot123"
@pytest.mark.asyncio def test_musician_not_found():
async def test_musician_not_found():
with pytest.raises(HTTPException) as e: with pytest.raises(HTTPException) as e:
await ec.get_musician(3) ec.get_musician(3)
assert isinstance(e.value, HTTPException) assert isinstance(e.value, HTTPException)
assert e.value.status_code == status.HTTP_404_NOT_FOUND assert e.value.status_code == status.HTTP_404_NOT_FOUND
assert e.value.detail == "Musician not found" assert e.value.detail == "Musician not found"