feat: Update API token handling to use API-TOKEN header and improve related tests

This commit is contained in:
JSC
2025-07-27 22:15:23 +02:00
parent 3dc21337f9
commit 58030914e6
5 changed files with 80 additions and 85 deletions

View File

@@ -94,26 +94,19 @@ async def get_current_active_user(
async def get_current_user_api_token( async def get_current_user_api_token(
auth_service: Annotated[AuthService, Depends(get_auth_service)], auth_service: Annotated[AuthService, Depends(get_auth_service)],
authorization: Annotated[str | None, Header()] = None, api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None,
) -> User: ) -> User:
"""Get the current authenticated user from API token in Authorization header.""" """Get the current authenticated user from API token in API-TOKEN header."""
try: try:
# Check if Authorization header exists # Check if API-TOKEN header exists
if not authorization: if not api_token_header:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header required", detail="API-TOKEN header required",
) )
# Check if it's a Bearer token # Use the API token directly
if not authorization.startswith("Bearer "): api_token = api_token_header.strip()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization header format",
)
# Extract the API token
api_token = authorization[7:] # Remove "Bearer " prefix
if not api_token: if not api_token:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -158,12 +151,12 @@ async def get_current_user_api_token(
async def get_current_user_flexible( async def get_current_user_flexible(
auth_service: Annotated[AuthService, Depends(get_auth_service)], auth_service: Annotated[AuthService, Depends(get_auth_service)],
access_token: Annotated[str | None, Cookie()] = None, access_token: Annotated[str | None, Cookie()] = None,
authorization: Annotated[str | None, Header()] = None, api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None,
) -> User: ) -> User:
"""Get the current authenticated user from either JWT cookie or API token.""" """Get the current authenticated user from either JWT cookie or API token."""
# Try API token first if Authorization header is present # Try API token first if API-TOKEN header is present
if authorization: if api_token_header:
return await get_current_user_api_token(auth_service, authorization) return await get_current_user_api_token(auth_service, api_token_header)
# Fall back to JWT cookie authentication # Fall back to JWT cookie authentication
return await get_current_user(auth_service, access_token) return await get_current_user(auth_service, access_token)

View File

@@ -11,7 +11,9 @@ class UserRegisterRequest(BaseModel):
email: EmailStr = Field(..., description="User email address") email: EmailStr = Field(..., description="User email address")
password: str = Field( password: str = Field(
..., min_length=8, description="User password (minimum 8 characters)", ...,
min_length=8,
description="User password (minimum 8 characters)",
) )
name: str = Field(..., min_length=1, max_length=100, description="User full name") name: str = Field(..., min_length=1, max_length=100, description="User full name")
@@ -68,7 +70,7 @@ class ApiTokenResponse(BaseModel):
"""Schema for API token response.""" """Schema for API token response."""
api_token: str = Field(..., description="Generated API token") api_token: str = Field(..., description="Generated API token")
expires_at: datetime = Field(..., description="Token expiration timestamp") expires_at: datetime | None = Field(None, description="Token expiration timestamp")
class ApiTokenStatusResponse(BaseModel): class ApiTokenStatusResponse(BaseModel):

View File

@@ -239,7 +239,7 @@ class TestApiTokenEndpoints:
api_token = token_response.json()["api_token"] api_token = token_response.json()["api_token"]
# Use API token to authenticate # Use API token to authenticate
headers = {"Authorization": f"Bearer {api_token}"} headers = {"API-TOKEN": api_token}
response = await client.get("/api/v1/auth/me", headers=headers) response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 200 assert response.status_code == 200
@@ -250,7 +250,7 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_authentication_invalid_token(self, client: AsyncClient): async def test_api_token_authentication_invalid_token(self, client: AsyncClient):
"""Test authentication with invalid API token.""" """Test authentication with invalid API token."""
headers = {"Authorization": "Bearer invalid_token"} headers = {"API-TOKEN": "invalid_token"}
response = await client.get("/api/v1/auth/me", headers=headers) response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401 assert response.status_code == 401
@@ -271,7 +271,7 @@ class TestApiTokenEndpoints:
# Mock expired token # Mock expired token
with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True): with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True):
headers = {"Authorization": f"Bearer {api_token}"} headers = {"API-TOKEN": api_token}
response = await client.get("/api/v1/auth/me", headers=headers) response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401 assert response.status_code == 401
@@ -279,18 +279,18 @@ class TestApiTokenEndpoints:
assert "API token has expired" in data["detail"] assert "API token has expired" in data["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_authentication_malformed_header(self, client: AsyncClient): async def test_api_token_authentication_empty_token(self, client: AsyncClient):
"""Test authentication with malformed Authorization header.""" """Test authentication with empty API-TOKEN header."""
# Missing Bearer prefix # Empty token
headers = {"Authorization": "invalid_format"} headers = {"API-TOKEN": ""}
response = await client.get("/api/v1/auth/me", headers=headers) response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401 assert response.status_code == 401
data = response.json() data = response.json()
assert "Invalid authorization header format" in data["detail"] assert "Could not validate credentials" in data["detail"]
# Empty token # Whitespace only token
headers = {"Authorization": "Bearer "} headers = {"API-TOKEN": " "}
response = await client.get("/api/v1/auth/me", headers=headers) response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401 assert response.status_code == 401
@@ -313,7 +313,7 @@ class TestApiTokenEndpoints:
authenticated_user.is_active = False authenticated_user.is_active = False
# Try to authenticate with API token # Try to authenticate with API token
headers = {"Authorization": f"Bearer {api_token}"} headers = {"API-TOKEN": api_token}
response = await client.get("/api/v1/auth/me", headers=headers) response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401 assert response.status_code == 401
@@ -332,9 +332,9 @@ class TestApiTokenEndpoints:
) )
api_token = token_response.json()["api_token"] api_token = token_response.json()["api_token"]
# Set both cookies and Authorization header # Set both cookies and API-TOKEN header
client.cookies.update(auth_cookies) client.cookies.update(auth_cookies)
headers = {"Authorization": f"Bearer {api_token}"} headers = {"API-TOKEN": api_token}
# This should use API token authentication # This should use API token authentication
response = await client.get("/api/v1/auth/me", headers=headers) response = await client.get("/api/v1/auth/me", headers=headers)

