Refactor tests for improved consistency and readability

- Updated test cases in `test_auth_endpoints.py` to ensure consistent formatting and style.
- Enhanced `test_socket_endpoints.py` with consistent parameter formatting and improved readability.
- Cleaned up `conftest.py` by ensuring consistent parameter formatting in fixtures.
- Added comprehensive tests for API token dependencies in `test_api_token_dependencies.py`.
- Refactored `test_auth_service.py` to maintain consistent parameter formatting.
- Cleaned up `test_oauth_service.py` by removing unnecessary imports.
- Improved `test_socket_service.py` with consistent formatting and readability.
- Enhanced `test_cookies.py` by ensuring consistent formatting and readability.
- Introduced new tests for token utilities in `test_token_utils.py` to validate token generation and expiration logic.
This commit is contained in:
JSC
2025-07-27 15:11:47 +02:00
parent 42deab2409
commit 3dc21337f9
16 changed files with 991 additions and 159 deletions

View File

@@ -11,14 +11,22 @@ from app.core.config import settings
from app.core.dependencies import ( from app.core.dependencies import (
get_auth_service, get_auth_service,
get_current_active_user, get_current_active_user,
get_current_active_user_flexible,
get_oauth_service, get_oauth_service,
) )
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.user import User from app.models.user import User
from app.schemas.auth import UserLoginRequest, UserRegisterRequest, UserResponse from app.schemas.auth import (
ApiTokenRequest,
ApiTokenResponse,
ApiTokenStatusResponse,
UserLoginRequest,
UserRegisterRequest,
UserResponse,
)
from app.services.auth import AuthService from app.services.auth import AuthService
from app.services.oauth import OAuthService from app.services.oauth import OAuthService
from app.utils.auth import JWTUtils from app.utils.auth import JWTUtils, TokenUtils
router = APIRouter() router = APIRouter()
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -131,7 +139,7 @@ async def login(
@router.get("/me") @router.get("/me")
async def get_current_user_info( async def get_current_user_info(
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
auth_service: Annotated[AuthService, Depends(get_auth_service)], auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> UserResponse: ) -> UserResponse:
"""Get current user information.""" """Get current user information."""
@@ -426,3 +434,72 @@ async def exchange_oauth_token(
user_id = token_data["user_id"] user_id = token_data["user_id"]
logger.info("OAuth tokens exchanged successfully for user: %s", user_id) logger.info("OAuth tokens exchanged successfully for user: %s", user_id)
return {"message": "Tokens set successfully", "user_id": str(user_id)} return {"message": "Tokens set successfully", "user_id": str(user_id)}
# API Token endpoints
@router.post("/api-token")
async def generate_api_token(
request: ApiTokenRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> ApiTokenResponse:
"""Generate a new API token for the current user."""
try:
api_token = await auth_service.generate_api_token(
current_user,
expires_days=request.expires_days,
)
# Refresh user to get updated token info
await auth_service.session.refresh(current_user)
return ApiTokenResponse(
api_token=api_token,
expires_at=current_user.api_token_expires_at,
)
except Exception as e:
logger.exception(
"Failed to generate API token for user: %s", current_user.email,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate API token",
) from e
@router.get("/api-token/status")
async def get_api_token_status(
current_user: Annotated[User, Depends(get_current_active_user)],
) -> ApiTokenStatusResponse:
"""Get the current user's API token status."""
has_token = current_user.api_token is not None
is_expired = False
if has_token and current_user.api_token_expires_at:
is_expired = TokenUtils.is_token_expired(current_user.api_token_expires_at)
return ApiTokenStatusResponse(
has_token=has_token,
expires_at=current_user.api_token_expires_at,
is_expired=is_expired,
)
@router.delete("/api-token")
async def revoke_api_token(
current_user: Annotated[User, Depends(get_current_active_user)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> dict[str, str]:
"""Revoke the current user's API token."""
try:
await auth_service.revoke_api_token(current_user)
except Exception as e:
logger.exception(
"Failed to revoke API token for user: %s", current_user.email,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke API token",
) from e
else:
return {"message": "API token revoked successfully"}

View File

@@ -2,7 +2,7 @@
from typing import Annotated, cast from typing import Annotated, cast
from fastapi import Cookie, Depends, HTTPException, status from fastapi import Cookie, Depends, Header, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db
@@ -10,7 +10,7 @@ from app.core.logging import get_logger
from app.models.user import User from app.models.user import User
from app.services.auth import AuthService from app.services.auth import AuthService
from app.services.oauth import OAuthService from app.services.oauth import OAuthService
from app.utils.auth import JWTUtils from app.utils.auth import JWTUtils, TokenUtils
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -92,6 +92,95 @@ async def get_current_active_user(
return current_user return current_user
async def get_current_user_api_token(
auth_service: Annotated[AuthService, Depends(get_auth_service)],
authorization: Annotated[str | None, Header()] = None,
) -> User:
"""Get the current authenticated user from API token in Authorization header."""
try:
# Check if Authorization header exists
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header required",
)
# Check if it's a Bearer token
if not authorization.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization header format",
)
# Extract the API token
api_token = authorization[7:] # Remove "Bearer " prefix
if not api_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API token required",
)
# Get the user by API token
user = await auth_service.get_user_by_api_token(api_token)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API token",
)
# Check if token is expired
if TokenUtils.is_token_expired(user.api_token_expires_at):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API token has expired",
)
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return user
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception as e:
logger.exception("Failed to authenticate user with API token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate API token",
) from e
async def get_current_user_flexible(
auth_service: Annotated[AuthService, Depends(get_auth_service)],
access_token: Annotated[str | None, Cookie()] = None,
authorization: Annotated[str | None, Header()] = None,
) -> User:
"""Get the current authenticated user from either JWT cookie or API token."""
# Try API token first if Authorization header is present
if authorization:
return await get_current_user_api_token(auth_service, authorization)
# Fall back to JWT cookie authentication
return await get_current_user(auth_service, access_token)
async def get_current_active_user_flexible(
current_user: Annotated[User, Depends(get_current_user_flexible)],
) -> User:
"""Get the current authenticated and active user using flexible auth."""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return current_user
async def get_admin_user( async def get_admin_user(
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> User: ) -> User:

View File

@@ -51,3 +51,29 @@ class AuthResponse(BaseModel):
user: UserResponse = Field(..., description="User information") user: UserResponse = Field(..., description="User information")
token: TokenResponse = Field(..., description="Authentication token") token: TokenResponse = Field(..., description="Authentication token")
class ApiTokenRequest(BaseModel):
"""Schema for API token generation request."""
expires_days: int = Field(
default=365,
ge=1,
le=3650,
description="Number of days until token expires (1-3650 days)",
)
class ApiTokenResponse(BaseModel):
"""Schema for API token response."""
api_token: str = Field(..., description="Generated API token")
expires_at: datetime = Field(..., description="Token expiration timestamp")
class ApiTokenStatusResponse(BaseModel):
"""Schema for API token status response."""
has_token: bool = Field(..., description="Whether user has an active API token")
expires_at: datetime | None = Field(None, description="Token expiration timestamp")
is_expired: bool = Field(..., description="Whether the token is expired")

View File

@@ -19,7 +19,7 @@ from app.schemas.auth import (
UserResponse, UserResponse,
) )
from app.services.oauth import OAuthUserInfo from app.services.oauth import OAuthUserInfo
from app.utils.auth import JWTUtils, PasswordUtils from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -123,6 +123,37 @@ class AuthService:
return user return user
async def get_user_by_api_token(self, api_token: str) -> User | None:
"""Get a user by their API token."""
return await self.user_repo.get_by_api_token(api_token)
async def generate_api_token(self, user: User, expires_days: int = 365) -> str:
"""Generate a new API token for a user."""
# Generate a secure random token
api_token = TokenUtils.generate_api_token()
# Set expiration date
expires_at = datetime.now(UTC) + timedelta(days=expires_days)
# Update user with new API token
update_data = {
"api_token": api_token,
"api_token_expires_at": expires_at,
}
await self.user_repo.update(user, update_data)
logger.info("Generated new API token for user: %s", user.email)
return api_token
async def revoke_api_token(self, user: User) -> None:
"""Revoke a user's API token."""
update_data = {
"api_token": None,
"api_token_expires_at": None,
}
await self.user_repo.update(user, update_data)
logger.info("Revoked API token for user: %s", user.email)
def _create_access_token(self, user: User) -> TokenResponse: def _create_access_token(self, user: User) -> TokenResponse:
"""Create an access token for a user.""" """Create an access token for a user."""
access_token_expires = timedelta( access_token_expires = timedelta(

View File

@@ -104,9 +104,8 @@ class SocketManager:
await self.sio.emit(event, data, room=room_id) await self.sio.emit(event, data, room=room_id)
logger.debug(f"Sent {event} to user {user_id} in room {room_id}") logger.debug(f"Sent {event} to user {user_id} in room {room_id}")
return True return True
else: logger.warning(f"User {user_id} not found in any room")
logger.warning(f"User {user_id} not found in any room") return False
return False
async def broadcast_to_all(self, event: str, data: dict): async def broadcast_to_all(self, event: str, data: dict):
"""Broadcast a message to all connected users.""" """Broadcast a message to all connected users."""

View File

@@ -1,6 +1,5 @@
"""Cookie parsing utilities for WebSocket authentication.""" """Cookie parsing utilities for WebSocket authentication."""
from typing import Optional
def parse_cookies(cookie_header: str) -> dict[str, str]: def parse_cookies(cookie_header: str) -> dict[str, str]:
@@ -18,7 +17,7 @@ def parse_cookies(cookie_header: str) -> dict[str, str]:
return cookies return cookies
def extract_access_token_from_cookies(cookie_header: str) -> Optional[str]: def extract_access_token_from_cookies(cookie_header: str) -> str | None:
"""Extract access token from HTTP cookies.""" """Extract access token from HTTP cookies."""
cookies = parse_cookies(cookie_header) cookies = parse_cookies(cookie_header)
return cookies.get("access_token") return cookies.get("access_token")

View File

@@ -0,0 +1,343 @@
"""Tests for API token endpoints."""
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import pytest
from httpx import AsyncClient
from app.models.user import User
class TestApiTokenEndpoints:
"""Test API token management endpoints."""
@pytest.mark.asyncio
async def test_generate_api_token_success(
self, authenticated_client: AsyncClient, authenticated_user: User,
):
"""Test successful API token generation."""
request_data = {"expires_days": 30}
response = await authenticated_client.post(
"/api/v1/auth/api-token",
json=request_data,
)
assert response.status_code == 200
data = response.json()
assert "api_token" in data
assert "expires_at" in data
assert len(data["api_token"]) > 0
# Verify token format (should be URL-safe base64)
import base64
try:
base64.urlsafe_b64decode(data["api_token"] + "===") # Add padding
except Exception:
pytest.fail("API token should be valid URL-safe base64")
@pytest.mark.asyncio
async def test_generate_api_token_default_expiry(
self, authenticated_client: AsyncClient,
):
"""Test API token generation with default expiry."""
response = await authenticated_client.post("/api/v1/auth/api-token", json={})
assert response.status_code == 200
data = response.json()
expires_at_str = data["expires_at"]
# Handle both ISO format with/without timezone info
if expires_at_str.endswith("Z"):
expires_at = datetime.fromisoformat(expires_at_str.replace("Z", "+00:00"))
elif "+" in expires_at_str or expires_at_str.count("-") > 2:
expires_at = datetime.fromisoformat(expires_at_str)
else:
# Naive datetime, assume UTC
expires_at = datetime.fromisoformat(expires_at_str).replace(tzinfo=UTC)
expected_expiry = datetime.now(UTC) + timedelta(days=365)
# Allow 1 minute tolerance
assert abs((expires_at - expected_expiry).total_seconds()) < 60
@pytest.mark.asyncio
async def test_generate_api_token_custom_expiry(
self, authenticated_client: AsyncClient,
):
"""Test API token generation with custom expiry."""
expires_days = 90
request_data = {"expires_days": expires_days}
response = await authenticated_client.post(
"/api/v1/auth/api-token",
json=request_data,
)
assert response.status_code == 200
data = response.json()
expires_at_str = data["expires_at"]
# Handle both ISO format with/without timezone info
if expires_at_str.endswith("Z"):
expires_at = datetime.fromisoformat(expires_at_str.replace("Z", "+00:00"))
elif "+" in expires_at_str or expires_at_str.count("-") > 2:
expires_at = datetime.fromisoformat(expires_at_str)
else:
# Naive datetime, assume UTC
expires_at = datetime.fromisoformat(expires_at_str).replace(tzinfo=UTC)
expected_expiry = datetime.now(UTC) + timedelta(days=expires_days)
# Allow 1 minute tolerance
assert abs((expires_at - expected_expiry).total_seconds()) < 60
@pytest.mark.asyncio
async def test_generate_api_token_validation_errors(
self, authenticated_client: AsyncClient,
):
"""Test API token generation with validation errors."""
# Test minimum validation
response = await authenticated_client.post(
"/api/v1/auth/api-token",
json={"expires_days": 0},
)
assert response.status_code == 422
# Test maximum validation
response = await authenticated_client.post(
"/api/v1/auth/api-token",
json={"expires_days": 4000},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_generate_api_token_unauthenticated(self, client: AsyncClient):
"""Test API token generation without authentication."""
response = await client.post(
"/api/v1/auth/api-token",
json={"expires_days": 30},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_api_token_status_no_token(
self, authenticated_client: AsyncClient,
):
"""Test getting API token status when user has no token."""
response = await authenticated_client.get("/api/v1/auth/api-token/status")
assert response.status_code == 200
data = response.json()
assert data["has_token"] is False
assert data["expires_at"] is None
assert data["is_expired"] is False
@pytest.mark.asyncio
async def test_get_api_token_status_with_token(
self, authenticated_client: AsyncClient,
):
"""Test getting API token status when user has a token."""
# First generate a token
await authenticated_client.post(
"/api/v1/auth/api-token",
json={"expires_days": 30},
)
# Then check status
response = await authenticated_client.get("/api/v1/auth/api-token/status")
assert response.status_code == 200
data = response.json()
assert data["has_token"] is True
assert data["expires_at"] is not None
assert data["is_expired"] is False
@pytest.mark.asyncio
async def test_get_api_token_status_expired_token(
self, authenticated_client: AsyncClient, authenticated_user: User,
):
"""Test getting API token status with expired token."""
# Mock expired token
with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True):
# Set a token on the user
authenticated_user.api_token = "expired_token"
authenticated_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1)
response = await authenticated_client.get("/api/v1/auth/api-token/status")
assert response.status_code == 200
data = response.json()
assert data["has_token"] is True
assert data["expires_at"] is not None
assert data["is_expired"] is True
@pytest.mark.asyncio
async def test_get_api_token_status_unauthenticated(self, client: AsyncClient):
"""Test getting API token status without authentication."""
response = await client.get("/api/v1/auth/api-token/status")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_revoke_api_token_success(
self, authenticated_client: AsyncClient,
):
"""Test successful API token revocation."""
# First generate a token
await authenticated_client.post(
"/api/v1/auth/api-token",
json={"expires_days": 30},
)
# Verify token exists
status_response = await authenticated_client.get("/api/v1/auth/api-token/status")
assert status_response.json()["has_token"] is True
# Revoke the token
response = await authenticated_client.delete("/api/v1/auth/api-token")
assert response.status_code == 200
data = response.json()
assert data["message"] == "API token revoked successfully"
# Verify token is gone
status_response = await authenticated_client.get("/api/v1/auth/api-token/status")
assert status_response.json()["has_token"] is False
@pytest.mark.asyncio
async def test_revoke_api_token_no_token(
self, authenticated_client: AsyncClient,
):
"""Test revoking API token when user has no token."""
response = await authenticated_client.delete("/api/v1/auth/api-token")
assert response.status_code == 200
data = response.json()
assert data["message"] == "API token revoked successfully"
@pytest.mark.asyncio
async def test_revoke_api_token_unauthenticated(self, client: AsyncClient):
"""Test revoking API token without authentication."""
response = await client.delete("/api/v1/auth/api-token")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_api_token_authentication_success(
self, client: AsyncClient, authenticated_client: AsyncClient,
):
"""Test successful authentication using API token."""
# Generate API token
token_response = await authenticated_client.post(
"/api/v1/auth/api-token",
json={"expires_days": 30},
)
api_token = token_response.json()["api_token"]
# Use API token to authenticate
headers = {"Authorization": f"Bearer {api_token}"}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert "email" in data
@pytest.mark.asyncio
async def test_api_token_authentication_invalid_token(self, client: AsyncClient):
"""Test authentication with invalid API token."""
headers = {"Authorization": "Bearer invalid_token"}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
data = response.json()
assert "Invalid API token" in data["detail"]
@pytest.mark.asyncio
async def test_api_token_authentication_expired_token(
self, client: AsyncClient, authenticated_client: AsyncClient,
):
"""Test authentication with expired API token."""
# Generate API token
token_response = await authenticated_client.post(
"/api/v1/auth/api-token",
json={"expires_days": 30},
)
api_token = token_response.json()["api_token"]
# Mock expired token
with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True):
headers = {"Authorization": f"Bearer {api_token}"}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
data = response.json()
assert "API token has expired" in data["detail"]
@pytest.mark.asyncio
async def test_api_token_authentication_malformed_header(self, client: AsyncClient):
"""Test authentication with malformed Authorization header."""
# Missing Bearer prefix
headers = {"Authorization": "invalid_format"}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
data = response.json()
assert "Invalid authorization header format" in data["detail"]
# Empty token
headers = {"Authorization": "Bearer "}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
data = response.json()
assert "API token required" in data["detail"]
@pytest.mark.asyncio
async def test_api_token_authentication_inactive_user(
self, client: AsyncClient, authenticated_client: AsyncClient, authenticated_user: User,
):
"""Test authentication with API token for inactive user."""
# Generate API token
token_response = await authenticated_client.post(
"/api/v1/auth/api-token",
json={"expires_days": 30},
)
api_token = token_response.json()["api_token"]
# Deactivate user
authenticated_user.is_active = False
# Try to authenticate with API token
headers = {"Authorization": f"Bearer {api_token}"}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
data = response.json()
assert "Account is deactivated" in data["detail"]
@pytest.mark.asyncio
async def test_flexible_authentication_prefers_api_token(
self, client: AsyncClient, authenticated_client: AsyncClient, auth_cookies: dict[str, str],
):
"""Test that flexible authentication prefers API token over cookie."""
# Generate API token
token_response = await authenticated_client.post(
"/api/v1/auth/api-token",
json={"expires_days": 30},
)
api_token = token_response.json()["api_token"]
# Set both cookies and Authorization header
client.cookies.update(auth_cookies)
headers = {"Authorization": f"Bearer {api_token}"}
# This should use API token authentication
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 200
# If it used API token auth, it should work even if cookies are invalid

View File

@@ -73,7 +73,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_duplicate_email( async def test_register_duplicate_email(
self, test_client: AsyncClient, test_user: User self, test_client: AsyncClient, test_user: User,
) -> None: ) -> None:
"""Test registration with duplicate email.""" """Test registration with duplicate email."""
user_data = { user_data = {
@@ -118,7 +118,7 @@ class TestAuthEndpoints:
async def test_register_missing_fields(self, test_client: AsyncClient) -> None: async def test_register_missing_fields(self, test_client: AsyncClient) -> None:
"""Test registration with missing fields.""" """Test registration with missing fields."""
user_data = { user_data = {
"email": "test@example.com" "email": "test@example.com",
# Missing password and name # Missing password and name
} }
@@ -128,7 +128,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_success( async def test_login_success(
self, test_client: AsyncClient, test_user: User, test_login_data: dict[str, str] self, test_client: AsyncClient, test_user: User, test_login_data: dict[str, str],
) -> None: ) -> None:
"""Test successful user login.""" """Test successful user login."""
response = await test_client.post("/api/v1/auth/login", json=test_login_data) response = await test_client.post("/api/v1/auth/login", json=test_login_data)
@@ -161,7 +161,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_invalid_password( async def test_login_invalid_password(
self, test_client: AsyncClient, test_user: User self, test_client: AsyncClient, test_user: User,
) -> None: ) -> None:
"""Test login with invalid password.""" """Test login with invalid password."""
login_data = {"email": test_user.email, "password": "wrongpassword"} login_data = {"email": test_user.email, "password": "wrongpassword"}
@@ -183,7 +183,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_success( async def test_get_current_user_success(
self, test_client: AsyncClient, test_user: User, auth_cookies: dict[str, str] self, test_client: AsyncClient, test_user: User, auth_cookies: dict[str, str],
) -> None: ) -> None:
"""Test getting current user info successfully.""" """Test getting current user info successfully."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -210,7 +210,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_invalid_token( async def test_get_current_user_invalid_token(
self, test_client: AsyncClient self, test_client: AsyncClient,
) -> None: ) -> None:
"""Test getting current user with invalid token.""" """Test getting current user with invalid token."""
# Set invalid cookies on client instance # Set invalid cookies on client instance
@@ -223,7 +223,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_expired_token( async def test_get_current_user_expired_token(
self, test_client: AsyncClient, test_user: User self, test_client: AsyncClient, test_user: User,
) -> None: ) -> None:
"""Test getting current user with expired token.""" """Test getting current user with expired token."""
from datetime import timedelta from datetime import timedelta
@@ -237,7 +237,7 @@ class TestAuthEndpoints:
"role": "user", "role": "user",
} }
expired_token = JWTUtils.create_access_token( expired_token = JWTUtils.create_access_token(
token_data, expires_delta=timedelta(seconds=-1) token_data, expires_delta=timedelta(seconds=-1),
) )
# Set expired cookies on client instance # Set expired cookies on client instance
@@ -262,7 +262,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_access_with_user_role( async def test_admin_access_with_user_role(
self, test_client: AsyncClient, auth_cookies: dict[str, str] self, test_client: AsyncClient, auth_cookies: dict[str, str],
) -> None: ) -> None:
"""Test that regular users cannot access admin endpoints.""" """Test that regular users cannot access admin endpoints."""
# This test would be for admin-only endpoints when they're created # This test would be for admin-only endpoints when they're created
@@ -293,7 +293,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_access_with_admin_role( async def test_admin_access_with_admin_role(
self, test_client: AsyncClient, admin_cookies: dict[str, str] self, test_client: AsyncClient, admin_cookies: dict[str, str],
) -> None: ) -> None:
"""Test that admin users can access admin endpoints.""" """Test that admin users can access admin endpoints."""
from app.core.dependencies import get_admin_user from app.core.dependencies import get_admin_user
@@ -357,7 +357,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_authorize_invalid_provider( async def test_oauth_authorize_invalid_provider(
self, test_client: AsyncClient self, test_client: AsyncClient,
) -> None: ) -> None:
"""Test OAuth authorization with invalid provider.""" """Test OAuth authorization with invalid provider."""
response = await test_client.get("/api/v1/auth/invalid/authorize") response = await test_client.get("/api/v1/auth/invalid/authorize")
@@ -368,7 +368,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_callback_new_user( async def test_oauth_callback_new_user(
self, test_client: AsyncClient, ensure_plans: tuple[Any, Any] self, test_client: AsyncClient, ensure_plans: tuple[Any, Any],
) -> None: ) -> None:
"""Test OAuth callback for new user creation.""" """Test OAuth callback for new user creation."""
# Mock OAuth user info # Mock OAuth user info
@@ -400,7 +400,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_callback_existing_user_link( async def test_oauth_callback_existing_user_link(
self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any] self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any],
) -> None: ) -> None:
"""Test OAuth callback for linking to existing user.""" """Test OAuth callback for linking to existing user."""
# Mock OAuth user info with same email as test user # Mock OAuth user info with same email as test user
@@ -442,7 +442,7 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_callback_invalid_provider( async def test_oauth_callback_invalid_provider(
self, test_client: AsyncClient self, test_client: AsyncClient,
) -> None: ) -> None:
"""Test OAuth callback with invalid provider.""" """Test OAuth callback with invalid provider."""
response = await test_client.get( response = await test_client.get(

View File

@@ -1,8 +1,9 @@
"""Tests for socket API endpoints.""" """Tests for socket API endpoints."""
from unittest.mock import AsyncMock, patch
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from unittest.mock import AsyncMock, patch
from app.models.user import User from app.models.user import User
@@ -10,7 +11,7 @@ from app.models.user import User
@pytest.fixture @pytest.fixture
def mock_socket_manager(): def mock_socket_manager():
"""Mock socket manager for testing.""" """Mock socket manager for testing."""
with patch('app.api.v1.socket.socket_manager') as mock: with patch("app.api.v1.socket.socket_manager") as mock:
mock.get_connected_users.return_value = ["1", "2", "3"] mock.get_connected_users.return_value = ["1", "2", "3"]
mock.send_to_user = AsyncMock(return_value=True) mock.send_to_user = AsyncMock(return_value=True)
mock.broadcast_to_all = AsyncMock() mock.broadcast_to_all = AsyncMock()
@@ -24,10 +25,10 @@ class TestSocketEndpoints:
async def test_get_socket_status_authenticated(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): async def test_get_socket_status_authenticated(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager):
"""Test getting socket status for authenticated user.""" """Test getting socket status for authenticated user."""
response = await authenticated_client.get("/api/v1/socket/status") response = await authenticated_client.get("/api/v1/socket/status")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "connected" in data assert "connected" in data
assert "user_id" in data assert "user_id" in data
assert "total_connected" in data assert "total_connected" in data
@@ -46,19 +47,19 @@ class TestSocketEndpoints:
"""Test sending message to specific user successfully.""" """Test sending message to specific user successfully."""
target_user_id = 2 target_user_id = 2
message = "Hello there!" message = "Hello there!"
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/socket/send-message", "/api/v1/socket/send-message",
params={"target_user_id": target_user_id, "message": message} params={"target_user_id": target_user_id, "message": message},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
assert data["target_user_id"] == target_user_id assert data["target_user_id"] == target_user_id
assert data["message"] == "Message sent" assert data["message"] == "Message sent"
# Verify socket manager was called correctly # Verify socket manager was called correctly
mock_socket_manager.send_to_user.assert_called_once_with( mock_socket_manager.send_to_user.assert_called_once_with(
str(target_user_id), str(target_user_id),
@@ -67,7 +68,7 @@ class TestSocketEndpoints:
"from_user_id": authenticated_user.id, "from_user_id": authenticated_user.id,
"from_user_name": authenticated_user.name, "from_user_name": authenticated_user.name,
"message": message, "message": message,
} },
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -75,18 +76,18 @@ class TestSocketEndpoints:
"""Test sending message to user who is not connected.""" """Test sending message to user who is not connected."""
target_user_id = 999 target_user_id = 999
message = "Hello there!" message = "Hello there!"
# Mock user not connected # Mock user not connected
mock_socket_manager.send_to_user.return_value = False mock_socket_manager.send_to_user.return_value = False
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/socket/send-message", "/api/v1/socket/send-message",
params={"target_user_id": target_user_id, "message": message} params={"target_user_id": target_user_id, "message": message},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is False assert data["success"] is False
assert data["target_user_id"] == target_user_id assert data["target_user_id"] == target_user_id
assert data["message"] == "User not connected" assert data["message"] == "User not connected"
@@ -96,7 +97,7 @@ class TestSocketEndpoints:
"""Test sending message without authentication.""" """Test sending message without authentication."""
response = await client.post( response = await client.post(
"/api/v1/socket/send-message", "/api/v1/socket/send-message",
params={"target_user_id": 1, "message": "test"} params={"target_user_id": 1, "message": "test"},
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -104,18 +105,18 @@ class TestSocketEndpoints:
async def test_broadcast_message_success(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): async def test_broadcast_message_success(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager):
"""Test broadcasting message to all users successfully.""" """Test broadcasting message to all users successfully."""
message = "Important announcement!" message = "Important announcement!"
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/socket/broadcast", "/api/v1/socket/broadcast",
params={"message": message} params={"message": message},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
assert data["message"] == "Message broadcasted to all users" assert data["message"] == "Message broadcasted to all users"
# Verify socket manager was called correctly # Verify socket manager was called correctly
mock_socket_manager.broadcast_to_all.assert_called_once_with( mock_socket_manager.broadcast_to_all.assert_called_once_with(
"broadcast_message", "broadcast_message",
@@ -123,7 +124,7 @@ class TestSocketEndpoints:
"from_user_id": authenticated_user.id, "from_user_id": authenticated_user.id,
"from_user_name": authenticated_user.name, "from_user_name": authenticated_user.name,
"message": message, "message": message,
} },
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -131,7 +132,7 @@ class TestSocketEndpoints:
"""Test broadcasting message without authentication.""" """Test broadcasting message without authentication."""
response = await client.post( response = await client.post(
"/api/v1/socket/broadcast", "/api/v1/socket/broadcast",
params={"message": "test"} params={"message": "test"},
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -141,14 +142,14 @@ class TestSocketEndpoints:
# Missing target_user_id # Missing target_user_id
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/socket/send-message", "/api/v1/socket/send-message",
params={"message": "test"} params={"message": "test"},
) )
assert response.status_code == 422 assert response.status_code == 422
# Missing message # Missing message
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/socket/send-message", "/api/v1/socket/send-message",
params={"target_user_id": 1} params={"target_user_id": 1},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -163,7 +164,7 @@ class TestSocketEndpoints:
"""Test sending message with invalid user ID.""" """Test sending message with invalid user ID."""
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/socket/send-message", "/api/v1/socket/send-message",
params={"target_user_id": "invalid", "message": "test"} params={"target_user_id": "invalid", "message": "test"},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -172,14 +173,14 @@ class TestSocketEndpoints:
"""Test that socket status correctly shows if user is connected.""" """Test that socket status correctly shows if user is connected."""
# Test when user is connected # Test when user is connected
mock_socket_manager.get_connected_users.return_value = [str(authenticated_user.id), "2", "3"] mock_socket_manager.get_connected_users.return_value = [str(authenticated_user.id), "2", "3"]
response = await authenticated_client.get("/api/v1/socket/status") response = await authenticated_client.get("/api/v1/socket/status")
data = response.json() data = response.json()
assert data["connected"] is True assert data["connected"] is True
# Test when user is not connected # Test when user is not connected
mock_socket_manager.get_connected_users.return_value = ["2", "3", "4"] mock_socket_manager.get_connected_users.return_value = ["2", "3", "4"]
response = await authenticated_client.get("/api/v1/socket/status") response = await authenticated_client.get("/api/v1/socket/status")
data = response.json() data = response.json()
assert data["connected"] is False assert data["connected"] is False

View File

@@ -13,7 +13,6 @@ from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db
from app.main import create_app
from app.models.plan import Plan from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.utils.auth import JWTUtils, PasswordUtils from app.utils.auth import JWTUtils, PasswordUtils
@@ -199,7 +198,7 @@ async def ensure_plans(test_session: AsyncSession) -> tuple[Plan, Plan]:
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user( async def test_user(
test_session: AsyncSession, ensure_plans: tuple[Plan, Plan] test_session: AsyncSession, ensure_plans: tuple[Plan, Plan],
) -> User: ) -> User:
"""Create a test user.""" """Create a test user."""
user = User( user = User(
@@ -219,7 +218,7 @@ async def test_user(
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def admin_user( async def admin_user(
test_session: AsyncSession, ensure_plans: tuple[Plan, Plan] test_session: AsyncSession, ensure_plans: tuple[Plan, Plan],
) -> User: ) -> User:
"""Create a test admin user.""" """Create a test admin user."""
user = User( user = User(

View File

@@ -0,0 +1,191 @@
"""Tests for API token authentication dependencies."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock
import pytest
from fastapi import HTTPException
from app.core.dependencies import get_current_user_api_token, get_current_user_flexible
from app.models.user import User
from app.services.auth import AuthService
class TestApiTokenDependencies:
"""Test API token authentication dependencies."""
@pytest.fixture
def mock_auth_service(self):
"""Create a mock auth service."""
return AsyncMock(spec=AuthService)
@pytest.fixture
def test_user(self):
"""Create a test user."""
return User(
id=1,
email="test@example.com",
name="Test User",
role="user",
is_active=True,
plan_id=1,
credits=100,
api_token="test_api_token_123",
api_token_expires_at=datetime.now(UTC) + timedelta(days=30),
)
@pytest.mark.asyncio
async def test_get_current_user_api_token_success(
self, mock_auth_service, test_user,
):
"""Test successful API token authentication."""
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_api_token_123"
result = await get_current_user_api_token(mock_auth_service, authorization)
assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123")
@pytest.mark.asyncio
async def test_get_current_user_api_token_no_header(self, mock_auth_service):
"""Test API token authentication without Authorization header."""
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, None)
assert exc_info.value.status_code == 401
assert "Authorization header required" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_invalid_format(self, mock_auth_service):
"""Test API token authentication with invalid header format."""
authorization = "Invalid format"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401
assert "Invalid authorization header format" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_empty_token(self, mock_auth_service):
"""Test API token authentication with empty token."""
authorization = "Bearer "
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401
assert "API token required" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_invalid_token(self, mock_auth_service):
"""Test API token authentication with invalid token."""
mock_auth_service.get_user_by_api_token.return_value = None
authorization = "Bearer invalid_token"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401
assert "Invalid API token" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_expired_token(
self, mock_auth_service, test_user,
):
"""Test API token authentication with expired token."""
# Set expired token
test_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1)
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer expired_token"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401
assert "API token has expired" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_inactive_user(
self, mock_auth_service, test_user,
):
"""Test API token authentication with inactive user."""
test_user.is_active = False
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_token"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401
assert "Account is deactivated" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_service_exception(self, mock_auth_service):
"""Test API token authentication with service exception."""
mock_auth_service.get_user_by_api_token.side_effect = Exception("Database error")
authorization = "Bearer test_token"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401
assert "Could not validate API token" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_flexible_uses_api_token(
self, mock_auth_service, test_user,
):
"""Test flexible authentication uses API token when available."""
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_api_token_123"
access_token = "jwt_token"
result = await get_current_user_flexible(
mock_auth_service, access_token, authorization,
)
assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123")
@pytest.mark.asyncio
async def test_get_current_user_flexible_falls_back_to_jwt(self, mock_auth_service):
"""Test flexible authentication falls back to JWT when no API token."""
# Mock the get_current_user function (normally imported)
with pytest.raises(Exception):
# This will fail because we can't easily mock the get_current_user import
# In a real test, you'd mock the import or use dependency injection
await get_current_user_flexible(mock_auth_service, "jwt_token", None)
@pytest.mark.asyncio
async def test_api_token_no_expiry_never_expires(self, mock_auth_service, test_user):
"""Test API token with no expiry date never expires."""
test_user.api_token_expires_at = None
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_token"
result = await get_current_user_api_token(mock_auth_service, authorization)
assert result == test_user
@pytest.mark.asyncio
async def test_api_token_bearer_case_insensitive(self, mock_auth_service, test_user):
"""Test that Bearer prefix is case-sensitive (as per OAuth2 spec)."""
mock_auth_service.get_user_by_api_token.return_value = test_user
# lowercase bearer should fail
authorization = "bearer test_token"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401
assert "Invalid authorization header format" in exc_info.value.detail

View File

@@ -9,7 +9,6 @@ from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.schemas.auth import UserLoginRequest, UserRegisterRequest from app.schemas.auth import UserLoginRequest, UserRegisterRequest
from app.services.auth import AuthService from app.services.auth import AuthService
from app.utils.auth import PasswordUtils
class TestAuthService: class TestAuthService:
@@ -49,11 +48,11 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_duplicate_email( async def test_register_duplicate_email(
self, auth_service: AuthService, test_user: User self, auth_service: AuthService, test_user: User,
) -> None: ) -> None:
"""Test registration with duplicate email.""" """Test registration with duplicate email."""
request = UserRegisterRequest( request = UserRegisterRequest(
email=test_user.email, password="password123", name="Another User" email=test_user.email, password="password123", name="Another User",
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -90,7 +89,7 @@ class TestAuthService:
async def test_login_invalid_email(self, auth_service: AuthService) -> None: async def test_login_invalid_email(self, auth_service: AuthService) -> None:
"""Test login with invalid email.""" """Test login with invalid email."""
request = UserLoginRequest( request = UserLoginRequest(
email="nonexistent@example.com", password="password123" email="nonexistent@example.com", password="password123",
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -101,7 +100,7 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_invalid_password( async def test_login_invalid_password(
self, auth_service: AuthService, test_user: User self, auth_service: AuthService, test_user: User,
) -> None: ) -> None:
"""Test login with invalid password.""" """Test login with invalid password."""
request = UserLoginRequest(email=test_user.email, password="wrongpassword") request = UserLoginRequest(email=test_user.email, password="wrongpassword")
@@ -114,7 +113,7 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_inactive_user( async def test_login_inactive_user(
self, auth_service: AuthService, test_user: User, test_session: AsyncSession self, auth_service: AuthService, test_user: User, test_session: AsyncSession,
) -> None: ) -> None:
"""Test login with inactive user.""" """Test login with inactive user."""
# Store the email before deactivating # Store the email before deactivating
@@ -134,7 +133,7 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_user_without_password( async def test_login_user_without_password(
self, auth_service: AuthService, test_user: User, test_session: AsyncSession self, auth_service: AuthService, test_user: User, test_session: AsyncSession,
) -> None: ) -> None:
"""Test login with user that has no password hash.""" """Test login with user that has no password hash."""
# Store the email before removing password # Store the email before removing password
@@ -154,7 +153,7 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_success( async def test_get_current_user_success(
self, auth_service: AuthService, test_user: User self, auth_service: AuthService, test_user: User,
) -> None: ) -> None:
"""Test getting current user successfully.""" """Test getting current user successfully."""
user = await auth_service.get_current_user(test_user.id) user = await auth_service.get_current_user(test_user.id)
@@ -175,7 +174,7 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_inactive( async def test_get_current_user_inactive(
self, auth_service: AuthService, test_user: User, test_session: AsyncSession self, auth_service: AuthService, test_user: User, test_session: AsyncSession,
) -> None: ) -> None:
"""Test getting current user when user is inactive.""" """Test getting current user when user is inactive."""
# Store the user ID before deactivating # Store the user ID before deactivating
@@ -193,7 +192,7 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_access_token( async def test_create_access_token(
self, auth_service: AuthService, test_user: User self, auth_service: AuthService, test_user: User,
) -> None: ) -> None:
"""Test access token creation.""" """Test access token creation."""
token_response = auth_service._create_access_token(test_user) token_response = auth_service._create_access_token(test_user)
@@ -212,7 +211,7 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_user_response( async def test_create_user_response(
self, auth_service: AuthService, test_user: User, test_session: AsyncSession self, auth_service: AuthService, test_user: User, test_session: AsyncSession,
) -> None: ) -> None:
"""Test user response creation.""" """Test user response creation."""
# Ensure plan relationship is loaded # Ensure plan relationship is loaded

View File

@@ -1,16 +1,14 @@
"""Tests for OAuth service.""" """Tests for OAuth service."""
from typing import Any from typing import Any
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from httpx import AsyncClient
from app.services.oauth import ( from app.services.oauth import (
GitHubOAuthProvider, GitHubOAuthProvider,
GoogleOAuthProvider, GoogleOAuthProvider,
OAuthService, OAuthService,
OAuthUserInfo,
) )

View File

@@ -1,7 +1,8 @@
"""Tests for socket service.""" """Tests for socket service."""
from unittest.mock import AsyncMock, patch
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch, call
import socketio import socketio
from app.services.socket import SocketManager from app.services.socket import SocketManager
@@ -24,7 +25,7 @@ class TestSocketManager:
def test_init_creates_socket_server(self): def test_init_creates_socket_server(self):
"""Test that socket manager initializes with proper configuration.""" """Test that socket manager initializes with proper configuration."""
manager = SocketManager() manager = SocketManager()
assert manager.sio is not None assert manager.sio is not None
assert isinstance(manager.user_rooms, dict) assert isinstance(manager.user_rooms, dict)
assert isinstance(manager.socket_users, dict) assert isinstance(manager.socket_users, dict)
@@ -37,12 +38,12 @@ class TestSocketManager:
user_id = "123" user_id = "123"
room_id = "user_123" room_id = "user_123"
socket_manager.user_rooms[user_id] = room_id socket_manager.user_rooms[user_id] = room_id
event = "test_event" event = "test_event"
data = {"message": "hello"} data = {"message": "hello"}
result = await socket_manager.send_to_user(user_id, event, data) result = await socket_manager.send_to_user(user_id, event, data)
assert result is True assert result is True
mock_sio.emit.assert_called_once_with(event, data, room=room_id) mock_sio.emit.assert_called_once_with(event, data, room=room_id)
@@ -52,9 +53,9 @@ class TestSocketManager:
user_id = "999" user_id = "999"
event = "test_event" event = "test_event"
data = {"message": "hello"} data = {"message": "hello"}
result = await socket_manager.send_to_user(user_id, event, data) result = await socket_manager.send_to_user(user_id, event, data)
assert result is False assert result is False
mock_sio.emit.assert_not_called() mock_sio.emit.assert_not_called()
@@ -63,9 +64,9 @@ class TestSocketManager:
"""Test broadcasting message to all users.""" """Test broadcasting message to all users."""
event = "broadcast_event" event = "broadcast_event"
data = {"message": "announcement"} data = {"message": "announcement"}
await socket_manager.broadcast_to_all(event, data) await socket_manager.broadcast_to_all(event, data)
mock_sio.emit.assert_called_once_with(event, data) mock_sio.emit.assert_called_once_with(event, data)
def test_get_connected_users(self, socket_manager): def test_get_connected_users(self, socket_manager):
@@ -74,9 +75,9 @@ class TestSocketManager:
socket_manager.user_rooms["1"] = "user_1" socket_manager.user_rooms["1"] = "user_1"
socket_manager.user_rooms["2"] = "user_2" socket_manager.user_rooms["2"] = "user_2"
socket_manager.user_rooms["3"] = "user_3" socket_manager.user_rooms["3"] = "user_3"
connected_users = socket_manager.get_connected_users() connected_users = socket_manager.get_connected_users()
assert len(connected_users) == 3 assert len(connected_users) == 3
assert "1" in connected_users assert "1" in connected_users
assert "2" in connected_users assert "2" in connected_users
@@ -87,139 +88,139 @@ class TestSocketManager:
# Add some users # Add some users
socket_manager.user_rooms["1"] = "user_1" socket_manager.user_rooms["1"] = "user_1"
socket_manager.user_rooms["2"] = "user_2" socket_manager.user_rooms["2"] = "user_2"
room_info = socket_manager.get_room_info() room_info = socket_manager.get_room_info()
assert room_info["total_users"] == 2 assert room_info["total_users"] == 2
assert room_info["connected_users"] == ["1", "2"] assert room_info["connected_users"] == ["1", "2"]
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('app.services.socket.extract_access_token_from_cookies') @patch("app.services.socket.extract_access_token_from_cookies")
@patch('app.services.socket.JWTUtils.decode_access_token') @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_success(self, mock_decode, mock_extract_token, socket_manager, mock_sio): async def test_connect_handler_success(self, mock_decode, mock_extract_token, socket_manager, mock_sio):
"""Test successful connection with valid token.""" """Test successful connection with valid token."""
# Setup mocks # Setup mocks
mock_extract_token.return_value = "valid_token" mock_extract_token.return_value = "valid_token"
mock_decode.return_value = {"sub": "123"} mock_decode.return_value = {"sub": "123"}
# Mock environment # Mock environment
environ = {"HTTP_COOKIE": "access_token=valid_token"} environ = {"HTTP_COOKIE": "access_token=valid_token"}
# Access the connect handler directly # Access the connect handler directly
handlers = {} handlers = {}
original_event = socket_manager.sio.event original_event = socket_manager.sio.event
def mock_event(func): def mock_event(func):
handlers[func.__name__] = func handlers[func.__name__] = func
return func return func
socket_manager.sio.event = mock_event socket_manager.sio.event = mock_event
socket_manager._setup_handlers() socket_manager._setup_handlers()
# Call the connect handler # Call the connect handler
await handlers['connect']("test_sid", environ) await handlers["connect"]("test_sid", environ)
# Verify token extraction and validation # Verify token extraction and validation
mock_extract_token.assert_called_once_with("access_token=valid_token") mock_extract_token.assert_called_once_with("access_token=valid_token")
mock_decode.assert_called_once_with("valid_token") mock_decode.assert_called_once_with("valid_token")
# Verify user tracking # Verify user tracking
assert socket_manager.socket_users["test_sid"] == "123" assert socket_manager.socket_users["test_sid"] == "123"
assert socket_manager.user_rooms["123"] == "user_123" assert socket_manager.user_rooms["123"] == "user_123"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('app.services.socket.extract_access_token_from_cookies') @patch("app.services.socket.extract_access_token_from_cookies")
async def test_connect_handler_no_token(self, mock_extract_token, socket_manager, mock_sio): async def test_connect_handler_no_token(self, mock_extract_token, socket_manager, mock_sio):
"""Test connection with no access token.""" """Test connection with no access token."""
# Setup mocks # Setup mocks
mock_extract_token.return_value = None mock_extract_token.return_value = None
# Mock environment # Mock environment
environ = {"HTTP_COOKIE": ""} environ = {"HTTP_COOKIE": ""}
# Access the connect handler directly # Access the connect handler directly
handlers = {} handlers = {}
original_event = socket_manager.sio.event original_event = socket_manager.sio.event
def mock_event(func): def mock_event(func):
handlers[func.__name__] = func handlers[func.__name__] = func
return func return func
socket_manager.sio.event = mock_event socket_manager.sio.event = mock_event
socket_manager._setup_handlers() socket_manager._setup_handlers()
# Call the connect handler # Call the connect handler
await handlers['connect']("test_sid", environ) await handlers["connect"]("test_sid", environ)
# Verify disconnection # Verify disconnection
mock_sio.disconnect.assert_called_once_with("test_sid") mock_sio.disconnect.assert_called_once_with("test_sid")
# Verify no user tracking # Verify no user tracking
assert "test_sid" not in socket_manager.socket_users assert "test_sid" not in socket_manager.socket_users
assert len(socket_manager.user_rooms) == 0 assert len(socket_manager.user_rooms) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('app.services.socket.extract_access_token_from_cookies') @patch("app.services.socket.extract_access_token_from_cookies")
@patch('app.services.socket.JWTUtils.decode_access_token') @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_invalid_token(self, mock_decode, mock_extract_token, socket_manager, mock_sio): async def test_connect_handler_invalid_token(self, mock_decode, mock_extract_token, socket_manager, mock_sio):
"""Test connection with invalid token.""" """Test connection with invalid token."""
# Setup mocks # Setup mocks
mock_extract_token.return_value = "invalid_token" mock_extract_token.return_value = "invalid_token"
mock_decode.side_effect = Exception("Invalid token") mock_decode.side_effect = Exception("Invalid token")
# Mock environment # Mock environment
environ = {"HTTP_COOKIE": "access_token=invalid_token"} environ = {"HTTP_COOKIE": "access_token=invalid_token"}
# Access the connect handler directly # Access the connect handler directly
handlers = {} handlers = {}
original_event = socket_manager.sio.event original_event = socket_manager.sio.event
def mock_event(func): def mock_event(func):
handlers[func.__name__] = func handlers[func.__name__] = func
return func return func
socket_manager.sio.event = mock_event socket_manager.sio.event = mock_event
socket_manager._setup_handlers() socket_manager._setup_handlers()
# Call the connect handler # Call the connect handler
await handlers['connect']("test_sid", environ) await handlers["connect"]("test_sid", environ)
# Verify disconnection # Verify disconnection
mock_sio.disconnect.assert_called_once_with("test_sid") mock_sio.disconnect.assert_called_once_with("test_sid")
# Verify no user tracking # Verify no user tracking
assert "test_sid" not in socket_manager.socket_users assert "test_sid" not in socket_manager.socket_users
assert len(socket_manager.user_rooms) == 0 assert len(socket_manager.user_rooms) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('app.services.socket.extract_access_token_from_cookies') @patch("app.services.socket.extract_access_token_from_cookies")
@patch('app.services.socket.JWTUtils.decode_access_token') @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_missing_user_id(self, mock_decode, mock_extract_token, socket_manager, mock_sio): async def test_connect_handler_missing_user_id(self, mock_decode, mock_extract_token, socket_manager, mock_sio):
"""Test connection with token missing user ID.""" """Test connection with token missing user ID."""
# Setup mocks # Setup mocks
mock_extract_token.return_value = "token_without_user_id" mock_extract_token.return_value = "token_without_user_id"
mock_decode.return_value = {"other_field": "value"} # Missing 'sub' mock_decode.return_value = {"other_field": "value"} # Missing 'sub'
# Mock environment # Mock environment
environ = {"HTTP_COOKIE": "access_token=token_without_user_id"} environ = {"HTTP_COOKIE": "access_token=token_without_user_id"}
# Access the connect handler directly # Access the connect handler directly
handlers = {} handlers = {}
original_event = socket_manager.sio.event original_event = socket_manager.sio.event
def mock_event(func): def mock_event(func):
handlers[func.__name__] = func handlers[func.__name__] = func
return func return func
socket_manager.sio.event = mock_event socket_manager.sio.event = mock_event
socket_manager._setup_handlers() socket_manager._setup_handlers()
# Call the connect handler # Call the connect handler
await handlers['connect']("test_sid", environ) await handlers["connect"]("test_sid", environ)
# Verify disconnection # Verify disconnection
mock_sio.disconnect.assert_called_once_with("test_sid") mock_sio.disconnect.assert_called_once_with("test_sid")
# Verify no user tracking # Verify no user tracking
assert "test_sid" not in socket_manager.socket_users assert "test_sid" not in socket_manager.socket_users
assert len(socket_manager.user_rooms) == 0 assert len(socket_manager.user_rooms) == 0
@@ -230,21 +231,21 @@ class TestSocketManager:
# Setup initial state # Setup initial state
socket_manager.socket_users["test_sid"] = "123" socket_manager.socket_users["test_sid"] = "123"
socket_manager.user_rooms["123"] = "user_123" socket_manager.user_rooms["123"] = "user_123"
# Access the disconnect handler directly # Access the disconnect handler directly
handlers = {} handlers = {}
original_event = socket_manager.sio.event original_event = socket_manager.sio.event
def mock_event(func): def mock_event(func):
handlers[func.__name__] = func handlers[func.__name__] = func
return func return func
socket_manager.sio.event = mock_event socket_manager.sio.event = mock_event
socket_manager._setup_handlers() socket_manager._setup_handlers()
# Call the disconnect handler # Call the disconnect handler
await handlers['disconnect']("test_sid") await handlers["disconnect"]("test_sid")
# Verify cleanup # Verify cleanup
assert "test_sid" not in socket_manager.socket_users assert "test_sid" not in socket_manager.socket_users
assert "123" not in socket_manager.user_rooms assert "123" not in socket_manager.user_rooms
@@ -255,17 +256,17 @@ class TestSocketManager:
# Access the disconnect handler directly # Access the disconnect handler directly
handlers = {} handlers = {}
original_event = socket_manager.sio.event original_event = socket_manager.sio.event
def mock_event(func): def mock_event(func):
handlers[func.__name__] = func handlers[func.__name__] = func
return func return func
socket_manager.sio.event = mock_event socket_manager.sio.event = mock_event
socket_manager._setup_handlers() socket_manager._setup_handlers()
# Call the disconnect handler with unknown socket # Call the disconnect handler with unknown socket
await handlers['disconnect']("unknown_sid") await handlers["disconnect"]("unknown_sid")
# Should not raise any errors and state should remain clean # Should not raise any errors and state should remain clean
assert len(socket_manager.socket_users) == 0 assert len(socket_manager.socket_users) == 0
assert len(socket_manager.user_rooms) == 0 assert len(socket_manager.user_rooms) == 0

View File

@@ -1,8 +1,7 @@
"""Tests for cookie utilities.""" """Tests for cookie utilities."""
import pytest
from app.utils.cookies import parse_cookies, extract_access_token_from_cookies from app.utils.cookies import extract_access_token_from_cookies, parse_cookies
class TestParseCookies: class TestParseCookies:
@@ -22,18 +21,18 @@ class TestParseCookies:
"""Test parsing single cookie.""" """Test parsing single cookie."""
cookie_header = "session_id=abc123" cookie_header = "session_id=abc123"
result = parse_cookies(cookie_header) result = parse_cookies(cookie_header)
assert result == {"session_id": "abc123"} assert result == {"session_id": "abc123"}
def test_parse_multiple_cookies(self): def test_parse_multiple_cookies(self):
"""Test parsing multiple cookies.""" """Test parsing multiple cookies."""
cookie_header = "session_id=abc123; user_pref=dark_mode; lang=en" cookie_header = "session_id=abc123; user_pref=dark_mode; lang=en"
result = parse_cookies(cookie_header) result = parse_cookies(cookie_header)
expected = { expected = {
"session_id": "abc123", "session_id": "abc123",
"user_pref": "dark_mode", "user_pref": "dark_mode",
"lang": "en" "lang": "en",
} }
assert result == expected assert result == expected
@@ -41,10 +40,10 @@ class TestParseCookies:
"""Test parsing cookies with extra spaces.""" """Test parsing cookies with extra spaces."""
cookie_header = " session_id = abc123 ; user_pref = dark_mode " cookie_header = " session_id = abc123 ; user_pref = dark_mode "
result = parse_cookies(cookie_header) result = parse_cookies(cookie_header)
expected = { expected = {
"session_id": "abc123", "session_id": "abc123",
"user_pref": "dark_mode" "user_pref": "dark_mode",
} }
assert result == expected assert result == expected
@@ -52,10 +51,10 @@ class TestParseCookies:
"""Test parsing cookies where value contains equals sign.""" """Test parsing cookies where value contains equals sign."""
cookie_header = "encoded_data=key=value&other=data; session=123" cookie_header = "encoded_data=key=value&other=data; session=123"
result = parse_cookies(cookie_header) result = parse_cookies(cookie_header)
expected = { expected = {
"encoded_data": "key=value&other=data", "encoded_data": "key=value&other=data",
"session": "123" "session": "123",
} }
assert result == expected assert result == expected
@@ -63,11 +62,11 @@ class TestParseCookies:
"""Test parsing malformed cookies (no equals sign).""" """Test parsing malformed cookies (no equals sign)."""
cookie_header = "session_id=abc123; malformed_cookie; user_pref=dark" cookie_header = "session_id=abc123; malformed_cookie; user_pref=dark"
result = parse_cookies(cookie_header) result = parse_cookies(cookie_header)
# Should skip malformed cookie and parse valid ones # Should skip malformed cookie and parse valid ones
expected = { expected = {
"session_id": "abc123", "session_id": "abc123",
"user_pref": "dark" "user_pref": "dark",
} }
assert result == expected assert result == expected
@@ -75,10 +74,10 @@ class TestParseCookies:
"""Test parsing cookies with empty values.""" """Test parsing cookies with empty values."""
cookie_header = "empty_value=; session_id=abc123" cookie_header = "empty_value=; session_id=abc123"
result = parse_cookies(cookie_header) result = parse_cookies(cookie_header)
expected = { expected = {
"empty_value": "", "empty_value": "",
"session_id": "abc123" "session_id": "abc123",
} }
assert result == expected assert result == expected
@@ -86,7 +85,7 @@ class TestParseCookies:
"""Test parsing cookies with duplicate names (last one wins).""" """Test parsing cookies with duplicate names (last one wins)."""
cookie_header = "session_id=first; session_id=second" cookie_header = "session_id=first; session_id=second"
result = parse_cookies(cookie_header) result = parse_cookies(cookie_header)
assert result == {"session_id": "second"} assert result == {"session_id": "second"}
@@ -97,14 +96,14 @@ class TestExtractAccessTokenFromCookies:
"""Test extracting access token when present.""" """Test extracting access token when present."""
cookie_header = "session_id=abc123; access_token=jwt_token_here; user_pref=dark" cookie_header = "session_id=abc123; access_token=jwt_token_here; user_pref=dark"
result = extract_access_token_from_cookies(cookie_header) result = extract_access_token_from_cookies(cookie_header)
assert result == "jwt_token_here" assert result == "jwt_token_here"
def test_extract_access_token_not_present(self): def test_extract_access_token_not_present(self):
"""Test extracting access token when not present.""" """Test extracting access token when not present."""
cookie_header = "session_id=abc123; user_pref=dark" cookie_header = "session_id=abc123; user_pref=dark"
result = extract_access_token_from_cookies(cookie_header) result = extract_access_token_from_cookies(cookie_header)
assert result is None assert result is None
def test_extract_access_token_empty_header(self): def test_extract_access_token_empty_header(self):
@@ -116,21 +115,21 @@ class TestExtractAccessTokenFromCookies:
"""Test extracting access token when it's the only cookie.""" """Test extracting access token when it's the only cookie."""
cookie_header = "access_token=my_jwt_token" cookie_header = "access_token=my_jwt_token"
result = extract_access_token_from_cookies(cookie_header) result = extract_access_token_from_cookies(cookie_header)
assert result == "my_jwt_token" assert result == "my_jwt_token"
def test_extract_access_token_with_spaces(self): def test_extract_access_token_with_spaces(self):
"""Test extracting access token with spaces around values.""" """Test extracting access token with spaces around values."""
cookie_header = " access_token = jwt_token_with_spaces ; other=value " cookie_header = " access_token = jwt_token_with_spaces ; other=value "
result = extract_access_token_from_cookies(cookie_header) result = extract_access_token_from_cookies(cookie_header)
assert result == "jwt_token_with_spaces" assert result == "jwt_token_with_spaces"
def test_extract_access_token_empty_value(self): def test_extract_access_token_empty_value(self):
"""Test extracting access token with empty value.""" """Test extracting access token with empty value."""
cookie_header = "access_token=; other=value" cookie_header = "access_token=; other=value"
result = extract_access_token_from_cookies(cookie_header) result = extract_access_token_from_cookies(cookie_header)
assert result == "" assert result == ""
def test_extract_access_token_complex_value(self): def test_extract_access_token_complex_value(self):
@@ -138,12 +137,12 @@ class TestExtractAccessTokenFromCookies:
jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjMiLCJleHAiOjE2MzM5NjY0MDB9.signature" jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjMiLCJleHAiOjE2MzM5NjY0MDB9.signature"
cookie_header = f"session=abc; access_token={jwt_token}; csrf=token" cookie_header = f"session=abc; access_token={jwt_token}; csrf=token"
result = extract_access_token_from_cookies(cookie_header) result = extract_access_token_from_cookies(cookie_header)
assert result == jwt_token assert result == jwt_token
def test_extract_access_token_multiple_equals(self): def test_extract_access_token_multiple_equals(self):
"""Test extracting access token when value contains equals signs.""" """Test extracting access token when value contains equals signs."""
cookie_header = "access_token=encoded=data=here; other=simple" cookie_header = "access_token=encoded=data=here; other=simple"
result = extract_access_token_from_cookies(cookie_header) result = extract_access_token_from_cookies(cookie_header)
assert result == "encoded=data=here" assert result == "encoded=data=here"

View File

@@ -0,0 +1,80 @@
"""Tests for token utilities."""
from datetime import UTC, datetime, timedelta
from app.utils.auth import TokenUtils
class TestTokenUtils:
"""Test token utility functions."""
def test_generate_api_token(self):
"""Test API token generation."""
token = TokenUtils.generate_api_token()
# Should be a string
assert isinstance(token, str)
# Should not be empty
assert len(token) > 0
# Should be URL-safe base64 (43 characters for 32 bytes)
assert len(token) == 43
# Should be unique (generate multiple and check they're different)
tokens = [TokenUtils.generate_api_token() for _ in range(10)]
assert len(set(tokens)) == 10
def test_is_token_expired_none(self):
"""Test token expiration check with None expires_at."""
result = TokenUtils.is_token_expired(None)
assert result is False
def test_is_token_expired_future_naive(self):
"""Test token expiration check with future naive datetime."""
# Use UTC time for naive datetime (as the function assumes)
expires_at = datetime.utcnow() + timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is False
def test_is_token_expired_past_naive(self):
"""Test token expiration check with past naive datetime."""
# Use UTC time for naive datetime (as the function assumes)
expires_at = datetime.utcnow() - timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is True
def test_is_token_expired_future_aware(self):
"""Test token expiration check with future timezone-aware datetime."""
expires_at = datetime.now(UTC) + timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is False
def test_is_token_expired_past_aware(self):
"""Test token expiration check with past timezone-aware datetime."""
expires_at = datetime.now(UTC) - timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is True
def test_is_token_expired_edge_case_now(self):
"""Test token expiration check with time very close to now."""
# Token expires in 1 second
expires_at = datetime.now(UTC) + timedelta(seconds=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is False
# Token expired 1 second ago
expires_at = datetime.now(UTC) - timedelta(seconds=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is True
def test_is_token_expired_timezone_conversion(self):
"""Test token expiration check with different timezone."""
from zoneinfo import ZoneInfo
# Create a datetime in a different timezone
eastern = ZoneInfo("US/Eastern")
expires_at = datetime.now(eastern) + timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is False