diff --git a/app/core/dependencies.py b/app/core/dependencies.py index c26e11e..292ded3 100644 --- a/app/core/dependencies.py +++ b/app/core/dependencies.py @@ -94,26 +94,19 @@ async def get_current_active_user( async def get_current_user_api_token( auth_service: Annotated[AuthService, Depends(get_auth_service)], - authorization: Annotated[str | None, Header()] = None, + api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None, ) -> User: - """Get the current authenticated user from API token in Authorization header.""" + """Get the current authenticated user from API token in API-TOKEN header.""" try: - # Check if Authorization header exists - if not authorization: + # Check if API-TOKEN header exists + if not api_token_header: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authorization header required", + detail="API-TOKEN header required", ) - # Check if it's a Bearer token - if not authorization.startswith("Bearer "): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid authorization header format", - ) - - # Extract the API token - api_token = authorization[7:] # Remove "Bearer " prefix + # Use the API token directly + api_token = api_token_header.strip() if not api_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -158,12 +151,12 @@ async def get_current_user_api_token( async def get_current_user_flexible( auth_service: Annotated[AuthService, Depends(get_auth_service)], access_token: Annotated[str | None, Cookie()] = None, - authorization: Annotated[str | None, Header()] = None, + api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None, ) -> User: """Get the current authenticated user from either JWT cookie or API token.""" - # Try API token first if Authorization header is present - if authorization: - return await get_current_user_api_token(auth_service, authorization) + # Try API token first if API-TOKEN header is present + if api_token_header: + return await get_current_user_api_token(auth_service, api_token_header) # Fall back to JWT cookie authentication return await get_current_user(auth_service, access_token) diff --git a/app/schemas/auth.py b/app/schemas/auth.py index bb38a9c..b506803 100644 --- a/app/schemas/auth.py +++ b/app/schemas/auth.py @@ -11,7 +11,9 @@ class UserRegisterRequest(BaseModel): email: EmailStr = Field(..., description="User email address") password: str = Field( - ..., min_length=8, description="User password (minimum 8 characters)", + ..., + min_length=8, + description="User password (minimum 8 characters)", ) name: str = Field(..., min_length=1, max_length=100, description="User full name") @@ -68,7 +70,7 @@ class ApiTokenResponse(BaseModel): """Schema for API token response.""" api_token: str = Field(..., description="Generated API token") - expires_at: datetime = Field(..., description="Token expiration timestamp") + expires_at: datetime | None = Field(None, description="Token expiration timestamp") class ApiTokenStatusResponse(BaseModel): diff --git a/tests/api/v1/test_api_token_endpoints.py b/tests/api/v1/test_api_token_endpoints.py index d512ed8..09ed404 100644 --- a/tests/api/v1/test_api_token_endpoints.py +++ b/tests/api/v1/test_api_token_endpoints.py @@ -239,7 +239,7 @@ class TestApiTokenEndpoints: api_token = token_response.json()["api_token"] # Use API token to authenticate - headers = {"Authorization": f"Bearer {api_token}"} + headers = {"API-TOKEN": api_token} response = await client.get("/api/v1/auth/me", headers=headers) assert response.status_code == 200 @@ -250,7 +250,7 @@ class TestApiTokenEndpoints: @pytest.mark.asyncio async def test_api_token_authentication_invalid_token(self, client: AsyncClient): """Test authentication with invalid API token.""" - headers = {"Authorization": "Bearer invalid_token"} + headers = {"API-TOKEN": "invalid_token"} response = await client.get("/api/v1/auth/me", headers=headers) assert response.status_code == 401 @@ -271,7 +271,7 @@ class TestApiTokenEndpoints: # Mock expired token with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True): - headers = {"Authorization": f"Bearer {api_token}"} + headers = {"API-TOKEN": api_token} response = await client.get("/api/v1/auth/me", headers=headers) assert response.status_code == 401 @@ -279,18 +279,18 @@ class TestApiTokenEndpoints: assert "API token has expired" in data["detail"] @pytest.mark.asyncio - async def test_api_token_authentication_malformed_header(self, client: AsyncClient): - """Test authentication with malformed Authorization header.""" - # Missing Bearer prefix - headers = {"Authorization": "invalid_format"} + async def test_api_token_authentication_empty_token(self, client: AsyncClient): + """Test authentication with empty API-TOKEN header.""" + # Empty token + headers = {"API-TOKEN": ""} response = await client.get("/api/v1/auth/me", headers=headers) assert response.status_code == 401 data = response.json() - assert "Invalid authorization header format" in data["detail"] + assert "Could not validate credentials" in data["detail"] - # Empty token - headers = {"Authorization": "Bearer "} + # Whitespace only token + headers = {"API-TOKEN": " "} response = await client.get("/api/v1/auth/me", headers=headers) assert response.status_code == 401 @@ -313,7 +313,7 @@ class TestApiTokenEndpoints: authenticated_user.is_active = False # Try to authenticate with API token - headers = {"Authorization": f"Bearer {api_token}"} + headers = {"API-TOKEN": api_token} response = await client.get("/api/v1/auth/me", headers=headers) assert response.status_code == 401 @@ -332,9 +332,9 @@ class TestApiTokenEndpoints: ) api_token = token_response.json()["api_token"] - # Set both cookies and Authorization header + # Set both cookies and API-TOKEN header client.cookies.update(auth_cookies) - headers = {"Authorization": f"Bearer {api_token}"} + headers = {"API-TOKEN": api_token} # This should use API token authentication response = await client.get("/api/v1/auth/me", headers=headers) diff --git a/tests/core/test_api_token_dependencies.py b/tests/core/test_api_token_dependencies.py index 5ca1026..29d3d2f 100644 --- a/tests/core/test_api_token_dependencies.py +++ b/tests/core/test_api_token_dependencies.py @@ -41,40 +41,40 @@ class TestApiTokenDependencies: """Test successful API token authentication.""" mock_auth_service.get_user_by_api_token.return_value = test_user - authorization = "Bearer test_api_token_123" + api_token_header = "test_api_token_123" - result = await get_current_user_api_token(mock_auth_service, authorization) + result = await get_current_user_api_token(mock_auth_service, api_token_header) assert result == test_user mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123") @pytest.mark.asyncio async def test_get_current_user_api_token_no_header(self, mock_auth_service): - """Test API token authentication without Authorization header.""" + """Test API token authentication without API-TOKEN header.""" with pytest.raises(HTTPException) as exc_info: await get_current_user_api_token(mock_auth_service, None) assert exc_info.value.status_code == 401 - assert "Authorization header required" in exc_info.value.detail - - @pytest.mark.asyncio - async def test_get_current_user_api_token_invalid_format(self, mock_auth_service): - """Test API token authentication with invalid header format.""" - authorization = "Invalid format" - - with pytest.raises(HTTPException) as exc_info: - await get_current_user_api_token(mock_auth_service, authorization) - - assert exc_info.value.status_code == 401 - assert "Invalid authorization header format" in exc_info.value.detail + assert "API-TOKEN header required" in exc_info.value.detail @pytest.mark.asyncio async def test_get_current_user_api_token_empty_token(self, mock_auth_service): """Test API token authentication with empty token.""" - authorization = "Bearer " + api_token_header = " " with pytest.raises(HTTPException) as exc_info: - await get_current_user_api_token(mock_auth_service, authorization) + await get_current_user_api_token(mock_auth_service, api_token_header) + + assert exc_info.value.status_code == 401 + assert "API token required" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_get_current_user_api_token_whitespace_token(self, mock_auth_service): + """Test API token authentication with whitespace-only token.""" + api_token_header = " " + + with pytest.raises(HTTPException) as exc_info: + await get_current_user_api_token(mock_auth_service, api_token_header) assert exc_info.value.status_code == 401 assert "API token required" in exc_info.value.detail @@ -84,10 +84,10 @@ class TestApiTokenDependencies: """Test API token authentication with invalid token.""" mock_auth_service.get_user_by_api_token.return_value = None - authorization = "Bearer invalid_token" + api_token_header = "invalid_token" with pytest.raises(HTTPException) as exc_info: - await get_current_user_api_token(mock_auth_service, authorization) + await get_current_user_api_token(mock_auth_service, api_token_header) assert exc_info.value.status_code == 401 assert "Invalid API token" in exc_info.value.detail @@ -101,10 +101,10 @@ class TestApiTokenDependencies: test_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1) mock_auth_service.get_user_by_api_token.return_value = test_user - authorization = "Bearer expired_token" + api_token_header = "expired_token" with pytest.raises(HTTPException) as exc_info: - await get_current_user_api_token(mock_auth_service, authorization) + await get_current_user_api_token(mock_auth_service, api_token_header) assert exc_info.value.status_code == 401 assert "API token has expired" in exc_info.value.detail @@ -117,10 +117,10 @@ class TestApiTokenDependencies: test_user.is_active = False mock_auth_service.get_user_by_api_token.return_value = test_user - authorization = "Bearer test_token" + api_token_header = "test_token" with pytest.raises(HTTPException) as exc_info: - await get_current_user_api_token(mock_auth_service, authorization) + await get_current_user_api_token(mock_auth_service, api_token_header) assert exc_info.value.status_code == 401 assert "Account is deactivated" in exc_info.value.detail @@ -130,10 +130,10 @@ class TestApiTokenDependencies: """Test API token authentication with service exception.""" mock_auth_service.get_user_by_api_token.side_effect = Exception("Database error") - authorization = "Bearer test_token" + api_token_header = "test_token" with pytest.raises(HTTPException) as exc_info: - await get_current_user_api_token(mock_auth_service, authorization) + await get_current_user_api_token(mock_auth_service, api_token_header) assert exc_info.value.status_code == 401 assert "Could not validate API token" in exc_info.value.detail @@ -145,11 +145,11 @@ class TestApiTokenDependencies: """Test flexible authentication uses API token when available.""" mock_auth_service.get_user_by_api_token.return_value = test_user - authorization = "Bearer test_api_token_123" + api_token_header = "test_api_token_123" access_token = "jwt_token" result = await get_current_user_flexible( - mock_auth_service, access_token, authorization, + mock_auth_service, access_token, api_token_header, ) assert result == test_user @@ -170,22 +170,20 @@ class TestApiTokenDependencies: test_user.api_token_expires_at = None mock_auth_service.get_user_by_api_token.return_value = test_user - authorization = "Bearer test_token" + api_token_header = "test_token" - result = await get_current_user_api_token(mock_auth_service, authorization) + result = await get_current_user_api_token(mock_auth_service, api_token_header) assert result == test_user @pytest.mark.asyncio - async def test_api_token_bearer_case_insensitive(self, mock_auth_service, test_user): - """Test that Bearer prefix is case-sensitive (as per OAuth2 spec).""" + async def test_api_token_with_whitespace(self, mock_auth_service, test_user): + """Test API token with leading/trailing whitespace is handled correctly.""" mock_auth_service.get_user_by_api_token.return_value = test_user - # lowercase bearer should fail - authorization = "bearer test_token" + api_token_header = " test_token " - with pytest.raises(HTTPException) as exc_info: - await get_current_user_api_token(mock_auth_service, authorization) + result = await get_current_user_api_token(mock_auth_service, api_token_header) - assert exc_info.value.status_code == 401 - assert "Invalid authorization header format" in exc_info.value.detail + assert result == test_user + mock_auth_service.get_user_by_api_token.assert_called_once_with("test_token") diff --git a/tests/utils/test_token_utils.py b/tests/utils/test_token_utils.py index fa723b6..5ef485b 100644 --- a/tests/utils/test_token_utils.py +++ b/tests/utils/test_token_utils.py @@ -1,14 +1,18 @@ """Tests for token utilities.""" from datetime import UTC, datetime, timedelta +from zoneinfo import ZoneInfo from app.utils.auth import TokenUtils +TOKEN_LENGTH = 43 # Length of URL-safe base64 encoded 32-byte token +UNIQUE_TOKENS_COUNT = 10 # Number of unique tokens to generate for uniqueness test + class TestTokenUtils: """Test token utility functions.""" - def test_generate_api_token(self): + def test_generate_api_token(self) -> None: """Test API token generation.""" token = TokenUtils.generate_api_token() @@ -19,44 +23,44 @@ class TestTokenUtils: assert len(token) > 0 # Should be URL-safe base64 (43 characters for 32 bytes) - assert len(token) == 43 + assert len(token) == TOKEN_LENGTH # Should be unique (generate multiple and check they're different) - tokens = [TokenUtils.generate_api_token() for _ in range(10)] - assert len(set(tokens)) == 10 + tokens = [TokenUtils.generate_api_token() for _ in range(UNIQUE_TOKENS_COUNT)] + assert len(set(tokens)) == UNIQUE_TOKENS_COUNT - def test_is_token_expired_none(self): + def test_is_token_expired_none(self) -> None: """Test token expiration check with None expires_at.""" result = TokenUtils.is_token_expired(None) assert result is False - def test_is_token_expired_future_naive(self): + def test_is_token_expired_future_naive(self) -> None: """Test token expiration check with future naive datetime.""" # Use UTC time for naive datetime (as the function assumes) - expires_at = datetime.utcnow() + timedelta(hours=1) + expires_at = datetime.now(UTC) + timedelta(hours=1) result = TokenUtils.is_token_expired(expires_at) assert result is False - def test_is_token_expired_past_naive(self): + def test_is_token_expired_past_naive(self) -> None: """Test token expiration check with past naive datetime.""" # Use UTC time for naive datetime (as the function assumes) - expires_at = datetime.utcnow() - timedelta(hours=1) + expires_at = datetime.now(UTC) - timedelta(hours=1) result = TokenUtils.is_token_expired(expires_at) assert result is True - def test_is_token_expired_future_aware(self): + def test_is_token_expired_future_aware(self) -> None: """Test token expiration check with future timezone-aware datetime.""" expires_at = datetime.now(UTC) + timedelta(hours=1) result = TokenUtils.is_token_expired(expires_at) assert result is False - def test_is_token_expired_past_aware(self): + def test_is_token_expired_past_aware(self) -> None: """Test token expiration check with past timezone-aware datetime.""" expires_at = datetime.now(UTC) - timedelta(hours=1) result = TokenUtils.is_token_expired(expires_at) assert result is True - def test_is_token_expired_edge_case_now(self): + def test_is_token_expired_edge_case_now(self) -> None: """Test token expiration check with time very close to now.""" # Token expires in 1 second expires_at = datetime.now(UTC) + timedelta(seconds=1) @@ -68,10 +72,8 @@ class TestTokenUtils: result = TokenUtils.is_token_expired(expires_at) assert result is True - def test_is_token_expired_timezone_conversion(self): + def test_is_token_expired_timezone_conversion(self) -> None: """Test token expiration check with different timezone.""" - from zoneinfo import ZoneInfo - # Create a datetime in a different timezone eastern = ZoneInfo("US/Eastern") expires_at = datetime.now(eastern) + timedelta(hours=1)