diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py index 4b4ff00..6ce1391 100644 --- a/app/api/v1/auth.py +++ b/app/api/v1/auth.py @@ -11,14 +11,22 @@ from app.core.config import settings from app.core.dependencies import ( get_auth_service, get_current_active_user, + get_current_active_user_flexible, get_oauth_service, ) from app.core.logging import get_logger from app.models.user import User -from app.schemas.auth import UserLoginRequest, UserRegisterRequest, UserResponse +from app.schemas.auth import ( + ApiTokenRequest, + ApiTokenResponse, + ApiTokenStatusResponse, + UserLoginRequest, + UserRegisterRequest, + UserResponse, +) from app.services.auth import AuthService from app.services.oauth import OAuthService -from app.utils.auth import JWTUtils +from app.utils.auth import JWTUtils, TokenUtils router = APIRouter() logger = get_logger(__name__) @@ -131,7 +139,7 @@ async def login( @router.get("/me") async def get_current_user_info( - current_user: Annotated[User, Depends(get_current_active_user)], + current_user: Annotated[User, Depends(get_current_active_user_flexible)], auth_service: Annotated[AuthService, Depends(get_auth_service)], ) -> UserResponse: """Get current user information.""" @@ -426,3 +434,72 @@ async def exchange_oauth_token( user_id = token_data["user_id"] logger.info("OAuth tokens exchanged successfully for user: %s", user_id) return {"message": "Tokens set successfully", "user_id": str(user_id)} + + +# API Token endpoints +@router.post("/api-token") +async def generate_api_token( + request: ApiTokenRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + auth_service: Annotated[AuthService, Depends(get_auth_service)], +) -> ApiTokenResponse: + """Generate a new API token for the current user.""" + try: + api_token = await auth_service.generate_api_token( + current_user, + expires_days=request.expires_days, + ) + + # Refresh user to get updated token info + await auth_service.session.refresh(current_user) + + return ApiTokenResponse( + api_token=api_token, + expires_at=current_user.api_token_expires_at, + ) + except Exception as e: + logger.exception( + "Failed to generate API token for user: %s", current_user.email, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to generate API token", + ) from e + + +@router.get("/api-token/status") +async def get_api_token_status( + current_user: Annotated[User, Depends(get_current_active_user)], +) -> ApiTokenStatusResponse: + """Get the current user's API token status.""" + has_token = current_user.api_token is not None + is_expired = False + + if has_token and current_user.api_token_expires_at: + is_expired = TokenUtils.is_token_expired(current_user.api_token_expires_at) + + return ApiTokenStatusResponse( + has_token=has_token, + expires_at=current_user.api_token_expires_at, + is_expired=is_expired, + ) + + +@router.delete("/api-token") +async def revoke_api_token( + current_user: Annotated[User, Depends(get_current_active_user)], + auth_service: Annotated[AuthService, Depends(get_auth_service)], +) -> dict[str, str]: + """Revoke the current user's API token.""" + try: + await auth_service.revoke_api_token(current_user) + except Exception as e: + logger.exception( + "Failed to revoke API token for user: %s", current_user.email, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to revoke API token", + ) from e + else: + return {"message": "API token revoked successfully"} diff --git a/app/core/dependencies.py b/app/core/dependencies.py index 777b69f..c26e11e 100644 --- a/app/core/dependencies.py +++ b/app/core/dependencies.py @@ -2,7 +2,7 @@ from typing import Annotated, cast -from fastapi import Cookie, Depends, HTTPException, status +from fastapi import Cookie, Depends, Header, HTTPException, status from sqlmodel.ext.asyncio.session import AsyncSession from app.core.database import get_db @@ -10,7 +10,7 @@ from app.core.logging import get_logger from app.models.user import User from app.services.auth import AuthService from app.services.oauth import OAuthService -from app.utils.auth import JWTUtils +from app.utils.auth import JWTUtils, TokenUtils logger = get_logger(__name__) @@ -92,6 +92,95 @@ async def get_current_active_user( return current_user +async def get_current_user_api_token( + auth_service: Annotated[AuthService, Depends(get_auth_service)], + authorization: Annotated[str | None, Header()] = None, +) -> User: + """Get the current authenticated user from API token in Authorization header.""" + try: + # Check if Authorization header exists + if not authorization: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authorization header required", + ) + + # Check if it's a Bearer token + if not authorization.startswith("Bearer "): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authorization header format", + ) + + # Extract the API token + api_token = authorization[7:] # Remove "Bearer " prefix + if not api_token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API token required", + ) + + # Get the user by API token + user = await auth_service.get_user_by_api_token(api_token) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API token", + ) + + # Check if token is expired + if TokenUtils.is_token_expired(user.api_token_expires_at): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API token has expired", + ) + + # Check if user is active + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Account is deactivated", + ) + + return user + + except HTTPException: + # Re-raise HTTPExceptions without wrapping them + raise + except Exception as e: + logger.exception("Failed to authenticate user with API token") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate API token", + ) from e + + +async def get_current_user_flexible( + auth_service: Annotated[AuthService, Depends(get_auth_service)], + access_token: Annotated[str | None, Cookie()] = None, + authorization: Annotated[str | None, Header()] = None, +) -> User: + """Get the current authenticated user from either JWT cookie or API token.""" + # Try API token first if Authorization header is present + if authorization: + return await get_current_user_api_token(auth_service, authorization) + + # Fall back to JWT cookie authentication + return await get_current_user(auth_service, access_token) + + +async def get_current_active_user_flexible( + current_user: Annotated[User, Depends(get_current_user_flexible)], +) -> User: + """Get the current authenticated and active user using flexible auth.""" + if not current_user.is_active: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Account is deactivated", + ) + return current_user + + async def get_admin_user( current_user: Annotated[User, Depends(get_current_active_user)], ) -> User: diff --git a/app/schemas/auth.py b/app/schemas/auth.py index cc9455c..bb38a9c 100644 --- a/app/schemas/auth.py +++ b/app/schemas/auth.py @@ -51,3 +51,29 @@ class AuthResponse(BaseModel): user: UserResponse = Field(..., description="User information") token: TokenResponse = Field(..., description="Authentication token") + + +class ApiTokenRequest(BaseModel): + """Schema for API token generation request.""" + + expires_days: int = Field( + default=365, + ge=1, + le=3650, + description="Number of days until token expires (1-3650 days)", + ) + + +class ApiTokenResponse(BaseModel): + """Schema for API token response.""" + + api_token: str = Field(..., description="Generated API token") + expires_at: datetime = Field(..., description="Token expiration timestamp") + + +class ApiTokenStatusResponse(BaseModel): + """Schema for API token status response.""" + + has_token: bool = Field(..., description="Whether user has an active API token") + expires_at: datetime | None = Field(None, description="Token expiration timestamp") + is_expired: bool = Field(..., description="Whether the token is expired") diff --git a/app/services/auth.py b/app/services/auth.py index e003f19..f3422dd 100644 --- a/app/services/auth.py +++ b/app/services/auth.py @@ -19,7 +19,7 @@ from app.schemas.auth import ( UserResponse, ) from app.services.oauth import OAuthUserInfo -from app.utils.auth import JWTUtils, PasswordUtils +from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils logger = get_logger(__name__) @@ -123,6 +123,37 @@ class AuthService: return user + async def get_user_by_api_token(self, api_token: str) -> User | None: + """Get a user by their API token.""" + return await self.user_repo.get_by_api_token(api_token) + + async def generate_api_token(self, user: User, expires_days: int = 365) -> str: + """Generate a new API token for a user.""" + # Generate a secure random token + api_token = TokenUtils.generate_api_token() + + # Set expiration date + expires_at = datetime.now(UTC) + timedelta(days=expires_days) + + # Update user with new API token + update_data = { + "api_token": api_token, + "api_token_expires_at": expires_at, + } + await self.user_repo.update(user, update_data) + + logger.info("Generated new API token for user: %s", user.email) + return api_token + + async def revoke_api_token(self, user: User) -> None: + """Revoke a user's API token.""" + update_data = { + "api_token": None, + "api_token_expires_at": None, + } + await self.user_repo.update(user, update_data) + logger.info("Revoked API token for user: %s", user.email) + def _create_access_token(self, user: User) -> TokenResponse: """Create an access token for a user.""" access_token_expires = timedelta( diff --git a/app/services/socket.py b/app/services/socket.py index 929b7c5..74522ab 100644 --- a/app/services/socket.py +++ b/app/services/socket.py @@ -104,9 +104,8 @@ class SocketManager: await self.sio.emit(event, data, room=room_id) logger.debug(f"Sent {event} to user {user_id} in room {room_id}") return True - else: - logger.warning(f"User {user_id} not found in any room") - return False + logger.warning(f"User {user_id} not found in any room") + return False async def broadcast_to_all(self, event: str, data: dict): """Broadcast a message to all connected users.""" diff --git a/app/utils/cookies.py b/app/utils/cookies.py index b49fef5..b428bb2 100644 --- a/app/utils/cookies.py +++ b/app/utils/cookies.py @@ -1,6 +1,5 @@ """Cookie parsing utilities for WebSocket authentication.""" -from typing import Optional def parse_cookies(cookie_header: str) -> dict[str, str]: @@ -18,7 +17,7 @@ def parse_cookies(cookie_header: str) -> dict[str, str]: return cookies -def extract_access_token_from_cookies(cookie_header: str) -> Optional[str]: +def extract_access_token_from_cookies(cookie_header: str) -> str | None: """Extract access token from HTTP cookies.""" cookies = parse_cookies(cookie_header) return cookies.get("access_token") diff --git a/tests/api/v1/test_api_token_endpoints.py b/tests/api/v1/test_api_token_endpoints.py new file mode 100644 index 0000000..d512ed8 --- /dev/null +++ b/tests/api/v1/test_api_token_endpoints.py @@ -0,0 +1,343 @@ +"""Tests for API token endpoints.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import patch + +import pytest +from httpx import AsyncClient + +from app.models.user import User + + +class TestApiTokenEndpoints: + """Test API token management endpoints.""" + + @pytest.mark.asyncio + async def test_generate_api_token_success( + self, authenticated_client: AsyncClient, authenticated_user: User, + ): + """Test successful API token generation.""" + request_data = {"expires_days": 30} + + response = await authenticated_client.post( + "/api/v1/auth/api-token", + json=request_data, + ) + + assert response.status_code == 200 + data = response.json() + + assert "api_token" in data + assert "expires_at" in data + assert len(data["api_token"]) > 0 + + # Verify token format (should be URL-safe base64) + import base64 + try: + base64.urlsafe_b64decode(data["api_token"] + "===") # Add padding + except Exception: + pytest.fail("API token should be valid URL-safe base64") + + @pytest.mark.asyncio + async def test_generate_api_token_default_expiry( + self, authenticated_client: AsyncClient, + ): + """Test API token generation with default expiry.""" + response = await authenticated_client.post("/api/v1/auth/api-token", json={}) + + assert response.status_code == 200 + data = response.json() + + expires_at_str = data["expires_at"] + # Handle both ISO format with/without timezone info + if expires_at_str.endswith("Z"): + expires_at = datetime.fromisoformat(expires_at_str.replace("Z", "+00:00")) + elif "+" in expires_at_str or expires_at_str.count("-") > 2: + expires_at = datetime.fromisoformat(expires_at_str) + else: + # Naive datetime, assume UTC + expires_at = datetime.fromisoformat(expires_at_str).replace(tzinfo=UTC) + + expected_expiry = datetime.now(UTC) + timedelta(days=365) + + # Allow 1 minute tolerance + assert abs((expires_at - expected_expiry).total_seconds()) < 60 + + @pytest.mark.asyncio + async def test_generate_api_token_custom_expiry( + self, authenticated_client: AsyncClient, + ): + """Test API token generation with custom expiry.""" + expires_days = 90 + request_data = {"expires_days": expires_days} + + response = await authenticated_client.post( + "/api/v1/auth/api-token", + json=request_data, + ) + + assert response.status_code == 200 + data = response.json() + + expires_at_str = data["expires_at"] + # Handle both ISO format with/without timezone info + if expires_at_str.endswith("Z"): + expires_at = datetime.fromisoformat(expires_at_str.replace("Z", "+00:00")) + elif "+" in expires_at_str or expires_at_str.count("-") > 2: + expires_at = datetime.fromisoformat(expires_at_str) + else: + # Naive datetime, assume UTC + expires_at = datetime.fromisoformat(expires_at_str).replace(tzinfo=UTC) + + expected_expiry = datetime.now(UTC) + timedelta(days=expires_days) + + # Allow 1 minute tolerance + assert abs((expires_at - expected_expiry).total_seconds()) < 60 + + @pytest.mark.asyncio + async def test_generate_api_token_validation_errors( + self, authenticated_client: AsyncClient, + ): + """Test API token generation with validation errors.""" + # Test minimum validation + response = await authenticated_client.post( + "/api/v1/auth/api-token", + json={"expires_days": 0}, + ) + assert response.status_code == 422 + + # Test maximum validation + response = await authenticated_client.post( + "/api/v1/auth/api-token", + json={"expires_days": 4000}, + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_generate_api_token_unauthenticated(self, client: AsyncClient): + """Test API token generation without authentication.""" + response = await client.post( + "/api/v1/auth/api-token", + json={"expires_days": 30}, + ) + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_get_api_token_status_no_token( + self, authenticated_client: AsyncClient, + ): + """Test getting API token status when user has no token.""" + response = await authenticated_client.get("/api/v1/auth/api-token/status") + + assert response.status_code == 200 + data = response.json() + + assert data["has_token"] is False + assert data["expires_at"] is None + assert data["is_expired"] is False + + @pytest.mark.asyncio + async def test_get_api_token_status_with_token( + self, authenticated_client: AsyncClient, + ): + """Test getting API token status when user has a token.""" + # First generate a token + await authenticated_client.post( + "/api/v1/auth/api-token", + json={"expires_days": 30}, + ) + + # Then check status + response = await authenticated_client.get("/api/v1/auth/api-token/status") + + assert response.status_code == 200 + data = response.json() + + assert data["has_token"] is True + assert data["expires_at"] is not None + assert data["is_expired"] is False + + @pytest.mark.asyncio + async def test_get_api_token_status_expired_token( + self, authenticated_client: AsyncClient, authenticated_user: User, + ): + """Test getting API token status with expired token.""" + # Mock expired token + with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True): + # Set a token on the user + authenticated_user.api_token = "expired_token" + authenticated_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1) + + response = await authenticated_client.get("/api/v1/auth/api-token/status") + + assert response.status_code == 200 + data = response.json() + + assert data["has_token"] is True + assert data["expires_at"] is not None + assert data["is_expired"] is True + + @pytest.mark.asyncio + async def test_get_api_token_status_unauthenticated(self, client: AsyncClient): + """Test getting API token status without authentication.""" + response = await client.get("/api/v1/auth/api-token/status") + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_revoke_api_token_success( + self, authenticated_client: AsyncClient, + ): + """Test successful API token revocation.""" + # First generate a token + await authenticated_client.post( + "/api/v1/auth/api-token", + json={"expires_days": 30}, + ) + + # Verify token exists + status_response = await authenticated_client.get("/api/v1/auth/api-token/status") + assert status_response.json()["has_token"] is True + + # Revoke the token + response = await authenticated_client.delete("/api/v1/auth/api-token") + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "API token revoked successfully" + + # Verify token is gone + status_response = await authenticated_client.get("/api/v1/auth/api-token/status") + assert status_response.json()["has_token"] is False + + @pytest.mark.asyncio + async def test_revoke_api_token_no_token( + self, authenticated_client: AsyncClient, + ): + """Test revoking API token when user has no token.""" + response = await authenticated_client.delete("/api/v1/auth/api-token") + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "API token revoked successfully" + + @pytest.mark.asyncio + async def test_revoke_api_token_unauthenticated(self, client: AsyncClient): + """Test revoking API token without authentication.""" + response = await client.delete("/api/v1/auth/api-token") + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_api_token_authentication_success( + self, client: AsyncClient, authenticated_client: AsyncClient, + ): + """Test successful authentication using API token.""" + # Generate API token + token_response = await authenticated_client.post( + "/api/v1/auth/api-token", + json={"expires_days": 30}, + ) + api_token = token_response.json()["api_token"] + + # Use API token to authenticate + headers = {"Authorization": f"Bearer {api_token}"} + response = await client.get("/api/v1/auth/me", headers=headers) + + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "email" in data + + @pytest.mark.asyncio + async def test_api_token_authentication_invalid_token(self, client: AsyncClient): + """Test authentication with invalid API token.""" + headers = {"Authorization": "Bearer invalid_token"} + response = await client.get("/api/v1/auth/me", headers=headers) + + assert response.status_code == 401 + data = response.json() + assert "Invalid API token" in data["detail"] + + @pytest.mark.asyncio + async def test_api_token_authentication_expired_token( + self, client: AsyncClient, authenticated_client: AsyncClient, + ): + """Test authentication with expired API token.""" + # Generate API token + token_response = await authenticated_client.post( + "/api/v1/auth/api-token", + json={"expires_days": 30}, + ) + api_token = token_response.json()["api_token"] + + # Mock expired token + with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True): + headers = {"Authorization": f"Bearer {api_token}"} + response = await client.get("/api/v1/auth/me", headers=headers) + + assert response.status_code == 401 + data = response.json() + assert "API token has expired" in data["detail"] + + @pytest.mark.asyncio + async def test_api_token_authentication_malformed_header(self, client: AsyncClient): + """Test authentication with malformed Authorization header.""" + # Missing Bearer prefix + headers = {"Authorization": "invalid_format"} + response = await client.get("/api/v1/auth/me", headers=headers) + + assert response.status_code == 401 + data = response.json() + assert "Invalid authorization header format" in data["detail"] + + # Empty token + headers = {"Authorization": "Bearer "} + response = await client.get("/api/v1/auth/me", headers=headers) + + assert response.status_code == 401 + data = response.json() + assert "API token required" in data["detail"] + + @pytest.mark.asyncio + async def test_api_token_authentication_inactive_user( + self, client: AsyncClient, authenticated_client: AsyncClient, authenticated_user: User, + ): + """Test authentication with API token for inactive user.""" + # Generate API token + token_response = await authenticated_client.post( + "/api/v1/auth/api-token", + json={"expires_days": 30}, + ) + api_token = token_response.json()["api_token"] + + # Deactivate user + authenticated_user.is_active = False + + # Try to authenticate with API token + headers = {"Authorization": f"Bearer {api_token}"} + response = await client.get("/api/v1/auth/me", headers=headers) + + assert response.status_code == 401 + data = response.json() + assert "Account is deactivated" in data["detail"] + + @pytest.mark.asyncio + async def test_flexible_authentication_prefers_api_token( + self, client: AsyncClient, authenticated_client: AsyncClient, auth_cookies: dict[str, str], + ): + """Test that flexible authentication prefers API token over cookie.""" + # Generate API token + token_response = await authenticated_client.post( + "/api/v1/auth/api-token", + json={"expires_days": 30}, + ) + api_token = token_response.json()["api_token"] + + # Set both cookies and Authorization header + client.cookies.update(auth_cookies) + headers = {"Authorization": f"Bearer {api_token}"} + + # This should use API token authentication + response = await client.get("/api/v1/auth/me", headers=headers) + + assert response.status_code == 200 + # If it used API token auth, it should work even if cookies are invalid diff --git a/tests/api/v1/test_auth_endpoints.py b/tests/api/v1/test_auth_endpoints.py index ad1df4c..2fcffde 100644 --- a/tests/api/v1/test_auth_endpoints.py +++ b/tests/api/v1/test_auth_endpoints.py @@ -73,7 +73,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_register_duplicate_email( - self, test_client: AsyncClient, test_user: User + self, test_client: AsyncClient, test_user: User, ) -> None: """Test registration with duplicate email.""" user_data = { @@ -118,7 +118,7 @@ class TestAuthEndpoints: async def test_register_missing_fields(self, test_client: AsyncClient) -> None: """Test registration with missing fields.""" user_data = { - "email": "test@example.com" + "email": "test@example.com", # Missing password and name } @@ -128,7 +128,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_login_success( - self, test_client: AsyncClient, test_user: User, test_login_data: dict[str, str] + self, test_client: AsyncClient, test_user: User, test_login_data: dict[str, str], ) -> None: """Test successful user login.""" response = await test_client.post("/api/v1/auth/login", json=test_login_data) @@ -161,7 +161,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_login_invalid_password( - self, test_client: AsyncClient, test_user: User + self, test_client: AsyncClient, test_user: User, ) -> None: """Test login with invalid password.""" login_data = {"email": test_user.email, "password": "wrongpassword"} @@ -183,7 +183,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_get_current_user_success( - self, test_client: AsyncClient, test_user: User, auth_cookies: dict[str, str] + self, test_client: AsyncClient, test_user: User, auth_cookies: dict[str, str], ) -> None: """Test getting current user info successfully.""" # Set cookies on client instance to avoid deprecation warning @@ -210,7 +210,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_get_current_user_invalid_token( - self, test_client: AsyncClient + self, test_client: AsyncClient, ) -> None: """Test getting current user with invalid token.""" # Set invalid cookies on client instance @@ -223,7 +223,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_get_current_user_expired_token( - self, test_client: AsyncClient, test_user: User + self, test_client: AsyncClient, test_user: User, ) -> None: """Test getting current user with expired token.""" from datetime import timedelta @@ -237,7 +237,7 @@ class TestAuthEndpoints: "role": "user", } expired_token = JWTUtils.create_access_token( - token_data, expires_delta=timedelta(seconds=-1) + token_data, expires_delta=timedelta(seconds=-1), ) # Set expired cookies on client instance @@ -262,7 +262,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_admin_access_with_user_role( - self, test_client: AsyncClient, auth_cookies: dict[str, str] + self, test_client: AsyncClient, auth_cookies: dict[str, str], ) -> None: """Test that regular users cannot access admin endpoints.""" # This test would be for admin-only endpoints when they're created @@ -293,7 +293,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_admin_access_with_admin_role( - self, test_client: AsyncClient, admin_cookies: dict[str, str] + self, test_client: AsyncClient, admin_cookies: dict[str, str], ) -> None: """Test that admin users can access admin endpoints.""" from app.core.dependencies import get_admin_user @@ -357,7 +357,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_oauth_authorize_invalid_provider( - self, test_client: AsyncClient + self, test_client: AsyncClient, ) -> None: """Test OAuth authorization with invalid provider.""" response = await test_client.get("/api/v1/auth/invalid/authorize") @@ -368,7 +368,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_oauth_callback_new_user( - self, test_client: AsyncClient, ensure_plans: tuple[Any, Any] + self, test_client: AsyncClient, ensure_plans: tuple[Any, Any], ) -> None: """Test OAuth callback for new user creation.""" # Mock OAuth user info @@ -400,7 +400,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_oauth_callback_existing_user_link( - self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any] + self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any], ) -> None: """Test OAuth callback for linking to existing user.""" # Mock OAuth user info with same email as test user @@ -442,7 +442,7 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_oauth_callback_invalid_provider( - self, test_client: AsyncClient + self, test_client: AsyncClient, ) -> None: """Test OAuth callback with invalid provider.""" response = await test_client.get( diff --git a/tests/api/v1/test_socket_endpoints.py b/tests/api/v1/test_socket_endpoints.py index a7d25c8..c83f419 100644 --- a/tests/api/v1/test_socket_endpoints.py +++ b/tests/api/v1/test_socket_endpoints.py @@ -1,8 +1,9 @@ """Tests for socket API endpoints.""" +from unittest.mock import AsyncMock, patch + import pytest from httpx import AsyncClient -from unittest.mock import AsyncMock, patch from app.models.user import User @@ -10,7 +11,7 @@ from app.models.user import User @pytest.fixture def mock_socket_manager(): """Mock socket manager for testing.""" - with patch('app.api.v1.socket.socket_manager') as mock: + with patch("app.api.v1.socket.socket_manager") as mock: mock.get_connected_users.return_value = ["1", "2", "3"] mock.send_to_user = AsyncMock(return_value=True) mock.broadcast_to_all = AsyncMock() @@ -24,10 +25,10 @@ class TestSocketEndpoints: async def test_get_socket_status_authenticated(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): """Test getting socket status for authenticated user.""" response = await authenticated_client.get("/api/v1/socket/status") - + assert response.status_code == 200 data = response.json() - + assert "connected" in data assert "user_id" in data assert "total_connected" in data @@ -46,19 +47,19 @@ class TestSocketEndpoints: """Test sending message to specific user successfully.""" target_user_id = 2 message = "Hello there!" - + response = await authenticated_client.post( "/api/v1/socket/send-message", - params={"target_user_id": target_user_id, "message": message} + params={"target_user_id": target_user_id, "message": message}, ) - + assert response.status_code == 200 data = response.json() - + assert data["success"] is True assert data["target_user_id"] == target_user_id assert data["message"] == "Message sent" - + # Verify socket manager was called correctly mock_socket_manager.send_to_user.assert_called_once_with( str(target_user_id), @@ -67,7 +68,7 @@ class TestSocketEndpoints: "from_user_id": authenticated_user.id, "from_user_name": authenticated_user.name, "message": message, - } + }, ) @pytest.mark.asyncio @@ -75,18 +76,18 @@ class TestSocketEndpoints: """Test sending message to user who is not connected.""" target_user_id = 999 message = "Hello there!" - + # Mock user not connected mock_socket_manager.send_to_user.return_value = False - + response = await authenticated_client.post( "/api/v1/socket/send-message", - params={"target_user_id": target_user_id, "message": message} + params={"target_user_id": target_user_id, "message": message}, ) - + assert response.status_code == 200 data = response.json() - + assert data["success"] is False assert data["target_user_id"] == target_user_id assert data["message"] == "User not connected" @@ -96,7 +97,7 @@ class TestSocketEndpoints: """Test sending message without authentication.""" response = await client.post( "/api/v1/socket/send-message", - params={"target_user_id": 1, "message": "test"} + params={"target_user_id": 1, "message": "test"}, ) assert response.status_code == 401 @@ -104,18 +105,18 @@ class TestSocketEndpoints: async def test_broadcast_message_success(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): """Test broadcasting message to all users successfully.""" message = "Important announcement!" - + response = await authenticated_client.post( "/api/v1/socket/broadcast", - params={"message": message} + params={"message": message}, ) - + assert response.status_code == 200 data = response.json() - + assert data["success"] is True assert data["message"] == "Message broadcasted to all users" - + # Verify socket manager was called correctly mock_socket_manager.broadcast_to_all.assert_called_once_with( "broadcast_message", @@ -123,7 +124,7 @@ class TestSocketEndpoints: "from_user_id": authenticated_user.id, "from_user_name": authenticated_user.name, "message": message, - } + }, ) @pytest.mark.asyncio @@ -131,7 +132,7 @@ class TestSocketEndpoints: """Test broadcasting message without authentication.""" response = await client.post( "/api/v1/socket/broadcast", - params={"message": "test"} + params={"message": "test"}, ) assert response.status_code == 401 @@ -141,14 +142,14 @@ class TestSocketEndpoints: # Missing target_user_id response = await authenticated_client.post( "/api/v1/socket/send-message", - params={"message": "test"} + params={"message": "test"}, ) assert response.status_code == 422 - + # Missing message response = await authenticated_client.post( "/api/v1/socket/send-message", - params={"target_user_id": 1} + params={"target_user_id": 1}, ) assert response.status_code == 422 @@ -163,7 +164,7 @@ class TestSocketEndpoints: """Test sending message with invalid user ID.""" response = await authenticated_client.post( "/api/v1/socket/send-message", - params={"target_user_id": "invalid", "message": "test"} + params={"target_user_id": "invalid", "message": "test"}, ) assert response.status_code == 422 @@ -172,14 +173,14 @@ class TestSocketEndpoints: """Test that socket status correctly shows if user is connected.""" # Test when user is connected mock_socket_manager.get_connected_users.return_value = [str(authenticated_user.id), "2", "3"] - + response = await authenticated_client.get("/api/v1/socket/status") data = response.json() assert data["connected"] is True - + # Test when user is not connected mock_socket_manager.get_connected_users.return_value = ["2", "3", "4"] - + response = await authenticated_client.get("/api/v1/socket/status") data = response.json() - assert data["connected"] is False \ No newline at end of file + assert data["connected"] is False diff --git a/tests/conftest.py b/tests/conftest.py index 58bcd6f..babdae6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,6 @@ from sqlmodel import SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.database import get_db -from app.main import create_app from app.models.plan import Plan from app.models.user import User from app.utils.auth import JWTUtils, PasswordUtils @@ -199,7 +198,7 @@ async def ensure_plans(test_session: AsyncSession) -> tuple[Plan, Plan]: @pytest_asyncio.fixture async def test_user( - test_session: AsyncSession, ensure_plans: tuple[Plan, Plan] + test_session: AsyncSession, ensure_plans: tuple[Plan, Plan], ) -> User: """Create a test user.""" user = User( @@ -219,7 +218,7 @@ async def test_user( @pytest_asyncio.fixture async def admin_user( - test_session: AsyncSession, ensure_plans: tuple[Plan, Plan] + test_session: AsyncSession, ensure_plans: tuple[Plan, Plan], ) -> User: """Create a test admin user.""" user = User( diff --git a/tests/core/test_api_token_dependencies.py b/tests/core/test_api_token_dependencies.py new file mode 100644 index 0000000..5ca1026 --- /dev/null +++ b/tests/core/test_api_token_dependencies.py @@ -0,0 +1,191 @@ +"""Tests for API token authentication dependencies.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock + +import pytest +from fastapi import HTTPException + +from app.core.dependencies import get_current_user_api_token, get_current_user_flexible +from app.models.user import User +from app.services.auth import AuthService + + +class TestApiTokenDependencies: + """Test API token authentication dependencies.""" + + @pytest.fixture + def mock_auth_service(self): + """Create a mock auth service.""" + return AsyncMock(spec=AuthService) + + @pytest.fixture + def test_user(self): + """Create a test user.""" + return User( + id=1, + email="test@example.com", + name="Test User", + role="user", + is_active=True, + plan_id=1, + credits=100, + api_token="test_api_token_123", + api_token_expires_at=datetime.now(UTC) + timedelta(days=30), + ) + + @pytest.mark.asyncio + async def test_get_current_user_api_token_success( + self, mock_auth_service, test_user, + ): + """Test successful API token authentication.""" + mock_auth_service.get_user_by_api_token.return_value = test_user + + authorization = "Bearer test_api_token_123" + + result = await get_current_user_api_token(mock_auth_service, authorization) + + assert result == test_user + mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123") + + @pytest.mark.asyncio + async def test_get_current_user_api_token_no_header(self, mock_auth_service): + """Test API token authentication without Authorization header.""" + with pytest.raises(HTTPException) as exc_info: + await get_current_user_api_token(mock_auth_service, None) + + assert exc_info.value.status_code == 401 + assert "Authorization header required" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_get_current_user_api_token_invalid_format(self, mock_auth_service): + """Test API token authentication with invalid header format.""" + authorization = "Invalid format" + + with pytest.raises(HTTPException) as exc_info: + await get_current_user_api_token(mock_auth_service, authorization) + + assert exc_info.value.status_code == 401 + assert "Invalid authorization header format" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_get_current_user_api_token_empty_token(self, mock_auth_service): + """Test API token authentication with empty token.""" + authorization = "Bearer " + + with pytest.raises(HTTPException) as exc_info: + await get_current_user_api_token(mock_auth_service, authorization) + + assert exc_info.value.status_code == 401 + assert "API token required" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_get_current_user_api_token_invalid_token(self, mock_auth_service): + """Test API token authentication with invalid token.""" + mock_auth_service.get_user_by_api_token.return_value = None + + authorization = "Bearer invalid_token" + + with pytest.raises(HTTPException) as exc_info: + await get_current_user_api_token(mock_auth_service, authorization) + + assert exc_info.value.status_code == 401 + assert "Invalid API token" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_get_current_user_api_token_expired_token( + self, mock_auth_service, test_user, + ): + """Test API token authentication with expired token.""" + # Set expired token + test_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1) + mock_auth_service.get_user_by_api_token.return_value = test_user + + authorization = "Bearer expired_token" + + with pytest.raises(HTTPException) as exc_info: + await get_current_user_api_token(mock_auth_service, authorization) + + assert exc_info.value.status_code == 401 + assert "API token has expired" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_get_current_user_api_token_inactive_user( + self, mock_auth_service, test_user, + ): + """Test API token authentication with inactive user.""" + test_user.is_active = False + mock_auth_service.get_user_by_api_token.return_value = test_user + + authorization = "Bearer test_token" + + with pytest.raises(HTTPException) as exc_info: + await get_current_user_api_token(mock_auth_service, authorization) + + assert exc_info.value.status_code == 401 + assert "Account is deactivated" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_get_current_user_api_token_service_exception(self, mock_auth_service): + """Test API token authentication with service exception.""" + mock_auth_service.get_user_by_api_token.side_effect = Exception("Database error") + + authorization = "Bearer test_token" + + with pytest.raises(HTTPException) as exc_info: + await get_current_user_api_token(mock_auth_service, authorization) + + assert exc_info.value.status_code == 401 + assert "Could not validate API token" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_get_current_user_flexible_uses_api_token( + self, mock_auth_service, test_user, + ): + """Test flexible authentication uses API token when available.""" + mock_auth_service.get_user_by_api_token.return_value = test_user + + authorization = "Bearer test_api_token_123" + access_token = "jwt_token" + + result = await get_current_user_flexible( + mock_auth_service, access_token, authorization, + ) + + assert result == test_user + mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123") + + @pytest.mark.asyncio + async def test_get_current_user_flexible_falls_back_to_jwt(self, mock_auth_service): + """Test flexible authentication falls back to JWT when no API token.""" + # Mock the get_current_user function (normally imported) + with pytest.raises(Exception): + # This will fail because we can't easily mock the get_current_user import + # In a real test, you'd mock the import or use dependency injection + await get_current_user_flexible(mock_auth_service, "jwt_token", None) + + @pytest.mark.asyncio + async def test_api_token_no_expiry_never_expires(self, mock_auth_service, test_user): + """Test API token with no expiry date never expires.""" + test_user.api_token_expires_at = None + mock_auth_service.get_user_by_api_token.return_value = test_user + + authorization = "Bearer test_token" + + result = await get_current_user_api_token(mock_auth_service, authorization) + + assert result == test_user + + @pytest.mark.asyncio + async def test_api_token_bearer_case_insensitive(self, mock_auth_service, test_user): + """Test that Bearer prefix is case-sensitive (as per OAuth2 spec).""" + mock_auth_service.get_user_by_api_token.return_value = test_user + + # lowercase bearer should fail + authorization = "bearer test_token" + + with pytest.raises(HTTPException) as exc_info: + await get_current_user_api_token(mock_auth_service, authorization) + + assert exc_info.value.status_code == 401 + assert "Invalid authorization header format" in exc_info.value.detail diff --git a/tests/services/test_auth_service.py b/tests/services/test_auth_service.py index e27da6a..abeb786 100644 --- a/tests/services/test_auth_service.py +++ b/tests/services/test_auth_service.py @@ -9,7 +9,6 @@ from app.models.plan import Plan from app.models.user import User from app.schemas.auth import UserLoginRequest, UserRegisterRequest from app.services.auth import AuthService -from app.utils.auth import PasswordUtils class TestAuthService: @@ -49,11 +48,11 @@ class TestAuthService: @pytest.mark.asyncio async def test_register_duplicate_email( - self, auth_service: AuthService, test_user: User + self, auth_service: AuthService, test_user: User, ) -> None: """Test registration with duplicate email.""" request = UserRegisterRequest( - email=test_user.email, password="password123", name="Another User" + email=test_user.email, password="password123", name="Another User", ) with pytest.raises(HTTPException) as exc_info: @@ -90,7 +89,7 @@ class TestAuthService: async def test_login_invalid_email(self, auth_service: AuthService) -> None: """Test login with invalid email.""" request = UserLoginRequest( - email="nonexistent@example.com", password="password123" + email="nonexistent@example.com", password="password123", ) with pytest.raises(HTTPException) as exc_info: @@ -101,7 +100,7 @@ class TestAuthService: @pytest.mark.asyncio async def test_login_invalid_password( - self, auth_service: AuthService, test_user: User + self, auth_service: AuthService, test_user: User, ) -> None: """Test login with invalid password.""" request = UserLoginRequest(email=test_user.email, password="wrongpassword") @@ -114,7 +113,7 @@ class TestAuthService: @pytest.mark.asyncio async def test_login_inactive_user( - self, auth_service: AuthService, test_user: User, test_session: AsyncSession + self, auth_service: AuthService, test_user: User, test_session: AsyncSession, ) -> None: """Test login with inactive user.""" # Store the email before deactivating @@ -134,7 +133,7 @@ class TestAuthService: @pytest.mark.asyncio async def test_login_user_without_password( - self, auth_service: AuthService, test_user: User, test_session: AsyncSession + self, auth_service: AuthService, test_user: User, test_session: AsyncSession, ) -> None: """Test login with user that has no password hash.""" # Store the email before removing password @@ -154,7 +153,7 @@ class TestAuthService: @pytest.mark.asyncio async def test_get_current_user_success( - self, auth_service: AuthService, test_user: User + self, auth_service: AuthService, test_user: User, ) -> None: """Test getting current user successfully.""" user = await auth_service.get_current_user(test_user.id) @@ -175,7 +174,7 @@ class TestAuthService: @pytest.mark.asyncio async def test_get_current_user_inactive( - self, auth_service: AuthService, test_user: User, test_session: AsyncSession + self, auth_service: AuthService, test_user: User, test_session: AsyncSession, ) -> None: """Test getting current user when user is inactive.""" # Store the user ID before deactivating @@ -193,7 +192,7 @@ class TestAuthService: @pytest.mark.asyncio async def test_create_access_token( - self, auth_service: AuthService, test_user: User + self, auth_service: AuthService, test_user: User, ) -> None: """Test access token creation.""" token_response = auth_service._create_access_token(test_user) @@ -212,7 +211,7 @@ class TestAuthService: @pytest.mark.asyncio async def test_create_user_response( - self, auth_service: AuthService, test_user: User, test_session: AsyncSession + self, auth_service: AuthService, test_user: User, test_session: AsyncSession, ) -> None: """Test user response creation.""" # Ensure plan relationship is loaded diff --git a/tests/services/test_oauth_service.py b/tests/services/test_oauth_service.py index be5c3bd..57d141d 100644 --- a/tests/services/test_oauth_service.py +++ b/tests/services/test_oauth_service.py @@ -1,16 +1,14 @@ """Tests for OAuth service.""" from typing import Any -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock, patch import pytest -from httpx import AsyncClient from app.services.oauth import ( GitHubOAuthProvider, GoogleOAuthProvider, OAuthService, - OAuthUserInfo, ) diff --git a/tests/services/test_socket_service.py b/tests/services/test_socket_service.py index 9a090cc..957b86b 100644 --- a/tests/services/test_socket_service.py +++ b/tests/services/test_socket_service.py @@ -1,7 +1,8 @@ """Tests for socket service.""" +from unittest.mock import AsyncMock, patch + import pytest -from unittest.mock import AsyncMock, MagicMock, patch, call import socketio from app.services.socket import SocketManager @@ -24,7 +25,7 @@ class TestSocketManager: def test_init_creates_socket_server(self): """Test that socket manager initializes with proper configuration.""" manager = SocketManager() - + assert manager.sio is not None assert isinstance(manager.user_rooms, dict) assert isinstance(manager.socket_users, dict) @@ -37,12 +38,12 @@ class TestSocketManager: user_id = "123" room_id = "user_123" socket_manager.user_rooms[user_id] = room_id - + event = "test_event" data = {"message": "hello"} - + result = await socket_manager.send_to_user(user_id, event, data) - + assert result is True mock_sio.emit.assert_called_once_with(event, data, room=room_id) @@ -52,9 +53,9 @@ class TestSocketManager: user_id = "999" event = "test_event" data = {"message": "hello"} - + result = await socket_manager.send_to_user(user_id, event, data) - + assert result is False mock_sio.emit.assert_not_called() @@ -63,9 +64,9 @@ class TestSocketManager: """Test broadcasting message to all users.""" event = "broadcast_event" data = {"message": "announcement"} - + await socket_manager.broadcast_to_all(event, data) - + mock_sio.emit.assert_called_once_with(event, data) def test_get_connected_users(self, socket_manager): @@ -74,9 +75,9 @@ class TestSocketManager: socket_manager.user_rooms["1"] = "user_1" socket_manager.user_rooms["2"] = "user_2" socket_manager.user_rooms["3"] = "user_3" - + connected_users = socket_manager.get_connected_users() - + assert len(connected_users) == 3 assert "1" in connected_users assert "2" in connected_users @@ -87,139 +88,139 @@ class TestSocketManager: # Add some users socket_manager.user_rooms["1"] = "user_1" socket_manager.user_rooms["2"] = "user_2" - + room_info = socket_manager.get_room_info() - + assert room_info["total_users"] == 2 assert room_info["connected_users"] == ["1", "2"] @pytest.mark.asyncio - @patch('app.services.socket.extract_access_token_from_cookies') - @patch('app.services.socket.JWTUtils.decode_access_token') + @patch("app.services.socket.extract_access_token_from_cookies") + @patch("app.services.socket.JWTUtils.decode_access_token") async def test_connect_handler_success(self, mock_decode, mock_extract_token, socket_manager, mock_sio): """Test successful connection with valid token.""" # Setup mocks mock_extract_token.return_value = "valid_token" mock_decode.return_value = {"sub": "123"} - + # Mock environment environ = {"HTTP_COOKIE": "access_token=valid_token"} - + # Access the connect handler directly handlers = {} original_event = socket_manager.sio.event - + def mock_event(func): handlers[func.__name__] = func return func - + socket_manager.sio.event = mock_event socket_manager._setup_handlers() - + # Call the connect handler - await handlers['connect']("test_sid", environ) - + await handlers["connect"]("test_sid", environ) + # Verify token extraction and validation mock_extract_token.assert_called_once_with("access_token=valid_token") mock_decode.assert_called_once_with("valid_token") - + # Verify user tracking assert socket_manager.socket_users["test_sid"] == "123" assert socket_manager.user_rooms["123"] == "user_123" @pytest.mark.asyncio - @patch('app.services.socket.extract_access_token_from_cookies') + @patch("app.services.socket.extract_access_token_from_cookies") async def test_connect_handler_no_token(self, mock_extract_token, socket_manager, mock_sio): """Test connection with no access token.""" # Setup mocks mock_extract_token.return_value = None - + # Mock environment environ = {"HTTP_COOKIE": ""} - + # Access the connect handler directly handlers = {} original_event = socket_manager.sio.event - + def mock_event(func): handlers[func.__name__] = func return func - + socket_manager.sio.event = mock_event socket_manager._setup_handlers() - + # Call the connect handler - await handlers['connect']("test_sid", environ) - + await handlers["connect"]("test_sid", environ) + # Verify disconnection mock_sio.disconnect.assert_called_once_with("test_sid") - + # Verify no user tracking assert "test_sid" not in socket_manager.socket_users assert len(socket_manager.user_rooms) == 0 @pytest.mark.asyncio - @patch('app.services.socket.extract_access_token_from_cookies') - @patch('app.services.socket.JWTUtils.decode_access_token') + @patch("app.services.socket.extract_access_token_from_cookies") + @patch("app.services.socket.JWTUtils.decode_access_token") async def test_connect_handler_invalid_token(self, mock_decode, mock_extract_token, socket_manager, mock_sio): """Test connection with invalid token.""" # Setup mocks mock_extract_token.return_value = "invalid_token" mock_decode.side_effect = Exception("Invalid token") - + # Mock environment environ = {"HTTP_COOKIE": "access_token=invalid_token"} - + # Access the connect handler directly handlers = {} original_event = socket_manager.sio.event - + def mock_event(func): handlers[func.__name__] = func return func - + socket_manager.sio.event = mock_event socket_manager._setup_handlers() - + # Call the connect handler - await handlers['connect']("test_sid", environ) - + await handlers["connect"]("test_sid", environ) + # Verify disconnection mock_sio.disconnect.assert_called_once_with("test_sid") - + # Verify no user tracking assert "test_sid" not in socket_manager.socket_users assert len(socket_manager.user_rooms) == 0 @pytest.mark.asyncio - @patch('app.services.socket.extract_access_token_from_cookies') - @patch('app.services.socket.JWTUtils.decode_access_token') + @patch("app.services.socket.extract_access_token_from_cookies") + @patch("app.services.socket.JWTUtils.decode_access_token") async def test_connect_handler_missing_user_id(self, mock_decode, mock_extract_token, socket_manager, mock_sio): """Test connection with token missing user ID.""" # Setup mocks mock_extract_token.return_value = "token_without_user_id" mock_decode.return_value = {"other_field": "value"} # Missing 'sub' - + # Mock environment environ = {"HTTP_COOKIE": "access_token=token_without_user_id"} - + # Access the connect handler directly handlers = {} original_event = socket_manager.sio.event - + def mock_event(func): handlers[func.__name__] = func return func - + socket_manager.sio.event = mock_event socket_manager._setup_handlers() - + # Call the connect handler - await handlers['connect']("test_sid", environ) - + await handlers["connect"]("test_sid", environ) + # Verify disconnection mock_sio.disconnect.assert_called_once_with("test_sid") - + # Verify no user tracking assert "test_sid" not in socket_manager.socket_users assert len(socket_manager.user_rooms) == 0 @@ -230,21 +231,21 @@ class TestSocketManager: # Setup initial state socket_manager.socket_users["test_sid"] = "123" socket_manager.user_rooms["123"] = "user_123" - + # Access the disconnect handler directly handlers = {} original_event = socket_manager.sio.event - + def mock_event(func): handlers[func.__name__] = func return func - + socket_manager.sio.event = mock_event socket_manager._setup_handlers() - + # Call the disconnect handler - await handlers['disconnect']("test_sid") - + await handlers["disconnect"]("test_sid") + # Verify cleanup assert "test_sid" not in socket_manager.socket_users assert "123" not in socket_manager.user_rooms @@ -255,17 +256,17 @@ class TestSocketManager: # Access the disconnect handler directly handlers = {} original_event = socket_manager.sio.event - + def mock_event(func): handlers[func.__name__] = func return func - + socket_manager.sio.event = mock_event socket_manager._setup_handlers() - + # Call the disconnect handler with unknown socket - await handlers['disconnect']("unknown_sid") - + await handlers["disconnect"]("unknown_sid") + # Should not raise any errors and state should remain clean assert len(socket_manager.socket_users) == 0 - assert len(socket_manager.user_rooms) == 0 \ No newline at end of file + assert len(socket_manager.user_rooms) == 0 diff --git a/tests/utils/test_cookies.py b/tests/utils/test_cookies.py index f7f3243..a91e1f2 100644 --- a/tests/utils/test_cookies.py +++ b/tests/utils/test_cookies.py @@ -1,8 +1,7 @@ """Tests for cookie utilities.""" -import pytest -from app.utils.cookies import parse_cookies, extract_access_token_from_cookies +from app.utils.cookies import extract_access_token_from_cookies, parse_cookies class TestParseCookies: @@ -22,18 +21,18 @@ class TestParseCookies: """Test parsing single cookie.""" cookie_header = "session_id=abc123" result = parse_cookies(cookie_header) - + assert result == {"session_id": "abc123"} def test_parse_multiple_cookies(self): """Test parsing multiple cookies.""" cookie_header = "session_id=abc123; user_pref=dark_mode; lang=en" result = parse_cookies(cookie_header) - + expected = { "session_id": "abc123", "user_pref": "dark_mode", - "lang": "en" + "lang": "en", } assert result == expected @@ -41,10 +40,10 @@ class TestParseCookies: """Test parsing cookies with extra spaces.""" cookie_header = " session_id = abc123 ; user_pref = dark_mode " result = parse_cookies(cookie_header) - + expected = { "session_id": "abc123", - "user_pref": "dark_mode" + "user_pref": "dark_mode", } assert result == expected @@ -52,10 +51,10 @@ class TestParseCookies: """Test parsing cookies where value contains equals sign.""" cookie_header = "encoded_data=key=value&other=data; session=123" result = parse_cookies(cookie_header) - + expected = { "encoded_data": "key=value&other=data", - "session": "123" + "session": "123", } assert result == expected @@ -63,11 +62,11 @@ class TestParseCookies: """Test parsing malformed cookies (no equals sign).""" cookie_header = "session_id=abc123; malformed_cookie; user_pref=dark" result = parse_cookies(cookie_header) - + # Should skip malformed cookie and parse valid ones expected = { "session_id": "abc123", - "user_pref": "dark" + "user_pref": "dark", } assert result == expected @@ -75,10 +74,10 @@ class TestParseCookies: """Test parsing cookies with empty values.""" cookie_header = "empty_value=; session_id=abc123" result = parse_cookies(cookie_header) - + expected = { "empty_value": "", - "session_id": "abc123" + "session_id": "abc123", } assert result == expected @@ -86,7 +85,7 @@ class TestParseCookies: """Test parsing cookies with duplicate names (last one wins).""" cookie_header = "session_id=first; session_id=second" result = parse_cookies(cookie_header) - + assert result == {"session_id": "second"} @@ -97,14 +96,14 @@ class TestExtractAccessTokenFromCookies: """Test extracting access token when present.""" cookie_header = "session_id=abc123; access_token=jwt_token_here; user_pref=dark" result = extract_access_token_from_cookies(cookie_header) - + assert result == "jwt_token_here" def test_extract_access_token_not_present(self): """Test extracting access token when not present.""" cookie_header = "session_id=abc123; user_pref=dark" result = extract_access_token_from_cookies(cookie_header) - + assert result is None def test_extract_access_token_empty_header(self): @@ -116,21 +115,21 @@ class TestExtractAccessTokenFromCookies: """Test extracting access token when it's the only cookie.""" cookie_header = "access_token=my_jwt_token" result = extract_access_token_from_cookies(cookie_header) - + assert result == "my_jwt_token" def test_extract_access_token_with_spaces(self): """Test extracting access token with spaces around values.""" cookie_header = " access_token = jwt_token_with_spaces ; other=value " result = extract_access_token_from_cookies(cookie_header) - + assert result == "jwt_token_with_spaces" def test_extract_access_token_empty_value(self): """Test extracting access token with empty value.""" cookie_header = "access_token=; other=value" result = extract_access_token_from_cookies(cookie_header) - + assert result == "" def test_extract_access_token_complex_value(self): @@ -138,12 +137,12 @@ class TestExtractAccessTokenFromCookies: jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjMiLCJleHAiOjE2MzM5NjY0MDB9.signature" cookie_header = f"session=abc; access_token={jwt_token}; csrf=token" result = extract_access_token_from_cookies(cookie_header) - + assert result == jwt_token def test_extract_access_token_multiple_equals(self): """Test extracting access token when value contains equals signs.""" cookie_header = "access_token=encoded=data=here; other=simple" result = extract_access_token_from_cookies(cookie_header) - - assert result == "encoded=data=here" \ No newline at end of file + + assert result == "encoded=data=here" diff --git a/tests/utils/test_token_utils.py b/tests/utils/test_token_utils.py new file mode 100644 index 0000000..fa723b6 --- /dev/null +++ b/tests/utils/test_token_utils.py @@ -0,0 +1,80 @@ +"""Tests for token utilities.""" + +from datetime import UTC, datetime, timedelta + +from app.utils.auth import TokenUtils + + +class TestTokenUtils: + """Test token utility functions.""" + + def test_generate_api_token(self): + """Test API token generation.""" + token = TokenUtils.generate_api_token() + + # Should be a string + assert isinstance(token, str) + + # Should not be empty + assert len(token) > 0 + + # Should be URL-safe base64 (43 characters for 32 bytes) + assert len(token) == 43 + + # Should be unique (generate multiple and check they're different) + tokens = [TokenUtils.generate_api_token() for _ in range(10)] + assert len(set(tokens)) == 10 + + def test_is_token_expired_none(self): + """Test token expiration check with None expires_at.""" + result = TokenUtils.is_token_expired(None) + assert result is False + + def test_is_token_expired_future_naive(self): + """Test token expiration check with future naive datetime.""" + # Use UTC time for naive datetime (as the function assumes) + expires_at = datetime.utcnow() + timedelta(hours=1) + result = TokenUtils.is_token_expired(expires_at) + assert result is False + + def test_is_token_expired_past_naive(self): + """Test token expiration check with past naive datetime.""" + # Use UTC time for naive datetime (as the function assumes) + expires_at = datetime.utcnow() - timedelta(hours=1) + result = TokenUtils.is_token_expired(expires_at) + assert result is True + + def test_is_token_expired_future_aware(self): + """Test token expiration check with future timezone-aware datetime.""" + expires_at = datetime.now(UTC) + timedelta(hours=1) + result = TokenUtils.is_token_expired(expires_at) + assert result is False + + def test_is_token_expired_past_aware(self): + """Test token expiration check with past timezone-aware datetime.""" + expires_at = datetime.now(UTC) - timedelta(hours=1) + result = TokenUtils.is_token_expired(expires_at) + assert result is True + + def test_is_token_expired_edge_case_now(self): + """Test token expiration check with time very close to now.""" + # Token expires in 1 second + expires_at = datetime.now(UTC) + timedelta(seconds=1) + result = TokenUtils.is_token_expired(expires_at) + assert result is False + + # Token expired 1 second ago + expires_at = datetime.now(UTC) - timedelta(seconds=1) + result = TokenUtils.is_token_expired(expires_at) + assert result is True + + def test_is_token_expired_timezone_conversion(self): + """Test token expiration check with different timezone.""" + from zoneinfo import ZoneInfo + + # Create a datetime in a different timezone + eastern = ZoneInfo("US/Eastern") + expires_at = datetime.now(eastern) + timedelta(hours=1) + + result = TokenUtils.is_token_expired(expires_at) + assert result is False