From 5bb28ed9237a14aa4185ec8c39fd33f899ef5e70 Mon Sep 17 00:00:00 2001 From: Lucas Jensen Date: Fri, 3 May 2024 13:21:47 -0700 Subject: [PATCH 1/2] removed many unneeded async definitions. The main controller remains async --- server/app/controllers/base_controller.py | 5 +- server/app/controllers/controller.py | 42 ++++++++-------- server/app/controllers/events.py | 50 +++++++++---------- server/app/controllers/group.py | 10 ++-- server/app/controllers/musicians.py | 36 ++++++------- server/app/controllers/users.py | 24 ++++----- server/app/db/base_queries.py | 4 +- server/app/db/events.py | 18 +++---- server/app/db/group.py | 6 +-- server/app/db/musicians.py | 4 +- server/app/db/users.py | 6 +-- server/app/main.py | 2 - server/app/routers/users.py | 9 ++++ .../controllers/test_event_controller.py | 49 ++++++++---------- .../controllers/test_musician_controller.py | 26 ++++------ 15 files changed, 144 insertions(+), 147 deletions(-) diff --git a/server/app/controllers/base_controller.py b/server/app/controllers/base_controller.py index dc40dfe..b21079b 100644 --- a/server/app/controllers/base_controller.py +++ b/server/app/controllers/base_controller.py @@ -21,7 +21,7 @@ class BaseController: self.ALL_FILES = ALLOWED_FILES_TYPES self.MAX_FILE_SIZE = ONE_MB - async def verify_image(self, file: UploadFile) -> bytes: + def verify_image(self, file: UploadFile) -> bytes: """Verifies that the file is an image and is within the maximum file size. Args: @@ -38,7 +38,8 @@ class BaseController: status_code=status.HTTP_400_BAD_REQUEST, detail=f"File type {file.content_type} not allowed. Allowed file types are {self.ALL_FILES}", ) - image_file = await file.read() + with file.file as f: + image_file = f.read() if len(image_file) > self.MAX_FILE_SIZE: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/server/app/controllers/controller.py b/server/app/controllers/controller.py index ea5875b..6296541 100644 --- a/server/app/controllers/controller.py +++ b/server/app/controllers/controller.py @@ -19,6 +19,8 @@ class MainController: All methods are either pass-throughs to the appropriate controller or are used to coordinate multiple controllers. + All methods are asynchronous to facilitate asynchronous calls from the Router layer. + token-based authentication is handled here as needed per the nature of the data being accessed. """ @@ -29,10 +31,10 @@ class MainController: self.group_controller = GroupController() async def get_musicians(self) -> list[Musician]: - return await self.musician_controller.get_musicians() + return self.musician_controller.get_musicians() async def get_musician(self, id: int) -> Musician: - return await self.musician_controller.get_musician(id) + return self.musician_controller.get_musician(id) async def update_musician( self, @@ -48,60 +50,60 @@ class MainController: detail="ID in URL does not match ID in request body", ) _, sub = oauth_token.email_and_sub(token) - await self.user_controller.get_user_by_sub(sub) - return await self.musician_controller.update_musician( + self.user_controller.get_user_by_sub(sub) + return self.musician_controller.update_musician( musician_id=musician.id, new_bio=musician.bio, file=file, ) async def get_events(self) -> list[EventSeries]: - return await self.event_controller.get_all_series() + return self.event_controller.get_all_series() async def get_event(self, id: int) -> EventSeries: - return await self.event_controller.get_one_series_by_id(id) + return self.event_controller.get_one_series_by_id(id) async def create_event( self, series: NewEventSeries, token: HTTPAuthorizationCredentials ) -> EventSeries: _, sub = oauth_token.email_and_sub(token) - await self.user_controller.get_user_by_sub(sub) - return await self.event_controller.create_series(series) + self.user_controller.get_user_by_sub(sub) + return self.event_controller.create_series(series) async def add_series_poster( self, series_id: int, poster: UploadFile, token: HTTPAuthorizationCredentials ) -> EventSeries: _, sub = oauth_token.email_and_sub(token) - await self.user_controller.get_user_by_sub(sub) - return await self.event_controller.add_series_poster(series_id, poster) + self.user_controller.get_user_by_sub(sub) + return self.event_controller.add_series_poster(series_id, poster) async def delete_series(self, id: int, token: HTTPAuthorizationCredentials) -> None: _, sub = oauth_token.email_and_sub(token) - await self.user_controller.get_user_by_sub(sub) - await self.event_controller.delete_series(id) + self.user_controller.get_user_by_sub(sub) + self.event_controller.delete_series(id) async def update_series( self, route_id: int, series: EventSeries, token: HTTPAuthorizationCredentials ) -> EventSeries: _, sub = oauth_token.email_and_sub(token) - await self.user_controller.get_user_by_sub(sub) - return await self.event_controller.update_series(route_id, series) + self.user_controller.get_user_by_sub(sub) + return self.event_controller.update_series(route_id, series) async def get_users(self) -> list[User]: - return await self.user_controller.get_users() + return self.user_controller.get_users() async def get_user(self, id: int) -> User: - return await self.user_controller.get_user_by_id(id) + return self.user_controller.get_user_by_id(id) async def create_user(self, token: HTTPAuthorizationCredentials) -> User: - return await self.user_controller.create_user(token) + return self.user_controller.create_user(token) async def get_group(self) -> Group: - return await self.group_controller.get_group() + return self.group_controller.get_group() async def update_group_bio( self, bio: str, token: HTTPAuthorizationCredentials ) -> Group: _, sub = oauth_token.email_and_sub(token) - await self.user_controller.get_user_by_sub(sub) - return await self.group_controller.update_group_bio(bio) + self.user_controller.get_user_by_sub(sub) + return self.group_controller.update_group_bio(bio) diff --git a/server/app/controllers/events.py b/server/app/controllers/events.py index 99cdef3..69eee4f 100644 --- a/server/app/controllers/events.py +++ b/server/app/controllers/events.py @@ -43,7 +43,7 @@ class EventController(BaseController): return all_series - async def get_all_series(self) -> list[EventSeries]: + def get_all_series(self) -> list[EventSeries]: """Retrieves all EventSeries objects from the database and returns them as a list. Raises: @@ -52,7 +52,7 @@ class EventController(BaseController): Returns: list[EventSeries]: A list of EventSeries objects which are suitable for a response body """ - series_data = await self.db.select_all_series() + series_data = self.db.select_all_series() try: return [series for series in self._all_series(series_data).values()] @@ -63,7 +63,7 @@ class EventController(BaseController): detail=f"Error retrieving event objects: {e}", ) - async def get_one_series_by_id(self, series_id: int) -> EventSeries: + def get_one_series_by_id(self, series_id: int) -> EventSeries: """Builds and returns a single EventSeries object by its numeric ID. Args: @@ -76,7 +76,7 @@ class EventController(BaseController): Returns: EventSeries: A single EventSeries object """ - if not (rows := await self.db.select_one_by_id(series_id)): + if not (rows := self.db.select_one_by_id(series_id)): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Event not found" ) @@ -90,7 +90,7 @@ class EventController(BaseController): detail=f"Error creating event object: {e}", ) - async def create_series(self, series: NewEventSeries) -> EventSeries: + def create_series(self, series: NewEventSeries) -> EventSeries: """Takes a NewEventSeries object and passes it to the database layer for insertion. Args: @@ -103,19 +103,17 @@ class EventController(BaseController): EventSeries: The newly created EventSeries object with an ID """ try: - inserted_id = await self.db.insert_one_series(series) + inserted_id = self.db.insert_one_series(series) for new_event in series.events: - await self.db.insert_one_event(new_event, inserted_id) - return await self.get_one_series_by_id(inserted_id) + self.db.insert_one_event(new_event, inserted_id) + return self.get_one_series_by_id(inserted_id) except IntegrityError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Series name already exists. Each series must have a unique name.\n{e}", ) - async def add_series_poster( - self, series_id: int, poster: UploadFile - ) -> EventSeries: + def add_series_poster(self, series_id: int, poster: UploadFile) -> EventSeries: """Adds (or updates) a poster image to a series. Args: @@ -125,12 +123,12 @@ class EventController(BaseController): Returns: EventSeries: The updated EventSeries object """ - series = await self.get_one_series_by_id(series_id) - series.poster_id = await self._upload_poster(poster) - await self.db.update_series_poster(series) - return await self.get_one_series_by_id(series.series_id) + series = self.get_one_series_by_id(series_id) + series.poster_id = self._upload_poster(poster) + self.db.update_series_poster(series) + return self.get_one_series_by_id(series.series_id) - async def _upload_poster(self, poster: UploadFile) -> str: + def _upload_poster(self, poster: UploadFile) -> str: """Uploads a poster image to Cloudinary and returns the public ID for storage in the database. Should only be used internally. @@ -143,7 +141,7 @@ class EventController(BaseController): Returns: str: The public ID of the uploaded image """ - image_file = await self.verify_image(poster) + image_file = self.verify_image(poster) try: data = uploader.upload(image_file) return data.get("public_id") @@ -153,16 +151,16 @@ class EventController(BaseController): detail=f"Error uploading image: {e}", ) - async def delete_series(self, id: int) -> None: + def delete_series(self, id: int) -> None: """Ensures an EventSeries object exists and then deletes it from the database. Args: id (int): The numeric ID of the series to delete """ - series = await self.get_one_series_by_id(id) - await self.db.delete_one_series(series) + series = self.get_one_series_by_id(id) + self.db.delete_one_series(series) - async def update_series(self, route_id: int, series: EventSeries) -> EventSeries: + def update_series(self, route_id: int, series: EventSeries) -> EventSeries: """Updates an existing EventSeries object in the database. Args: @@ -181,14 +179,14 @@ class EventController(BaseController): status_code=status.HTTP_400_BAD_REQUEST, detail="ID in URL does not match ID in request body", ) - prev_series = await self.get_one_series_by_id(series.series_id) + prev_series = self.get_one_series_by_id(series.series_id) if series.poster_id != prev_series.poster_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Poster ID cannot be updated directly. Use the /poster endpoint instead.", ) - await self.db.delete_events_by_series(series) - await self.db.replace_series(series) + self.db.delete_events_by_series(series) + self.db.replace_series(series) for event in series.events: - await self.db.insert_one_event(event, series.series_id) - return await self.get_one_series_by_id(series.series_id) + self.db.insert_one_event(event, series.series_id) + return self.get_one_series_by_id(series.series_id) diff --git a/server/app/controllers/group.py b/server/app/controllers/group.py index 1e3dba4..fdc37ff 100644 --- a/server/app/controllers/group.py +++ b/server/app/controllers/group.py @@ -11,8 +11,8 @@ class GroupController(BaseController): super().__init__() self.db: GroupQueries = group_queries - async def get_group(self) -> Group: - if (data := await self.db.select_one_by_id()) is None: + def get_group(self) -> Group: + if (data := self.db.select_one_by_id()) is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Group not found" ) @@ -24,12 +24,12 @@ class GroupController(BaseController): detail=f"Error creating group object: {e}", ) - async def update_group_bio(self, bio: str) -> Group: + def update_group_bio(self, bio: str) -> Group: try: - await self.db.update_group_bio(bio) + self.db.update_group_bio(bio) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error updating group bio: {e}", ) - return await self.get_group() + return self.get_group() diff --git a/server/app/controllers/musicians.py b/server/app/controllers/musicians.py index 3d9f583..7f843c8 100644 --- a/server/app/controllers/musicians.py +++ b/server/app/controllers/musicians.py @@ -21,7 +21,7 @@ class MusicianController(BaseController): super().__init__() self.db: MusicianQueries = musician_queries - async def get_musicians(self) -> list[Musician]: + def get_musicians(self) -> list[Musician]: """Retrieves all musicians from the database and returns them as a list of Musician objects. Raises: @@ -30,7 +30,7 @@ class MusicianController(BaseController): Returns: list[Musician]: A list of Musician objects which are suitable for a response body """ - data = await self.db.select_all_series() + data = self.db.select_all_series() try: return [Musician(**m) for m in data] except Exception as e: @@ -39,7 +39,7 @@ class MusicianController(BaseController): detail=f"Error creating musician objects: {e}", ) - async def get_musician(self, musician_id: int) -> Musician: + def get_musician(self, musician_id: int) -> Musician: """Retrieves a single musician from the database and returns it as a Musician object. Args: @@ -52,7 +52,7 @@ class MusicianController(BaseController): Returns: Musician: A Musician object which is suitable for a response body """ - if (data := await self.db.select_one_by_id(musician_id)) is None: + if (data := self.db.select_one_by_id(musician_id)) is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Musician not found" ) @@ -64,7 +64,7 @@ class MusicianController(BaseController): detail=f"Error creating musician object: {e}", ) - async def update_musician( + def update_musician( self, musician_id: int, new_bio: str, @@ -83,18 +83,18 @@ class MusicianController(BaseController): Returns: Musician: The updated Musician object """ - musician = await self.get_musician(musician_id) + musician = self.get_musician(musician_id) if new_bio != musician.bio: - return await self._update_musician_bio(musician, new_bio) + return self._update_musician_bio(musician, new_bio) if file is not None: - return await self._upload_headshot(musician, file) + return self._upload_headshot(musician, file) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Update operation not implemented. Neither the bio or headshot was updated.", ) - async def _update_musician_headshot( + def _update_musician_headshot( self, musician: Musician, headshot_id: str ) -> Musician: """Updates a musician's headshot in the database. @@ -110,15 +110,15 @@ class MusicianController(BaseController): Musician: The updated Musician object """ try: - await self.db.update_headshot(musician, headshot_id) + self.db.update_headshot(musician, headshot_id) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error updating musician headshot: {e}", ) - return await self.get_musician(musician.id) + return self.get_musician(musician.id) - async def _update_musician_bio(self, musician: Musician, bio: str) -> Musician: + def _update_musician_bio(self, musician: Musician, bio: str) -> Musician: """Updates a musician's bio in the database. Args: @@ -132,15 +132,15 @@ class MusicianController(BaseController): Musician: The updated Musician object """ try: - await self.db.update_bio(musician, bio) + self.db.update_bio(musician, bio) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error updating musician bio: {e}", ) - return await self.get_musician(musician.id) + return self.get_musician(musician.id) - async def _upload_headshot(self, musician: Musician, file: UploadFile) -> Musician: + def _upload_headshot(self, musician: Musician, file: UploadFile) -> Musician: """Uploads a new headshot for a musician and updates the database with the new public ID. Args: @@ -153,7 +153,7 @@ class MusicianController(BaseController): Returns: Musician: The updated Musician object """ - image_file = await self.verify_image(file) + image_file = self.verify_image(file) data = uploader.upload(image_file) public_id = data.get("public_id") if public_id is None: @@ -161,6 +161,6 @@ class MusicianController(BaseController): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to upload image", ) - await self._update_musician_headshot(musician, public_id) + self._update_musician_headshot(musician, public_id) - return await self.get_musician(musician.id) + return self.get_musician(musician.id) diff --git a/server/app/controllers/users.py b/server/app/controllers/users.py index df59531..efc8a17 100644 --- a/server/app/controllers/users.py +++ b/server/app/controllers/users.py @@ -13,8 +13,8 @@ class UserController(BaseController): super().__init__() self.db: UserQueries = user_queries - async def get_users(self) -> list[User]: - data = await self.db.select_all_series() + def get_users(self) -> list[User]: + data = self.db.select_all_series() try: return [User(**e) for e in data] except Exception as e: @@ -23,8 +23,8 @@ class UserController(BaseController): detail=f"Error creating user objects: {e}", ) - async def get_user_by_id(self, id: int) -> User: - if (data := await self.db.select_one_by_id(id)) is None: + def get_user_by_id(self, id: int) -> User: + if (data := self.db.select_one_by_id(id)) is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) @@ -36,8 +36,8 @@ class UserController(BaseController): detail=f"Error creating user object: {e}", ) - async def get_user_by_email(self, email: str) -> User: - if (data := await self.db.get_one_by_email(email)) is None: + def get_user_by_email(self, email: str) -> User: + if (data := self.db.get_one_by_email(email)) is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User does not exist" ) @@ -49,8 +49,8 @@ class UserController(BaseController): detail=f"Error creating user object: {e}", ) - async def get_user_by_sub(self, sub: str) -> User: - if (data := await self.db.get_one_by_sub(sub)) is None: + def get_user_by_sub(self, sub: str) -> User: + if (data := self.db.get_one_by_sub(sub)) is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) @@ -62,9 +62,9 @@ class UserController(BaseController): detail=f"Error creating user object: {e}", ) - async def create_user(self, token: HTTPAuthorizationCredentials) -> User: + def create_user(self, token: HTTPAuthorizationCredentials) -> User: email, sub = oauth_token.email_and_sub(token) - user: User = await self.get_user_by_email(email) + user: User = self.get_user_by_email(email) if user.sub is None: - await self.db.update_sub(user.email, sub) - return await self.get_user_by_sub(sub) + self.db.update_sub(user.email, sub) + return self.get_user_by_sub(sub) diff --git a/server/app/db/base_queries.py b/server/app/db/base_queries.py index e2ec3ed..18344ac 100644 --- a/server/app/db/base_queries.py +++ b/server/app/db/base_queries.py @@ -17,7 +17,7 @@ class BaseQueries: self.table: str = None # type: ignore self.connect_db: Callable = connect_db - async def select_all_series(self) -> list[dict]: + def select_all_series(self) -> list[dict]: query = f"SELECT * FROM {self.table}" db = connect_db() cursor = db.cursor(dictionary=True) @@ -27,7 +27,7 @@ class BaseQueries: db.close() return data # type: ignore - async def select_one_by_id(self, id: int) -> dict | None: + def select_one_by_id(self, id: int) -> dict | None: query = f"SELECT * FROM {self.table} WHERE id = %s" db = self.connect_db() cursor = db.cursor(dictionary=True) diff --git a/server/app/db/events.py b/server/app/db/events.py index 09acff2..e2930be 100644 --- a/server/app/db/events.py +++ b/server/app/db/events.py @@ -17,7 +17,7 @@ class EventQueries(BaseQueries): super().__init__() self.table = SERIES_TABLE - async def select_one_by_id(self, series_id: int) -> list[dict] | None: + def select_one_by_id(self, series_id: int) -> list[dict] | None: query = f""" SELECT s.series_id , s.name , s.description , s.poster_id , e.event_id , e.location , e.`time` , e.ticket_url , e.map_url FROM {SERIES_TABLE} s @@ -33,7 +33,7 @@ class EventQueries(BaseQueries): db.close() return data - async def select_all_series(self) -> list[dict]: + def select_all_series(self) -> list[dict]: """ Queries for all Series and Event info and returns a list of dictionaries. Data is gathered with a LEFT JOIN on the Event table to ensure all Series are returned. @@ -54,7 +54,7 @@ class EventQueries(BaseQueries): db.close() return data - async def insert_one_series(self, series: NewEventSeries) -> int: + def insert_one_series(self, series: NewEventSeries) -> int: query = f""" INSERT INTO {self.table} (name, description) VALUES (%s, %s) @@ -74,7 +74,7 @@ class EventQueries(BaseQueries): db.close() return inserted_id - async def insert_one_event(self, event: NewEvent, series_id: int) -> int: + def insert_one_event(self, event: NewEvent, series_id: int) -> int: query = f""" INSERT INTO {EVENT_TABLE} (series_id, location, time, ticket_url, map_url) VALUES (%s, %s, %s, %s, %s) @@ -92,7 +92,7 @@ class EventQueries(BaseQueries): db.close() return iserted_id - async def delete_events_by_series(self, series: EventSeries) -> None: + def delete_events_by_series(self, series: EventSeries) -> None: query = f""" DELETE FROM {EVENT_TABLE} WHERE series_id = %s @@ -103,7 +103,7 @@ class EventQueries(BaseQueries): db.commit() cursor.close() - async def delete_one_series(self, series: EventSeries) -> None: + def delete_one_series(self, series: EventSeries) -> None: query = f""" DELETE FROM {self.table} WHERE series_id = %s @@ -114,7 +114,7 @@ class EventQueries(BaseQueries): db.commit() cursor.close() - async def update_series_poster(self, series: EventSeries) -> None: + def update_series_poster(self, series: EventSeries) -> None: query = f""" UPDATE {self.table} SET poster_id = %s @@ -126,7 +126,7 @@ class EventQueries(BaseQueries): db.commit() cursor.close() - async def replace_event(self, event: Event) -> None: + def replace_event(self, event: Event) -> None: query = f""" UPDATE {EVENT_TABLE} SET location = %s, time = %s, ticket_url = %s, map_url = %s @@ -143,7 +143,7 @@ class EventQueries(BaseQueries): cursor.close() db.close() - async def replace_series(self, series: EventSeries) -> None: + def replace_series(self, series: EventSeries) -> None: query = f""" UPDATE {self.table} SET name = %s, description = %s, poster_id = %s diff --git a/server/app/db/group.py b/server/app/db/group.py index 18ba764..f3e618d 100644 --- a/server/app/db/group.py +++ b/server/app/db/group.py @@ -7,7 +7,7 @@ class GroupQueries(BaseQueries): super().__init__() self.table = GROUP_TABLE - async def select_one_by_id(self) -> dict: + def select_one_by_id(self) -> dict: query = f"SELECT * FROM {self.table}" db = self.connect_db() cursor = db.cursor(dictionary=True) @@ -21,12 +21,12 @@ class GroupQueries(BaseQueries): return data - async def select_all_series(self) -> None: + def select_all_series(self) -> None: raise NotImplementedError( "get_all method not implemented for GroupQueries. There's only one row in the table." ) - async def update_group_bio(self, bio: str) -> None: + def update_group_bio(self, bio: str) -> None: db = self.connect_db() cursor = db.cursor() query = f"UPDATE {self.table} SET bio = %s WHERE id = 1" # only one row in the table diff --git a/server/app/db/musicians.py b/server/app/db/musicians.py index 90627a4..e0ae079 100644 --- a/server/app/db/musicians.py +++ b/server/app/db/musicians.py @@ -11,7 +11,7 @@ class MusicianQueries(BaseQueries): super().__init__() self.table = MUSICIAN_TABLE - async def update_bio(self, musician: Musician, bio: str) -> None: + def update_bio(self, musician: Musician, bio: str) -> None: """Updates a musician's biography in the database. Args: @@ -26,7 +26,7 @@ class MusicianQueries(BaseQueries): cursor.close() db.close() - async def update_headshot(self, musician: Musician, headshot_id: str) -> None: + def update_headshot(self, musician: Musician, headshot_id: str) -> None: """Updates a musician's headshot ID in the database. The image itself is stored with Cloudinary. diff --git a/server/app/db/users.py b/server/app/db/users.py index 6b59e0b..2367ff3 100644 --- a/server/app/db/users.py +++ b/server/app/db/users.py @@ -7,7 +7,7 @@ class UserQueries(BaseQueries): super().__init__() self.table = USER_TABLE - async def get_one_by_email(self, email: str) -> dict | None: + def get_one_by_email(self, email: str) -> dict | None: query = f"SELECT * FROM {self.table} WHERE email = %s" db = self.connect_db() cursor = db.cursor(dictionary=True) @@ -18,7 +18,7 @@ class UserQueries(BaseQueries): return data - async def get_one_by_sub(self, sub: str) -> dict | None: + def get_one_by_sub(self, sub: str) -> dict | None: query = f"SELECT * FROM {self.table} WHERE sub = %s" db = self.connect_db() cursor = db.cursor(dictionary=True) @@ -32,7 +32,7 @@ class UserQueries(BaseQueries): return data - async def update_sub(self, email: str, sub: str) -> None: + def update_sub(self, email: str, sub: str) -> None: query = f"UPDATE {self.table} SET sub = %s WHERE email = %s" db = self.connect_db() cursor = db.cursor() diff --git a/server/app/main.py b/server/app/main.py index 6bfae60..9b0371d 100644 --- a/server/app/main.py +++ b/server/app/main.py @@ -9,7 +9,6 @@ from app.routers.contact import router as contact_router from app.routers.events import router as event_router from app.routers.group import router as group_router from app.routers.musicians import router as musician_router -from app.routers.users import router as user_router from app.scripts.version import get_version app = FastAPI( @@ -21,7 +20,6 @@ app.include_router(musician_router) app.include_router(group_router) app.include_router(contact_router) app.include_router(event_router) -app.include_router(user_router) controller = MainController() diff --git a/server/app/routers/users.py b/server/app/routers/users.py index c2c1dc6..c44bcc3 100644 --- a/server/app/routers/users.py +++ b/server/app/routers/users.py @@ -12,6 +12,15 @@ router = APIRouter( ) +""" +Note: this router is not currently registered in the main FastAPI app. +This is to facilitate not exposing the user routes to the public API, +but this may change in a future version. + +The file remains to ease the future addition of user routes. +""" + + @router.get("/", status_code=status.HTTP_200_OK) async def get_users() -> list[User]: return await controller.get_users() diff --git a/server/tests/controllers/test_event_controller.py b/server/tests/controllers/test_event_controller.py index a4b9897..2db327c 100644 --- a/server/tests/controllers/test_event_controller.py +++ b/server/tests/controllers/test_event_controller.py @@ -22,23 +22,21 @@ def test_type(): assert isinstance(ec, EventController) -@pytest.mark.asyncio -async def test_all_series_with_no_data(): +def test_all_series_with_no_data(): """Tests with absent data.""" - async def no_series() -> list[dict]: + def no_series() -> list[dict]: return [] mock_queries.select_all_series = no_series - result = await ec.get_all_series() + result = ec.get_all_series() assert result == [] -@pytest.mark.asyncio -async def test_all_series_with_basic_data(): +def test_all_series_with_basic_data(): """Tests a single valid row with no event info""" - async def one_series_with_no_events() -> list[dict]: + def one_series_with_no_events() -> list[dict]: return [ { "name": "Test Series", @@ -49,7 +47,7 @@ async def test_all_series_with_basic_data(): ] mock_queries.select_all_series = one_series_with_no_events - result = await ec.get_all_series() + result = ec.get_all_series() assert isinstance(result, list) assert len(result) == 1 series = result[0] @@ -61,8 +59,7 @@ async def test_all_series_with_basic_data(): assert series.events == [] -@pytest.mark.asyncio -async def test_all_series_with_detailed_data(): +def test_all_series_with_detailed_data(): """Tests a single valid row with event info""" series_id = 1 @@ -70,7 +67,7 @@ async def test_all_series_with_detailed_data(): series_description = "Pieces by Danzi, Gomez, Gounod, Grant, and Rusnak!" poster_id = "The_Grapefruits_Present_qhng6y" - async def one_series_with_events() -> list[dict]: + def one_series_with_events() -> list[dict]: row_1 = { "series_id": series_id, @@ -108,7 +105,7 @@ async def test_all_series_with_detailed_data(): ] mock_queries.select_all_series = one_series_with_events - result = await ec.get_all_series() + result = ec.get_all_series() assert isinstance(result, list) assert len(result) == 1 series = result[0] @@ -148,11 +145,10 @@ async def test_all_series_with_detailed_data(): assert e3.time == datetime(2024, 6, 23, 15, 0) -@pytest.mark.asyncio -async def test_all_series_with_many_series(): +def test_all_series_with_many_series(): """Tests multiple series with no events.""" - async def many_series() -> list[dict]: + def many_series() -> list[dict]: return [ { "name": "Test Series", @@ -175,7 +171,7 @@ async def test_all_series_with_many_series(): ] mock_queries.select_all_series = many_series - result = await ec.get_all_series() + result = ec.get_all_series() assert isinstance(result, list) assert len(result) == 3 for series in result: @@ -183,14 +179,13 @@ async def test_all_series_with_many_series(): assert series.events == [] -@pytest.mark.asyncio -async def test_all_series_with_error(): +def test_all_series_with_error(): """Tests an error during the retrieval process.""" mock_log_error = Mock() ec.log_error = mock_log_error - async def invalid_series() -> list[dict]: + def invalid_series() -> list[dict]: # no series name return [ { @@ -202,13 +197,12 @@ async def test_all_series_with_error(): mock_queries.select_all_series = invalid_series with pytest.raises(Exception): - await ec.get_all_series() + ec.get_all_series() Mock.assert_called_once(mock_log_error) -@pytest.mark.asyncio -async def test_one_series(): - async def one_series_no_events(series_id: int) -> list[dict]: +def test_one_series(): + def one_series_no_events(series_id: int) -> list[dict]: return [ { "series_id": series_id, @@ -219,7 +213,7 @@ async def test_one_series(): ] mock_queries.select_one_by_id = one_series_no_events - series = await ec.get_one_series_by_id(1) + series = ec.get_one_series_by_id(1) assert isinstance(series, EventSeries) assert series.name == "Test Series" assert series.description == "Test Description" @@ -228,9 +222,8 @@ async def test_one_series(): assert series.events == [] -@pytest.mark.asyncio -async def test_one_series_with_events(): - async def one_series_with_events(series_id: int) -> list[dict]: +def test_one_series_with_events(): + def one_series_with_events(series_id: int) -> list[dict]: return [ { "series_id": series_id, @@ -264,7 +257,7 @@ async def test_one_series_with_events(): ] mock_queries.select_one_by_id = one_series_with_events - series = await ec.get_one_series_by_id(1) + series = ec.get_one_series_by_id(1) assert isinstance(series, EventSeries) assert series.name == "Test Series" assert series.description == "Test Description" diff --git a/server/tests/controllers/test_musician_controller.py b/server/tests/controllers/test_musician_controller.py index f8b5e0d..3373001 100644 --- a/server/tests/controllers/test_musician_controller.py +++ b/server/tests/controllers/test_musician_controller.py @@ -35,15 +35,15 @@ bad_data = [ ] -async def mock_select_all_series(): +def mock_select_all_series(): return sample_data -async def mock_select_all_series_sad(): +def mock_select_all_series_sad(): return bad_data -async def mock_select_one_by_id(musician_id: int): +def mock_select_one_by_id(musician_id: int): for musician in sample_data: if musician.get("id") == musician_id: return musician @@ -67,9 +67,8 @@ TODO: write tests for following methods: """ -@pytest.mark.asyncio -async def test_happy_get_musicians(): - musicians = await ec.get_musicians() +def test_happy_get_musicians(): + musicians = ec.get_musicians() assert isinstance(musicians, list) assert len(musicians) == 2 for musician in musicians: @@ -85,19 +84,17 @@ async def test_happy_get_musicians(): assert m2.headshot_id == "headshotABC" -@pytest.mark.asyncio -async def test_sad_get_musicians(): +def test_sad_get_musicians(): mock_queries.select_all_series = mock_select_all_series_sad with pytest.raises(HTTPException) as e: - await ec.get_musicians() + ec.get_musicians() mock_queries.select_all_series = mock_select_all_series assert isinstance(e.value, HTTPException) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR -@pytest.mark.asyncio -async def test_happy_get_musician(): - musician = await ec.get_musician(1) +def test_happy_get_musician(): + musician = ec.get_musician(1) assert isinstance(musician, Musician) assert musician.id == 1 assert musician.name == "John Doe" @@ -105,10 +102,9 @@ async def test_happy_get_musician(): assert musician.headshot_id == "headshot123" -@pytest.mark.asyncio -async def test_musician_not_found(): +def test_musician_not_found(): with pytest.raises(HTTPException) as e: - await ec.get_musician(3) + ec.get_musician(3) assert isinstance(e.value, HTTPException) assert e.value.status_code == status.HTTP_404_NOT_FOUND assert e.value.detail == "Musician not found" From 971073c19f728aacba120906721f86aee8d5d69f Mon Sep 17 00:00:00 2001 From: Lucas Jensen Date: Fri, 3 May 2024 13:46:01 -0700 Subject: [PATCH 2/2] fixed user route registration --- server/app/main.py | 2 ++ server/app/routers/users.py | 9 --------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/server/app/main.py b/server/app/main.py index 9b0371d..6bfae60 100644 --- a/server/app/main.py +++ b/server/app/main.py @@ -9,6 +9,7 @@ from app.routers.contact import router as contact_router from app.routers.events import router as event_router from app.routers.group import router as group_router from app.routers.musicians import router as musician_router +from app.routers.users import router as user_router from app.scripts.version import get_version app = FastAPI( @@ -20,6 +21,7 @@ app.include_router(musician_router) app.include_router(group_router) app.include_router(contact_router) app.include_router(event_router) +app.include_router(user_router) controller = MainController() diff --git a/server/app/routers/users.py b/server/app/routers/users.py index c44bcc3..c2c1dc6 100644 --- a/server/app/routers/users.py +++ b/server/app/routers/users.py @@ -12,15 +12,6 @@ router = APIRouter( ) -""" -Note: this router is not currently registered in the main FastAPI app. -This is to facilitate not exposing the user routes to the public API, -but this may change in a future version. - -The file remains to ease the future addition of user routes. -""" - - @router.get("/", status_code=status.HTTP_200_OK) async def get_users() -> list[User]: return await controller.get_users()