View File

@@ -41,40 +41,40 @@ class TestApiTokenDependencies:
"""Test successful API token authentication.""" """Test successful API token authentication."""
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_api_token_123" api_token_header = "test_api_token_123"
result = await get_current_user_api_token(mock_auth_service, authorization) result = await get_current_user_api_token(mock_auth_service, api_token_header)
assert result == test_user assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123") mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_no_header(self, mock_auth_service): async def test_get_current_user_api_token_no_header(self, mock_auth_service):
"""Test API token authentication without Authorization header.""" """Test API token authentication without API-TOKEN header."""
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, None) await get_current_user_api_token(mock_auth_service, None)
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
assert "Authorization header required" in exc_info.value.detail assert "API-TOKEN header required" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_invalid_format(self, mock_auth_service):
"""Test API token authentication with invalid header format."""
authorization = "Invalid format"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401
assert "Invalid authorization header format" in exc_info.value.detail
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_empty_token(self, mock_auth_service): async def test_get_current_user_api_token_empty_token(self, mock_auth_service):
"""Test API token authentication with empty token.""" """Test API token authentication with empty token."""
authorization = "Bearer " api_token_header = " "
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization) await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401
assert "API token required" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_whitespace_token(self, mock_auth_service):
"""Test API token authentication with whitespace-only token."""
api_token_header = " "
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
assert "API token required" in exc_info.value.detail assert "API token required" in exc_info.value.detail
@@ -84,10 +84,10 @@ class TestApiTokenDependencies:
"""Test API token authentication with invalid token.""" """Test API token authentication with invalid token."""
mock_auth_service.get_user_by_api_token.return_value = None mock_auth_service.get_user_by_api_token.return_value = None
authorization = "Bearer invalid_token" api_token_header = "invalid_token"
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization) await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
assert "Invalid API token" in exc_info.value.detail assert "Invalid API token" in exc_info.value.detail
@@ -101,10 +101,10 @@ class TestApiTokenDependencies:
test_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1) test_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1)
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer expired_token" api_token_header = "expired_token"
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization) await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
assert "API token has expired" in exc_info.value.detail assert "API token has expired" in exc_info.value.detail
@@ -117,10 +117,10 @@ class TestApiTokenDependencies:
test_user.is_active = False test_user.is_active = False
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_token" api_token_header = "test_token"
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization) await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
assert "Account is deactivated" in exc_info.value.detail assert "Account is deactivated" in exc_info.value.detail
@@ -130,10 +130,10 @@ class TestApiTokenDependencies:
"""Test API token authentication with service exception.""" """Test API token authentication with service exception."""
mock_auth_service.get_user_by_api_token.side_effect = Exception("Database error") mock_auth_service.get_user_by_api_token.side_effect = Exception("Database error")
authorization = "Bearer test_token" api_token_header = "test_token"
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization) await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
assert "Could not validate API token" in exc_info.value.detail assert "Could not validate API token" in exc_info.value.detail
@@ -145,11 +145,11 @@ class TestApiTokenDependencies:
"""Test flexible authentication uses API token when available.""" """Test flexible authentication uses API token when available."""
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_api_token_123" api_token_header = "test_api_token_123"
access_token = "jwt_token" access_token = "jwt_token"
result = await get_current_user_flexible( result = await get_current_user_flexible(
mock_auth_service, access_token, authorization, mock_auth_service, access_token, api_token_header,
) )
assert result == test_user assert result == test_user
@@ -170,22 +170,20 @@ class TestApiTokenDependencies:
test_user.api_token_expires_at = None test_user.api_token_expires_at = None
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_token" api_token_header = "test_token"
result = await get_current_user_api_token(mock_auth_service, authorization) result = await get_current_user_api_token(mock_auth_service, api_token_header)
assert result == test_user assert result == test_user
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_bearer_case_insensitive(self, mock_auth_service, test_user): async def test_api_token_with_whitespace(self, mock_auth_service, test_user):
"""Test that Bearer prefix is case-sensitive (as per OAuth2 spec).""" """Test API token with leading/trailing whitespace is handled correctly."""
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user
# lowercase bearer should fail api_token_header = " test_token "
authorization = "bearer test_token"
with pytest.raises(HTTPException) as exc_info: result = await get_current_user_api_token(mock_auth_service, api_token_header)
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401 assert result == test_user
assert "Invalid authorization header format" in exc_info.value.detail mock_auth_service.get_user_by_api_token.assert_called_once_with("test_token")

