Merge pull request #6 from ljensen505/docstring-formatting

updated more docstrings for various controllers; prepped musician tests
This commit is contained in:
Lucas Jensen
2024-05-03 10:25:27 -07:00
committed by GitHub
5 changed files with 299 additions and 62 deletions

View File

@@ -22,8 +22,16 @@ class BaseController:
self.MAX_FILE_SIZE = ONE_MB self.MAX_FILE_SIZE = ONE_MB
async def verify_image(self, file: UploadFile) -> bytes: async def verify_image(self, file: UploadFile) -> bytes:
""" """Verifies that the file is an image and is within the maximum file size.
Verifies that the file is an image and is within the maximum file size.
Args:
file (UploadFile): The file to be verified
Raises:
HTTPException: If the file is not an image or exceeds the maximum file size (status code 400)
Returns:
bytes: The file contents as bytes
""" """
if file.content_type not in self.ALL_FILES: if file.content_type not in self.ALL_FILES:
raise HTTPException( raise HTTPException(
@@ -39,8 +47,10 @@ class BaseController:
return image_file return image_file
def log_error(self, e: Exception) -> None: def log_error(self, e: Exception) -> None:
""" """Logs an error to a timestamped text file in the logs directory.
Logs an error to a timestamped text file in the logs directory.
Args:
e (Exception): Any exception object
""" """
curr_dir = Path(__file__).parent curr_dir = Path(__file__).parent
log_dir = curr_dir / "logs" log_dir = curr_dir / "logs"

View File

@@ -18,16 +18,19 @@ class EventController(BaseController):
Testing: pass a mocked EventQueries object to the constructor. Testing: pass a mocked EventQueries object to the constructor.
""" """
def __init__(self, eq=event_queries) -> None: def __init__(self, event_queries=event_queries) -> None:
super().__init__() super().__init__()
self.db: EventQueries = eq self.db: EventQueries = event_queries
def _all_series(self, data: list[dict]) -> dict[str, EventSeries]: def _all_series(self, data: list[dict]) -> dict[str, EventSeries]:
""" """Creates and returns a dictionary of EventSeries objects from a list of sql rows (as dicts).
Helper method to instantiate EventSeries objects from sql rows (a list of dictionaries). Should only be used internally.
Instantiation is done by destructuring the dictionary into the EventSeries constructor.
Should not be called directly; use get_all_series() instead. Args:
series.name is a required and unique field and can reliably be used as a key in a dictionary. data (list[dict]): List of dicts, each representing a row from the database. `event_id` may be null
Returns:
dict[str, EventSeries]: A dictionary of EventSeries objects, keyed by series name
""" """
all_series: dict[str, EventSeries] = {} all_series: dict[str, EventSeries] = {}
@@ -41,12 +44,13 @@ class EventController(BaseController):
return all_series return all_series
async def get_all_series(self) -> list[EventSeries]: async def get_all_series(self) -> list[EventSeries]:
""" """Retrieves all EventSeries objects from the database and returns them as a list.
Attempts to create and return a list of EventSeries objects and is consumed by the main controller.
Will trigger a 500 status code if any exception is raised, and log the error to a timestamped text file.
The list of EventSeries is created by calling the _all_series() helper method, which provided this data Raises:
as a list of dicts. HTTPException: If any error occurs during the retrieval process (status code 500)
Returns:
list[EventSeries]: A list of EventSeries objects which are suitable for a response body
""" """
series_data = await self.db.select_all_series() series_data = await self.db.select_all_series()
@@ -60,16 +64,25 @@ class EventController(BaseController):
) )
async def get_one_series_by_id(self, series_id: int) -> EventSeries: async def get_one_series_by_id(self, series_id: int) -> EventSeries:
"""Builds and returns a single EventSeries object by its numeric ID.
Args:
series_id (int): The numeric id of the series to retrieve
Raises:
HTTPException: If the series is not found (status code 404)
HTTPException: If an error occurs (status code 500)
Returns:
EventSeries: A single EventSeries object
""" """
Builds and returns a single EventSeries object by its numeric ID. if not (rows := await self.db.select_one_by_id(series_id)):
"""
if not (data := await self.db.select_one_by_id(series_id)):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Event not found" status_code=status.HTTP_404_NOT_FOUND, detail="Event not found"
) )
try: try:
return EventSeries( return EventSeries(
**data[0], events=[Event(**e) for e in data if e["event_id"]] **rows[0], events=[Event(**row) for row in rows if row.get("event_id")]
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@@ -78,8 +91,16 @@ class EventController(BaseController):
) )
async def create_series(self, series: NewEventSeries) -> EventSeries: async def create_series(self, series: NewEventSeries) -> EventSeries:
""" """Takes a NewEventSeries object and passes it to the database layer for insertion.
Takes a NewEventSeries object and passes it to the database layer for insertion.
Args:
series (NewEventSeries): A NewEventSeries object which does not yet have an ID
Raises:
HTTPException: If the series name already exists (status code 400)
Returns:
EventSeries: The newly created EventSeries object with an ID
""" """
try: try:
inserted_id = await self.db.insert_one_series(series) inserted_id = await self.db.insert_one_series(series)
@@ -92,10 +113,17 @@ class EventController(BaseController):
detail=f"Series name already exists. Each series must have a unique name.\n{e}", detail=f"Series name already exists. Each series must have a unique name.\n{e}",
) )
async def add_series_poster(self, series_id, poster: UploadFile) -> EventSeries: async def add_series_poster(
""" self, series_id: int, poster: UploadFile
Adds (or updates) a poster image to a series. ) -> EventSeries:
Actual image storage is done with Cloudinary and the public ID is stored in the database. """Adds (or updates) a poster image to a series.
Args:
series_id (int): The numeric ID of the series to update
poster (UploadFile): The image file to upload
Returns:
EventSeries: The updated EventSeries object
""" """
series = await self.get_one_series_by_id(series_id) series = await self.get_one_series_by_id(series_id)
series.poster_id = await self._upload_poster(poster) series.poster_id = await self._upload_poster(poster)
@@ -103,8 +131,17 @@ class EventController(BaseController):
return await self.get_one_series_by_id(series.series_id) return await self.get_one_series_by_id(series.series_id)
async def _upload_poster(self, poster: UploadFile) -> str: async def _upload_poster(self, poster: UploadFile) -> str:
""" """Uploads a poster image to Cloudinary and returns the public ID for storage in the database.
Uploads a poster image to Cloudinary and returns the public ID for storage in the database. Should only be used internally.
Args:
poster (UploadFile): The image file to upload
Raises:
HTTPException: If an error occurs during the upload process (status code 500)
Returns:
str: The public ID of the uploaded image
""" """
image_file = await self.verify_image(poster) image_file = await self.verify_image(poster)
try: try:
@@ -117,15 +154,27 @@ class EventController(BaseController):
) )
async def delete_series(self, id: int) -> None: async def delete_series(self, id: int) -> None:
""" """Ensures an EventSeries object exists and then deletes it from the database.
Ensures an EventSeries object exists and then deletes it from the database
Args:
id (int): The numeric ID of the series to delete
""" """
series = await self.get_one_series_by_id(id) series = await self.get_one_series_by_id(id)
await self.db.delete_one_series(series) await self.db.delete_one_series(series)
async def update_series(self, route_id: int, series: EventSeries) -> EventSeries: async def update_series(self, route_id: int, series: EventSeries) -> EventSeries:
""" """Updates an existing EventSeries object in the database.
Updates an existing EventSeries object in the database.
Args:
route_id (int): The numeric ID of the series in the URL
series (EventSeries): The updated EventSeries object
Raises:
HTTPException: if the ID in the URL does not match the ID in the request body (status code 400)
HTTPException: if the poster ID is updated directly (status code 400)
Returns:
EventSeries: The updated EventSeries object with updated info
""" """
if route_id != series.series_id: if route_id != series.series_id:
raise HTTPException( raise HTTPException(

View File

@@ -9,11 +9,27 @@ from app.models.musician import Musician
class MusicianController(BaseController): class MusicianController(BaseController):
def __init__(self) -> None: """
Handles all musician-related operations and serves as an intermediate controller between
the main controller and the model layer.
Inherits from BaseController, which provides logging and other generic methods.
Testing: pass a mocked MusicianQueries object to the constructor.
"""
def __init__(self, musician_queries=musician_queries) -> None:
super().__init__() super().__init__()
self.db: MusicianQueries = musician_queries self.db: MusicianQueries = musician_queries
async def get_musicians(self) -> list[Musician]: async def get_musicians(self) -> list[Musician]:
"""Retrieves all musicians from the database and returns them as a list of Musician objects.
Raises:
HTTPException: If any error occurs during the retrieval process (status code 500)
Returns:
list[Musician]: A list of Musician objects which are suitable for a response body
"""
data = await self.db.select_all_series() data = await self.db.select_all_series()
try: try:
return [Musician(**m) for m in data] return [Musician(**m) for m in data]
@@ -23,8 +39,20 @@ class MusicianController(BaseController):
detail=f"Error creating musician objects: {e}", detail=f"Error creating musician objects: {e}",
) )
async def get_musician(self, id: int) -> Musician: async def get_musician(self, musician_id: int) -> Musician:
if (data := await self.db.select_one_by_id(id)) is None: """Retrieves a single musician from the database and returns it as a Musician object.
Args:
id (int): The ID of the musician to retrieve
Raises:
HTTPException: If the musician is not found (status code 404)
HTTPException: If any error occurs during the retrieval process (status code 500)
Returns:
Musician: A Musician object which is suitable for a response body
"""
if (data := await self.db.select_one_by_id(musician_id)) is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Musician not found" status_code=status.HTTP_404_NOT_FOUND, detail="Musician not found"
) )
@@ -42,40 +70,91 @@ class MusicianController(BaseController):
new_bio: str, new_bio: str,
file: UploadFile | None = None, file: UploadFile | None = None,
) -> Musician: ) -> Musician:
"""Updates a musician's bio and/or headshot by conditionally calling the appropriate methods.
Args:
musician_id (int): The numeric ID of the musician to update
new_bio (str): The new biography for the musician
file (UploadFile | None, optional): The new headshot file. Defaults to None.
Raises:
HTTPException: If the musician is not found (status code 404)
Returns:
Musician: The updated Musician object
"""
musician = await self.get_musician(musician_id) musician = await self.get_musician(musician_id)
if new_bio != musician.bio: if new_bio != musician.bio:
return await self.update_musician_bio(musician.id, new_bio) return await self._update_musician_bio(musician.id, new_bio)
if file is not None: if file is not None:
return await self.upload_headshot(musician.id, file) return await self._upload_headshot(musician.id, file)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Update operation not implemented. Neither the bio or headshot was updated.", detail="Update operation not implemented. Neither the bio or headshot was updated.",
) )
async def update_musician_headshot(self, id: int, headshot_id: str) -> Musician: async def update_musician_headshot(
await self.get_musician(id) self, musician_id: int, headshot_id: str
) -> Musician:
"""Updates a musician's headshot in the database.
Args:
id (int): The numeric ID of the musician to update
headshot_id (str): The public ID of the new headshot (as determined by Cloudinary)
Raises:
HTTPException: If any error occurs during the update process (status code 500)
Returns:
Musician: The updated Musician object
"""
await self.get_musician(musician_id)
try: try:
await self.db.update_headshot(id, headshot_id) await self.db.update_headshot(musician_id, headshot_id)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error updating musician headshot: {e}", detail=f"Error updating musician headshot: {e}",
) )
return await self.get_musician(id) return await self.get_musician(musician_id)
async def update_musician_bio(self, id: int, bio: str) -> Musician: async def _update_musician_bio(self, musician_id: int, bio: str) -> Musician:
await self.get_musician(id) # Check if musician exists """Updates a musician's bio in the database.
Args:
id (int): The numeric ID of the musician to update
bio (str): The new biography for the musician
Raises:
HTTPException: If any error occurs during the update process (status code 500)
Returns:
Musician: The updated Musician object
"""
await self.get_musician(musician_id) # Check if musician exists
try: try:
await self.db.update_bio(id, bio) await self.db.update_bio(musician_id, bio)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error updating musician bio: {e}", detail=f"Error updating musician bio: {e}",
) )
return await self.get_musician(id) return await self.get_musician(musician_id)
async def upload_headshot(self, id: int, file: UploadFile) -> Musician: async def _upload_headshot(self, id: int, file: UploadFile) -> Musician:
"""Uploads a new headshot for a musician and updates the database with the new public ID.
Args:
id (int): The numeric ID of the musician to update
file (UploadFile): The new headshot file
Raises:
HTTPException: If the file is not an image or exceeds the maximum file size (status code 400)
Returns:
Musician: The updated Musician object
"""
image_file = await self.verify_image(file) image_file = await self.verify_image(file)
data = uploader.upload(image_file) data = uploader.upload(image_file)
public_id = data.get("public_id") public_id = data.get("public_id")

View File

@@ -2,13 +2,20 @@ from datetime import datetime
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from icecream import ic
from pydantic_core import Url from pydantic_core import Url
from app.controllers.events import EventController from app.controllers.events import EventController
from app.models.event import Event, EventSeries from app.models.event import Event, EventSeries
mock_queries = Mock() mock_queries = Mock()
ec = EventController(eq=mock_queries) ec = EventController(event_queries=mock_queries)
eventbrite_url = "https://www.eventbrite.com/e/the-grapefruits-duo-presents-works-for-horn-and-piano-tickets-1234567890"
medford = "Medford, OR"
newport_church = "First Presbyterian Church Newport"
eugene_church = "First Church of Christ, Scientist, Eugene"
map_url = "https://maps.app.goo.gl/hNfN8X5FBZLg8LDF8"
def test_type(): def test_type():
@@ -17,17 +24,21 @@ def test_type():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_series_with_no_data(): async def test_all_series_with_no_data():
async def select_all_series() -> list[dict]: """Tests with absent data."""
async def no_series() -> list[dict]:
return [] return []
mock_queries.select_all_series = select_all_series mock_queries.select_all_series = no_series
result = await ec.get_all_series() result = await ec.get_all_series()
assert result == [] assert result == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_series_with_basic_data(): async def test_all_series_with_basic_data():
async def select_all_series() -> list[dict]: """Tests a single valid row with no event info"""
async def one_series_with_no_events() -> list[dict]:
return [ return [
{ {
"name": "Test Series", "name": "Test Series",
@@ -37,7 +48,7 @@ async def test_all_series_with_basic_data():
} }
] ]
mock_queries.select_all_series = select_all_series mock_queries.select_all_series = one_series_with_no_events
result = await ec.get_all_series() result = await ec.get_all_series()
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 1 assert len(result) == 1
@@ -52,17 +63,14 @@ async def test_all_series_with_basic_data():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_series_with_detailed_data(): async def test_all_series_with_detailed_data():
"""Tests a single valid row with event info"""
series_id = 1 series_id = 1
series_name = "The Grapefruits Duo Presents: Works for Horn and Piano" series_name = "The Grapefruits Duo Presents: Works for Horn and Piano"
series_description = "Pieces by Danzi, Gomez, Gounod, Grant, and Rusnak!" series_description = "Pieces by Danzi, Gomez, Gounod, Grant, and Rusnak!"
poster_id = "The_Grapefruits_Present_qhng6y" poster_id = "The_Grapefruits_Present_qhng6y"
eventbrite_url = "https://www.eventbrite.com/e/the-grapefruits-duo-presents-works-for-horn-and-piano-tickets-1234567890"
medford = "Medford, OR"
newport_church = "First Presbyterian Church Newport"
eugene_church = "First Church of Christ, Scientist, Eugene"
map_url = "https://maps.app.goo.gl/hNfN8X5FBZLg8LDF8"
async def select_all_series() -> list[dict]: async def one_series_with_events() -> list[dict]:
row_1 = { row_1 = {
"series_id": series_id, "series_id": series_id,
@@ -99,7 +107,7 @@ async def test_all_series_with_detailed_data():
row_3, row_3,
] ]
mock_queries.select_all_series = select_all_series mock_queries.select_all_series = one_series_with_events
result = await ec.get_all_series() result = await ec.get_all_series()
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 1 assert len(result) == 1
@@ -142,7 +150,9 @@ async def test_all_series_with_detailed_data():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_series_with_many_series(): async def test_all_series_with_many_series():
async def select_all_series() -> list[dict]: """Tests multiple series with no events."""
async def many_series() -> list[dict]:
return [ return [
{ {
"name": "Test Series", "name": "Test Series",
@@ -164,7 +174,7 @@ async def test_all_series_with_many_series():
}, },
] ]
mock_queries.select_all_series = select_all_series mock_queries.select_all_series = many_series
result = await ec.get_all_series() result = await ec.get_all_series()
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 3 assert len(result) == 3
@@ -175,10 +185,12 @@ async def test_all_series_with_many_series():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_series_with_error(): async def test_all_series_with_error():
"""Tests an error during the retrieval process."""
mock_log_error = Mock() mock_log_error = Mock()
ec.log_error = mock_log_error ec.log_error = mock_log_error
async def select_all_series() -> list[dict]: async def invalid_series() -> list[dict]:
# no series name # no series name
return [ return [
{ {
@@ -188,7 +200,77 @@ async def test_all_series_with_error():
} }
] ]
mock_queries.select_all_series = select_all_series mock_queries.select_all_series = invalid_series
with pytest.raises(Exception): with pytest.raises(Exception):
await ec.get_all_series() await ec.get_all_series()
Mock.assert_called_once(mock_log_error) Mock.assert_called_once(mock_log_error)
@pytest.mark.asyncio
async def test_one_series():
async def one_series_no_events(series_id: int) -> list[dict]:
return [
{
"series_id": series_id,
"name": "Test Series",
"description": "Test Description",
"poster_id": "abc123",
}
]
mock_queries.select_one_by_id = one_series_no_events
series = await ec.get_one_series_by_id(1)
assert isinstance(series, EventSeries)
assert series.name == "Test Series"
assert series.description == "Test Description"
assert series.series_id == 1
assert series.poster_id == "abc123"
assert series.events == []
@pytest.mark.asyncio
async def test_one_series_with_events():
async def one_series_with_events(series_id: int) -> list[dict]:
return [
{
"series_id": series_id,
"name": "Test Series",
"description": "Test Description",
"poster_id": "abc123",
"event_id": 1,
"location": medford,
"time": "2024-05-31 19:00:00.000",
"ticket_url": eventbrite_url,
},
{
"series_id": series_id,
"name": "Test Series",
"description": "Test Description",
"poster_id": "abc123",
"event_id": 2,
"location": newport_church,
"time": "2024-06-16 16:00:00.000",
"map_url": map_url,
},
{
"series_id": series_id,
"name": "Test Series",
"description": "Test Description",
"poster_id": "abc123",
"event_id": 3,
"location": eugene_church,
"time": "2024-06-23 15:00:00.000",
},
]
mock_queries.select_one_by_id = one_series_with_events
series = await ec.get_one_series_by_id(1)
assert isinstance(series, EventSeries)
assert series.name == "Test Series"
assert series.description == "Test Description"
assert series.series_id == 1
assert series.poster_id == "abc123"
events = series.events
assert len(events) == 3
for event in events:
assert isinstance(event, Event)

View File

@@ -0,0 +1,17 @@
from unittest.mock import Mock
import pytest
from icecream import ic
from app.controllers.musicians import MusicianController
from app.models.musician import Musician
mock_queries = Mock()
ec = MusicianController(musician_queries=mock_queries)
def test_type():
assert isinstance(ec, MusicianController)
# TODO: Write tests for MusicianController