Merge pull request #4 from ljensen505/add-docstrings

added docstrings to event controller and moved constants
This commit is contained in:
Lucas Jensen
2024-05-02 18:34:22 -07:00
committed by GitHub
16 changed files with 100 additions and 42 deletions

View File

@@ -0,0 +1,9 @@
ALLOWED_FILES_TYPES = ["image/jpeg", "image/png"]
ONE_MB = 1000000
# sql table names
SERIES_TABLE = "series"
EVENT_TABLE = "events"
GROUP_TABLE = "group_table"
MUSICIAN_TABLE = "musicians"
USER_TABLE = "users"

View File

@@ -1,3 +1,3 @@
from .controller import Controller from .controller import MainController
controller = Controller() controller = MainController()

View File

@@ -6,20 +6,25 @@ from pathlib import Path
from fastapi import HTTPException, UploadFile, status from fastapi import HTTPException, UploadFile, status
from icecream import ic from icecream import ic
from app.constants import ALLOWED_FILES_TYPES, ONE_MB
from app.db.base_queries import BaseQueries from app.db.base_queries import BaseQueries
ALLOWED_FILES_TYPES = ["image/jpeg", "image/png"]
MAX_FILE_SIZE = 1000000 # 1 MB
class BaseController: class BaseController:
"""
A generic controller class which includes logging, image verification, and other common methods.
Model-specific controllers should inherit from this class and this class should not be instantiated directly.
"""
def __init__(self) -> None: def __init__(self) -> None:
self.db: BaseQueries = None # type: ignore self.db: BaseQueries = None # type: ignore
self.ALL_FILES = ALLOWED_FILES_TYPES self.ALL_FILES = ALLOWED_FILES_TYPES
self.MAX_FILE_SIZE = MAX_FILE_SIZE self.MAX_FILE_SIZE = ONE_MB
async def verify_image(self, file: UploadFile) -> bytes: async def verify_image(self, file: UploadFile) -> bytes:
print("verifying image") """
Verifies that the file is an image and is within the maximum file size.
"""
if file.content_type not in self.ALL_FILES: if file.content_type not in self.ALL_FILES:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@@ -34,6 +39,9 @@ class BaseController:
return image_file return image_file
def log_error(self, e: Exception) -> None: def log_error(self, e: Exception) -> None:
"""
Logs an error to a timestamped text file in the logs directory.
"""
curr_dir = Path(__file__).parent curr_dir = Path(__file__).parent
log_dir = curr_dir / "logs" log_dir = curr_dir / "logs"
log_dir.mkdir(exist_ok=True) log_dir.mkdir(exist_ok=True)

View File

@@ -13,7 +13,15 @@ from app.models.musician import Musician
from app.models.user import User from app.models.user import User
class Controller: class MainController:
"""
The main controller and entry point for all API requests.
All methods are either pass-throughs to the appropriate controller or
are used to coordinate multiple controllers.
token-based authentication is handled here as needed per the nature of the data being accessed.
"""
def __init__(self) -> None: def __init__(self) -> None:
self.event_controller = EventController() self.event_controller = EventController()
self.musician_controller = MusicianController() self.musician_controller = MusicianController()

View File