View File

@@ -1,14 +1,18 @@
"""Tests for token utilities.""" """Tests for token utilities."""
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from zoneinfo import ZoneInfo
from app.utils.auth import TokenUtils from app.utils.auth import TokenUtils
TOKEN_LENGTH = 43 # Length of URL-safe base64 encoded 32-byte token
UNIQUE_TOKENS_COUNT = 10 # Number of unique tokens to generate for uniqueness test
class TestTokenUtils: class TestTokenUtils:
"""Test token utility functions.""" """Test token utility functions."""
def test_generate_api_token(self): def test_generate_api_token(self) -> None:
"""Test API token generation.""" """Test API token generation."""
token = TokenUtils.generate_api_token() token = TokenUtils.generate_api_token()
@@ -19,44 +23,44 @@ class TestTokenUtils:
assert len(token) > 0 assert len(token) > 0
# Should be URL-safe base64 (43 characters for 32 bytes) # Should be URL-safe base64 (43 characters for 32 bytes)
assert len(token) == 43 assert len(token) == TOKEN_LENGTH
# Should be unique (generate multiple and check they're different) # Should be unique (generate multiple and check they're different)
tokens = [TokenUtils.generate_api_token() for _ in range(10)] tokens = [TokenUtils.generate_api_token() for _ in range(UNIQUE_TOKENS_COUNT)]
assert len(set(tokens)) == 10 assert len(set(tokens)) == UNIQUE_TOKENS_COUNT
def test_is_token_expired_none(self): def test_is_token_expired_none(self) -> None:
"""Test token expiration check with None expires_at.""" """Test token expiration check with None expires_at."""
result = TokenUtils.is_token_expired(None) result = TokenUtils.is_token_expired(None)
assert result is False assert result is False
def test_is_token_expired_future_naive(self): def test_is_token_expired_future_naive(self) -> None:
"""Test token expiration check with future naive datetime.""" """Test token expiration check with future naive datetime."""
# Use UTC time for naive datetime (as the function assumes) # Use UTC time for naive datetime (as the function assumes)
expires_at = datetime.utcnow() + timedelta(hours=1) expires_at = datetime.now(UTC) + timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at) result = TokenUtils.is_token_expired(expires_at)
assert result is False assert result is False
def test_is_token_expired_past_naive(self): def test_is_token_expired_past_naive(self) -> None:
"""Test token expiration check with past naive datetime.""" """Test token expiration check with past naive datetime."""
# Use UTC time for naive datetime (as the function assumes) # Use UTC time for naive datetime (as the function assumes)
expires_at = datetime.utcnow() - timedelta(hours=1) expires_at = datetime.now(UTC) - timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at) result = TokenUtils.is_token_expired(expires_at)
assert result is True assert result is True
def test_is_token_expired_future_aware(self): def test_is_token_expired_future_aware(self) -> None:
"""Test token expiration check with future timezone-aware datetime.""" """Test token expiration check with future timezone-aware datetime."""
expires_at = datetime.now(UTC) + timedelta(hours=1) expires_at = datetime.now(UTC) + timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at) result = TokenUtils.is_token_expired(expires_at)
assert result is False assert result is False
def test_is_token_expired_past_aware(self): def test_is_token_expired_past_aware(self) -> None:
"""Test token expiration check with past timezone-aware datetime.""" """Test token expiration check with past timezone-aware datetime."""
expires_at = datetime.now(UTC) - timedelta(hours=1) expires_at = datetime.now(UTC) - timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at) result = TokenUtils.is_token_expired(expires_at)
assert result is True assert result is True
def test_is_token_expired_edge_case_now(self): def test_is_token_expired_edge_case_now(self) -> None:
"""Test token expiration check with time very close to now.""" """Test token expiration check with time very close to now."""
# Token expires in 1 second # Token expires in 1 second
expires_at = datetime.now(UTC) + timedelta(seconds=1) expires_at = datetime.now(UTC) + timedelta(seconds=1)
@@ -68,10 +72,8 @@ class TestTokenUtils:
result = TokenUtils.is_token_expired(expires_at) result = TokenUtils.is_token_expired(expires_at)
assert result is True assert result is True
def test_is_token_expired_timezone_conversion(self): def test_is_token_expired_timezone_conversion(self) -> None:
"""Test token expiration check with different timezone.""" """Test token expiration check with different timezone."""
from zoneinfo import ZoneInfo
# Create a datetime in a different timezone # Create a datetime in a different timezone
eastern = ZoneInfo("US/Eastern") eastern = ZoneInfo("US/Eastern")
expires_at = datetime.now(eastern) + timedelta(hours=1) expires_at = datetime.now(eastern) + timedelta(hours=1)