"""FastAPI dependencies.""" from typing import Annotated, cast from fastapi import Cookie, Depends, Header, HTTPException, status from sqlmodel.ext.asyncio.session import AsyncSession from app.core.database import get_db from app.core.logging import get_logger from app.models.user import User from app.repositories.sound import SoundRepository from app.repositories.user import UserRepository from app.services.auth import AuthService from app.services.dashboard import DashboardService from app.services.oauth import OAuthService from app.utils.auth import JWTUtils, TokenUtils logger = get_logger(__name__) async def get_auth_service( session: Annotated[AsyncSession, Depends(get_db)], ) -> AuthService: """Get the authentication service.""" return AuthService(session) async def get_oauth_service( session: Annotated[AsyncSession, Depends(get_db)], ) -> OAuthService: """Get the OAuth service.""" return OAuthService(session) async def get_current_user( auth_service: Annotated[AuthService, Depends(get_auth_service)], access_token: Annotated[str | None, Cookie()] = None, ) -> User: """Get the current authenticated user from JWT token in HTTP-only cookie.""" try: # Check if access token cookie exists if not access_token: logger.warning("No access token cookie found") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", ) # Decode the JWT token payload = JWTUtils.decode_access_token(access_token) # Extract user ID from token user_id_str = payload.get("sub") if not user_id_str: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload", ) # At this point user_id_str is guaranteed to be truthy, safe to cast user_id_str = cast("str", user_id_str) try: user_id = int(user_id_str) except (ValueError, TypeError) as e: logger.warning("Invalid user ID in token: %s", user_id_str) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload", ) from e # Get the user return await auth_service.get_current_user(user_id) except HTTPException: # Re-raise HTTPExceptions without wrapping them raise except Exception as e: logger.exception("Failed to authenticate user") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", ) from e async def get_current_active_user( current_user: Annotated[User, Depends(get_current_user)], ) -> User: """Get the current authenticated and active user.""" if not current_user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is deactivated", ) return current_user async def get_current_user_api_token( auth_service: Annotated[AuthService, Depends(get_auth_service)], api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None, ) -> User: """Get the current authenticated user from API token in API-TOKEN header.""" try: # Check if API-TOKEN header exists if not api_token_header: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="API-TOKEN header required", ) # Use the API token directly api_token = api_token_header.strip() if not api_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="API token required", ) # Get the user by API token user = await auth_service.get_user_by_api_token(api_token) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API token", ) # Check if token is expired if TokenUtils.is_token_expired(user.api_token_expires_at): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="API token has expired", ) # Check if user is active if not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is deactivated", ) except HTTPException: # Re-raise HTTPExceptions without wrapping them raise except Exception as e: logger.exception("Failed to authenticate user with API token") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate API token", ) from e else: return user async def get_current_user_flexible( auth_service: Annotated[AuthService, Depends(get_auth_service)], access_token: Annotated[str | None, Cookie()] = None, api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None, ) -> User: """Get the current authenticated user from either JWT cookie or API token.""" # Try API token first if API-TOKEN header is present if api_token_header: return await get_current_user_api_token(auth_service, api_token_header) # Fall back to JWT cookie authentication return await get_current_user(auth_service, access_token) async def get_current_active_user_flexible( current_user: Annotated[User, Depends(get_current_user_flexible)], ) -> User: """Get the current authenticated and active user using flexible auth.""" if not current_user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is deactivated", ) return current_user async def get_admin_user( current_user: Annotated[User, Depends(get_current_active_user)], ) -> User: """Get the current authenticated admin user.""" if current_user.role not in ["admin", "superadmin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions", ) return current_user async def get_dashboard_service( session: Annotated[AsyncSession, Depends(get_db)], ) -> DashboardService: """Get the dashboard service.""" sound_repository = SoundRepository(session) user_repository = UserRepository(session) return DashboardService(sound_repository, user_repository)