@@ -10,11 +10,25 @@ from app.models.event import Event, EventSeries, NewEventSeries
class EventController(BaseController): class EventController(BaseController):
"""
Handles all event-related operations and serves as an intermediate controller between
the main controller and the model layer.
Inherits from BaseController, which provides logging and other generic methods.
Testing: pass a mocked EventQueries object to the constructor.
"""
def __init__(self, eq=event_queries) -> None: def __init__(self, eq=event_queries) -> None:
super().__init__() super().__init__()
self.db: EventQueries = eq self.db: EventQueries = eq
def _all_series(self, data: list[dict]) -> list[EventSeries]: def _all_series(self, data: list[dict]) -> dict[str, EventSeries]:
"""
Helper method to instantiate EventSeries objects from sql rows (a list of dictionaries).
Instantiation is done by destructuring the dictionary into the EventSeries constructor.
Should not be called directly; use get_all_series() instead.
series.name is a required and unique field and can reliably be used as a key in a dictionary.
"""
all_series: dict[str, EventSeries] = {} all_series: dict[str, EventSeries] = {}
for event_series_row in data: for event_series_row in data:
@@ -24,13 +38,20 @@ class EventController(BaseController):
if event_series_row.get("event_id"): if event_series_row.get("event_id"):
all_series[series_name].events.append(Event(**event_series_row)) all_series[series_name].events.append(Event(**event_series_row))
return [series for series in all_series.values()] return all_series
async def get_all_series(self) -> list[EventSeries]: async def get_all_series(self) -> list[EventSeries]:
"""
Attempts to create and return a list of EventSeries objects and is consumed by the main controller.
Will trigger a 500 status code if any exception is raised, and log the error to a timestamped text file.
The list of EventSeries is created by calling the _all_series() helper method, which provided this data
as a list of dicts.
"""
series_data = await self.db.select_all_series() series_data = await self.db.select_all_series()
try: try:
return self._all_series(series_data) return [series for series in self._all_series(series_data).values()]
except Exception as e: except Exception as e:
self.log_error(e) self.log_error(e)
raise HTTPException( raise HTTPException(
@@ -39,6 +60,9 @@ class EventController(BaseController):
) )
async def get_one_series_by_id(self, series_id: int) -> EventSeries: async def get_one_series_by_id(self, series_id: int) -> EventSeries:
"""
Builds and returns a single EventSeries object by its numeric ID.
"""
if not (data := await self.db.select_one_series_by_id(series_id)): if not (data := await self.db.select_one_series_by_id(series_id)):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Event not found" status_code=status.HTTP_404_NOT_FOUND, detail="Event not found"
@@ -54,6 +78,9 @@ class EventController(BaseController):
) )
async def create_series(self, series: NewEventSeries) -> EventSeries: async def create_series(self, series: NewEventSeries) -> EventSeries:
"""
Takes a NewEventSeries object and passes it to the database layer for insertion.
"""
try: try:
inserted_id = await self.db.insert_one_series(series) inserted_id = await self.db.insert_one_series(series)
for new_event in series.events: for new_event in series.events:
@@ -66,12 +93,19 @@ class EventController(BaseController):
) )
async def add_series_poster(self, series_id, poster: UploadFile) -> EventSeries: async def add_series_poster(self, series_id, poster: UploadFile) -> EventSeries:
"""
Adds (or updates) a poster image to a series.
Actual image storage is done with Cloudinary and the public ID is stored in the database.
"""
series = await self.get_one_series_by_id(series_id) series = await self.get_one_series_by_id(series_id)
series.poster_id = await self._upload_poster(poster) series.poster_id = await self._upload_poster(poster)
await self.db.update_series_poster(series) await self.db.update_series_poster(series)
return await self.get_one_series_by_id(series.series_id) return await self.get_one_series_by_id(series.series_id)
async def _upload_poster(self, poster: UploadFile) -> str: async def _upload_poster(self, poster: UploadFile) -> str:
"""
Uploads a poster image to Cloudinary and returns the public ID for storage in the database.
"""
image_file = await self.verify_image(poster) image_file = await self.verify_image(poster)
try: try:
data = uploader.upload(image_file) data = uploader.upload(image_file)
@@ -83,12 +117,17 @@ class EventController(BaseController):
) )
async def delete_series(self, id: int) -> None: async def delete_series(self, id: int) -> None:
"""
Ensures an EventSeries object exists and then deletes it from the database
"""
series = await self.get_one_series_by_id(id) series = await self.get_one_series_by_id(id)
await self.db.delete_one_series(series) await self.db.delete_one_series(series)
async def update_series(self, route_id: int, series: EventSeries) -> EventSeries: async def update_series(self, route_id: int, series: EventSeries) -> EventSeries:
"""
Updates an existing EventSeries object in the database.
"""
if route_id != series.series_id: if route_id != series.series_id:
print("error")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="ID in URL does not match ID in request body", detail="ID in URL does not match ID in request body",

View File

@@ -9,6 +9,12 @@ class DBException(Exception):
def connect_db() -> mysql.connector.MySQLConnection: def connect_db() -> mysql.connector.MySQLConnection:
"""
Connects to the MySQL database using credentials from the .env file.
Returns a MySQLConnection object which can be used by the database query layer.
Credential values are validated and an exception is raised if any are missing.
"""
load_dotenv() load_dotenv()
host = os.getenv("DB_HOST") host = os.getenv("DB_HOST")
user = os.getenv("DB_USER") user = os.getenv("DB_USER")

