Files
sdb2-backend/app/repositories/user.py
JSC 5ed19c8f0f Add comprehensive tests for playlist service and refactor socket service tests
- Introduced a new test suite for the PlaylistService covering various functionalities including creation, retrieval, updating, and deletion of playlists.
- Added tests for handling sounds within playlists, ensuring correct behavior when adding/removing sounds and managing current playlists.
- Refactored socket service tests for improved readability by adjusting function signatures.
- Cleaned up unnecessary whitespace in sound normalizer and sound scanner tests for consistency.
- Enhanced audio utility tests to ensure accurate hash and size calculations, including edge cases for nonexistent files.
- Removed redundant blank lines in cookie utility tests for cleaner code.
2025-07-29 19:25:46 +02:00

136 lines
4.6 KiB
Python

"""User repository."""
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.plan import Plan
from app.models.user import User
logger = get_logger(__name__)
class UserRepository:
"""Repository for user operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the user repository."""
self.session = session
async def get_by_id(self, user_id: int) -> User | None:
"""Get a user by ID."""
try:
statement = select(User).where(User.id == user_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by ID: %s", user_id)
raise
async def get_by_email(self, email: str) -> User | None:
"""Get a user by email address."""
try:
statement = select(User).where(User.email == email)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by email: %s", email)
raise
async def get_by_api_token(self, api_token: str) -> User | None:
"""Get a user by API token."""
try:
statement = select(User).where(User.api_token == api_token)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by API token")
raise
async def create(self, user_data: dict[str, Any]) -> User:
"""Create a new user."""
def _raise_plan_not_found() -> None:
msg = "Default plan not found"
raise ValueError(msg)
try:
# Check if this is the first user
user_count_statement = select(User)
user_count_result = await self.session.exec(user_count_statement)
is_first_user = user_count_result.first() is None
if is_first_user:
# First user gets admin role and pro plan
plan_statement = select(Plan).where(Plan.code == "pro")
user_data["role"] = "admin"
logger.info("Creating first user with admin role and pro plan")
else:
# Regular users get free plan
plan_statement = select(Plan).where(Plan.code == "free")
plan_result = await self.session.exec(plan_statement)
default_plan = plan_result.first()
if default_plan is None:
_raise_plan_not_found()
# Type assertion to help type checker understand default_plan is not None
assert default_plan is not None # noqa: S101
# Set plan_id and default credits
user_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits
user = User(**user_data)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user)
except Exception:
await self.session.rollback()
logger.exception("Failed to create user")
raise
else:
logger.info("Created new user with email: %s", user.email)
return user
async def update(self, user: User, update_data: dict[str, Any]) -> User:
"""Update a user."""
try:
for field, value in update_data.items():
setattr(user, field, value)
await self.session.commit()
await self.session.refresh(user)
except Exception:
await self.session.rollback()
logger.exception("Failed to update user")
raise
else:
logger.info("Updated user: %s", user.email)
return user
async def delete(self, user: User) -> None:
"""Delete a user."""
try:
await self.session.delete(user)
await self.session.commit()
logger.info("Deleted user: %s", user.email)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete user")
raise
async def email_exists(self, email: str) -> bool:
"""Check if an email address is already registered."""
try:
statement = select(User).where(User.email == email)
result = await self.session.exec(statement)
return result.first() is not None
except Exception:
logger.exception("Failed to check if email exists: %s", email)
raise