Refactor tests for improved consistency and readability
- Updated test cases in `test_auth_endpoints.py` to ensure consistent formatting and style. - Enhanced `test_socket_endpoints.py` with consistent parameter formatting and improved readability. - Cleaned up `conftest.py` by ensuring consistent parameter formatting in fixtures. - Added comprehensive tests for API token dependencies in `test_api_token_dependencies.py`. - Refactored `test_auth_service.py` to maintain consistent parameter formatting. - Cleaned up `test_oauth_service.py` by removing unnecessary imports. - Improved `test_socket_service.py` with consistent formatting and readability. - Enhanced `test_cookies.py` by ensuring consistent formatting and readability. - Introduced new tests for token utilities in `test_token_utils.py` to validate token generation and expiration logic.
This commit is contained in:
@@ -11,14 +11,22 @@ from app.core.config import settings
|
||||
from app.core.dependencies import (
|
||||
get_auth_service,
|
||||
get_current_active_user,
|
||||
get_current_active_user_flexible,
|
||||
get_oauth_service,
|
||||
)
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import UserLoginRequest, UserRegisterRequest, UserResponse
|
||||
from app.schemas.auth import (
|
||||
ApiTokenRequest,
|
||||
ApiTokenResponse,
|
||||
ApiTokenStatusResponse,
|
||||
UserLoginRequest,
|
||||
UserRegisterRequest,
|
||||
UserResponse,
|
||||
)
|
||||
from app.services.auth import AuthService
|
||||
from app.services.oauth import OAuthService
|
||||
from app.utils.auth import JWTUtils
|
||||
from app.utils.auth import JWTUtils, TokenUtils
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_logger(__name__)
|
||||
@@ -131,7 +139,7 @@ async def login(
|
||||
|
||||
@router.get("/me")
|
||||
async def get_current_user_info(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> UserResponse:
|
||||
"""Get current user information."""
|
||||
@@ -426,3 +434,72 @@ async def exchange_oauth_token(
|
||||
user_id = token_data["user_id"]
|
||||
logger.info("OAuth tokens exchanged successfully for user: %s", user_id)
|
||||
return {"message": "Tokens set successfully", "user_id": str(user_id)}
|
||||
|
||||
|
||||
# API Token endpoints
|
||||
@router.post("/api-token")
|
||||
async def generate_api_token(
|
||||
request: ApiTokenRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> ApiTokenResponse:
|
||||
"""Generate a new API token for the current user."""
|
||||
try:
|
||||
api_token = await auth_service.generate_api_token(
|
||||
current_user,
|
||||
expires_days=request.expires_days,
|
||||
)
|
||||
|
||||
# Refresh user to get updated token info
|
||||
await auth_service.session.refresh(current_user)
|
||||
|
||||
return ApiTokenResponse(
|
||||
api_token=api_token,
|
||||
expires_at=current_user.api_token_expires_at,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to generate API token for user: %s", current_user.email,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate API token",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/api-token/status")
|
||||
async def get_api_token_status(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> ApiTokenStatusResponse:
|
||||
"""Get the current user's API token status."""
|
||||
has_token = current_user.api_token is not None
|
||||
is_expired = False
|
||||
|
||||
if has_token and current_user.api_token_expires_at:
|
||||
is_expired = TokenUtils.is_token_expired(current_user.api_token_expires_at)
|
||||
|
||||
return ApiTokenStatusResponse(
|
||||
has_token=has_token,
|
||||
expires_at=current_user.api_token_expires_at,
|
||||
is_expired=is_expired,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/api-token")
|
||||
async def revoke_api_token(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Revoke the current user's API token."""
|
||||
try:
|
||||
await auth_service.revoke_api_token(current_user)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to revoke API token for user: %s", current_user.email,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke API token",
|
||||
) from e
|
||||
else:
|
||||
return {"message": "API token revoked successfully"}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Annotated, cast
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
from fastapi import Cookie, Depends, Header, HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
@@ -10,7 +10,7 @@ from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.services.auth import AuthService
|
||||
from app.services.oauth import OAuthService
|
||||
from app.utils.auth import JWTUtils
|
||||
from app.utils.auth import JWTUtils, TokenUtils
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -92,6 +92,95 @@ async def get_current_active_user(
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_current_user_api_token(
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> User:
|
||||
"""Get the current authenticated user from API token in Authorization header."""
|
||||
try:
|
||||
# Check if Authorization header exists
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required",
|
||||
)
|
||||
|
||||
# Check if it's a Bearer token
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authorization header format",
|
||||
)
|
||||
|
||||
# Extract the API token
|
||||
api_token = authorization[7:] # Remove "Bearer " prefix
|
||||
if not api_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="API token required",
|
||||
)
|
||||
|
||||
# Get the user by API token
|
||||
user = await auth_service.get_user_by_api_token(api_token)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API token",
|
||||
)
|
||||
|
||||
# Check if token is expired
|
||||
if TokenUtils.is_token_expired(user.api_token_expires_at):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="API token has expired",
|
||||
)
|
||||
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to authenticate user with API token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate API token",
|
||||
) from e
|
||||
|
||||
|
||||
async def get_current_user_flexible(
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
access_token: Annotated[str | None, Cookie()] = None,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> User:
|
||||
"""Get the current authenticated user from either JWT cookie or API token."""
|
||||
# Try API token first if Authorization header is present
|
||||
if authorization:
|
||||
return await get_current_user_api_token(auth_service, authorization)
|
||||
|
||||
# Fall back to JWT cookie authentication
|
||||
return await get_current_user(auth_service, access_token)
|
||||
|
||||
|
||||
async def get_current_active_user_flexible(
|
||||
current_user: Annotated[User, Depends(get_current_user_flexible)],
|
||||
) -> User:
|
||||
"""Get the current authenticated and active user using flexible auth."""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_admin_user(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> User:
|
||||
|
||||
@@ -51,3 +51,29 @@ class AuthResponse(BaseModel):
|
||||
|
||||
user: UserResponse = Field(..., description="User information")
|
||||
token: TokenResponse = Field(..., description="Authentication token")
|
||||
|
||||
|
||||
class ApiTokenRequest(BaseModel):
|
||||
"""Schema for API token generation request."""
|
||||
|
||||
expires_days: int = Field(
|
||||
default=365,
|
||||
ge=1,
|
||||
le=3650,
|
||||
description="Number of days until token expires (1-3650 days)",
|
||||
)
|
||||
|
||||
|
||||
class ApiTokenResponse(BaseModel):
|
||||
"""Schema for API token response."""
|
||||
|
||||
api_token: str = Field(..., description="Generated API token")
|
||||
expires_at: datetime = Field(..., description="Token expiration timestamp")
|
||||
|
||||
|
||||
class ApiTokenStatusResponse(BaseModel):
|
||||
"""Schema for API token status response."""
|
||||
|
||||
has_token: bool = Field(..., description="Whether user has an active API token")
|
||||
expires_at: datetime | None = Field(None, description="Token expiration timestamp")
|
||||
is_expired: bool = Field(..., description="Whether the token is expired")
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.schemas.auth import (
|
||||
UserResponse,
|
||||
)
|
||||
from app.services.oauth import OAuthUserInfo
|
||||
from app.utils.auth import JWTUtils, PasswordUtils
|
||||
from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -123,6 +123,37 @@ class AuthService:
|
||||
|
||||
return user
|
||||
|
||||
async def get_user_by_api_token(self, api_token: str) -> User | None:
|
||||
"""Get a user by their API token."""
|
||||
return await self.user_repo.get_by_api_token(api_token)
|
||||
|
||||
async def generate_api_token(self, user: User, expires_days: int = 365) -> str:
|
||||
"""Generate a new API token for a user."""
|
||||
# Generate a secure random token
|
||||
api_token = TokenUtils.generate_api_token()
|
||||
|
||||
# Set expiration date
|
||||
expires_at = datetime.now(UTC) + timedelta(days=expires_days)
|
||||
|
||||
# Update user with new API token
|
||||
update_data = {
|
||||
"api_token": api_token,
|
||||
"api_token_expires_at": expires_at,
|
||||
}
|
||||
await self.user_repo.update(user, update_data)
|
||||
|
||||
logger.info("Generated new API token for user: %s", user.email)
|
||||
return api_token
|
||||
|
||||
async def revoke_api_token(self, user: User) -> None:
|
||||
"""Revoke a user's API token."""
|
||||
update_data = {
|
||||
"api_token": None,
|
||||
"api_token_expires_at": None,
|
||||
}
|
||||
await self.user_repo.update(user, update_data)
|
||||
logger.info("Revoked API token for user: %s", user.email)
|
||||
|
||||
def _create_access_token(self, user: User) -> TokenResponse:
|
||||
"""Create an access token for a user."""
|
||||
access_token_expires = timedelta(
|
||||
|
||||
@@ -104,9 +104,8 @@ class SocketManager:
|
||||
await self.sio.emit(event, data, room=room_id)
|
||||
logger.debug(f"Sent {event} to user {user_id} in room {room_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"User {user_id} not found in any room")
|
||||
return False
|
||||
logger.warning(f"User {user_id} not found in any room")
|
||||
return False
|
||||
|
||||
async def broadcast_to_all(self, event: str, data: dict):
|
||||
"""Broadcast a message to all connected users."""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Cookie parsing utilities for WebSocket authentication."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def parse_cookies(cookie_header: str) -> dict[str, str]:
|
||||
@@ -18,7 +17,7 @@ def parse_cookies(cookie_header: str) -> dict[str, str]:
|
||||
return cookies
|
||||
|
||||
|
||||
def extract_access_token_from_cookies(cookie_header: str) -> Optional[str]:
|
||||
def extract_access_token_from_cookies(cookie_header: str) -> str | None:
|
||||
"""Extract access token from HTTP cookies."""
|
||||
cookies = parse_cookies(cookie_header)
|
||||
return cookies.get("access_token")
|
||||
|
||||
Reference in New Issue
Block a user