View File

@@ -2,15 +2,9 @@ from asyncio import gather
from icecream import ic from icecream import ic
from app.constants import EVENT_TABLE, SERIES_TABLE
from app.db.base_queries import BaseQueries from app.db.base_queries import BaseQueries
from app.models.event import ( from app.models.event import Event, EventSeries, NewEvent, NewEventSeries
EVENT_TABLE,
SERIES_TABLE,
Event,
EventSeries,
NewEvent,
NewEventSeries,
)
class EventQueries(BaseQueries): class EventQueries(BaseQueries):

View File

@@ -1,5 +1,5 @@
from app.constants import GROUP_TABLE
from app.db.base_queries import BaseQueries from app.db.base_queries import BaseQueries
from app.models.group import GROUP_TABLE
class GroupQueries(BaseQueries): class GroupQueries(BaseQueries):

View File

@@ -1,8 +1,8 @@
from icecream import ic from icecream import ic
from app.constants import MUSICIAN_TABLE
from app.db.base_queries import BaseQueries from app.db.base_queries import BaseQueries
from app.db.conn import connect_db from app.db.conn import connect_db
from app.models.musician import MUSICIAN_TABLE
class MusicianQueries(BaseQueries): class MusicianQueries(BaseQueries):

View File

@@ -1,5 +1,5 @@
from app.constants import USER_TABLE
from app.db.base_queries import BaseQueries from app.db.base_queries import BaseQueries
from app.models.user import USER_TABLE
class UserQueries(BaseQueries): class UserQueries(BaseQueries):

View File

@@ -3,7 +3,7 @@ from asyncio import gather
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.controllers import Controller from app.controllers import MainController
from app.models.tgd import TheGrapefruitsDuo from app.models.tgd import TheGrapefruitsDuo
from app.routers.contact import router as contact_router from app.routers.contact import router as contact_router
from app.routers.events import router as event_router from app.routers.events import router as event_router
@@ -23,7 +23,7 @@ app.include_router(contact_router)
app.include_router(event_router) app.include_router(event_router)
app.include_router(user_router) app.include_router(user_router)
controller = Controller() controller = MainController()
origins = [ origins = [
"http://localhost:3000", "http://localhost:3000",

View File

@@ -30,7 +30,3 @@ class EventSeries(NewEventSeries):
series_id: int series_id: int
events: list[Event] events: list[Event]
poster_id: Optional[str] = None poster_id: Optional[str] = None
SERIES_TABLE = "series"
EVENT_TABLE = "events"

View File

@@ -5,6 +5,3 @@ class Group(BaseModel):
name: str name: str
bio: str bio: str
id: int | None = None id: int | None = None
GROUP_TABLE = "group_table"

View File

@@ -11,6 +11,3 @@ class NewMusician(BaseModel):
class Musician(NewMusician): class Musician(NewMusician):
id: int id: int
MUSICIAN_TABLE = "musicians"

View File

@@ -8,6 +8,3 @@ class User(BaseModel):
email: str email: str
sub: Optional[str] = None sub: Optional[str] = None
id: int | None = None id: int | None = None
USER_TABLE = "users"

View File

@@ -2,11 +2,18 @@ from datetime import datetime
from dotenv import load_dotenv from dotenv import load_dotenv
from app.constants import (
EVENT_TABLE,
GROUP_TABLE,
MUSICIAN_TABLE,
SERIES_TABLE,
USER_TABLE,
)
from app.db.conn import connect_db from app.db.conn import connect_db
from app.models.event import EVENT_TABLE, SERIES_TABLE, Event, EventSeries from app.models.event import Event, EventSeries
from app.models.group import GROUP_TABLE, Group from app.models.group import Group
from app.models.musician import MUSICIAN_TABLE, NewMusician from app.models.musician import NewMusician
from app.models.user import USER_TABLE, User from app.models.user import User
margarite: NewMusician = NewMusician( margarite: NewMusician = NewMusician(
name="Margarite Waddell", name="Margarite Waddell",