feat: Update API token handling to use API-TOKEN header and improve related tests
This commit is contained in:
@@ -94,26 +94,19 @@ async def get_current_active_user(
|
||||
|
||||
async def get_current_user_api_token(
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None,
|
||||
) -> User:
|
||||
"""Get the current authenticated user from API token in Authorization header."""
|
||||
"""Get the current authenticated user from API token in API-TOKEN header."""
|
||||
try:
|
||||
# Check if Authorization header exists
|
||||
if not authorization:
|
||||
# Check if API-TOKEN header exists
|
||||
if not api_token_header:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required",
|
||||
detail="API-TOKEN header required",
|
||||
)
|
||||
|
||||
# Check if it's a Bearer token
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authorization header format",
|
||||
)
|
||||
|
||||
# Extract the API token
|
||||
api_token = authorization[7:] # Remove "Bearer " prefix
|
||||
# Use the API token directly
|
||||
api_token = api_token_header.strip()
|
||||
if not api_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -158,12 +151,12 @@ async def get_current_user_api_token(
|
||||
async def get_current_user_flexible(
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
access_token: Annotated[str | None, Cookie()] = None,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None,
|
||||
) -> User:
|
||||
"""Get the current authenticated user from either JWT cookie or API token."""
|
||||
# Try API token first if Authorization header is present
|
||||
if authorization:
|
||||
return await get_current_user_api_token(auth_service, authorization)
|
||||
# Try API token first if API-TOKEN header is present
|
||||
if api_token_header:
|
||||
return await get_current_user_api_token(auth_service, api_token_header)
|
||||
|
||||
# Fall back to JWT cookie authentication
|
||||
return await get_current_user(auth_service, access_token)
|
||||
|
||||
@@ -11,7 +11,9 @@ class UserRegisterRequest(BaseModel):
|
||||
|
||||
email: EmailStr = Field(..., description="User email address")
|
||||
password: str = Field(
|
||||
..., min_length=8, description="User password (minimum 8 characters)",
|
||||
...,
|
||||
min_length=8,
|
||||
description="User password (minimum 8 characters)",
|
||||
)
|
||||
name: str = Field(..., min_length=1, max_length=100, description="User full name")
|
||||
|
||||
@@ -68,7 +70,7 @@ class ApiTokenResponse(BaseModel):
|
||||
"""Schema for API token response."""
|
||||
|
||||
api_token: str = Field(..., description="Generated API token")
|
||||
expires_at: datetime = Field(..., description="Token expiration timestamp")
|
||||
expires_at: datetime | None = Field(None, description="Token expiration timestamp")
|
||||
|
||||
|
||||
class ApiTokenStatusResponse(BaseModel):
|
||||
|
||||
@@ -239,7 +239,7 @@ class TestApiTokenEndpoints:
|
||||
api_token = token_response.json()["api_token"]
|
||||
|
||||
# Use API token to authenticate
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
headers = {"API-TOKEN": api_token}
|
||||
response = await client.get("/api/v1/auth/me", headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -250,7 +250,7 @@ class TestApiTokenEndpoints:
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_token_authentication_invalid_token(self, client: AsyncClient):
|
||||
"""Test authentication with invalid API token."""
|
||||
headers = {"Authorization": "Bearer invalid_token"}
|
||||
headers = {"API-TOKEN": "invalid_token"}
|
||||
response = await client.get("/api/v1/auth/me", headers=headers)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -271,7 +271,7 @@ class TestApiTokenEndpoints:
|
||||
|
||||
# Mock expired token
|
||||
with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True):
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
headers = {"API-TOKEN": api_token}
|
||||
response = await client.get("/api/v1/auth/me", headers=headers)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -279,18 +279,18 @@ class TestApiTokenEndpoints:
|
||||
assert "API token has expired" in data["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_token_authentication_malformed_header(self, client: AsyncClient):
|
||||
"""Test authentication with malformed Authorization header."""
|
||||
# Missing Bearer prefix
|
||||
headers = {"Authorization": "invalid_format"}
|
||||
async def test_api_token_authentication_empty_token(self, client: AsyncClient):
|
||||
"""Test authentication with empty API-TOKEN header."""
|
||||
# Empty token
|
||||
headers = {"API-TOKEN": ""}
|
||||
response = await client.get("/api/v1/auth/me", headers=headers)
|
||||
|
||||
assert response.status_code == 401
|
||||
data = response.json()
|
||||
assert "Invalid authorization header format" in data["detail"]
|
||||
assert "Could not validate credentials" in data["detail"]
|
||||
|
||||
# Empty token
|
||||
headers = {"Authorization": "Bearer "}
|
||||
# Whitespace only token
|
||||
headers = {"API-TOKEN": " "}
|
||||
response = await client.get("/api/v1/auth/me", headers=headers)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -313,7 +313,7 @@ class TestApiTokenEndpoints:
|
||||
authenticated_user.is_active = False
|
||||
|
||||
# Try to authenticate with API token
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
headers = {"API-TOKEN": api_token}
|
||||
response = await client.get("/api/v1/auth/me", headers=headers)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -332,9 +332,9 @@ class TestApiTokenEndpoints:
|
||||
)
|
||||
api_token = token_response.json()["api_token"]
|
||||
|
||||
# Set both cookies and Authorization header
|
||||
# Set both cookies and API-TOKEN header
|
||||
client.cookies.update(auth_cookies)
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
headers = {"API-TOKEN": api_token}
|
||||
|
||||
# This should use API token authentication
|
||||
response = await client.get("/api/v1/auth/me", headers=headers)
|
||||
|
||||
@@ -41,40 +41,40 @@ class TestApiTokenDependencies:
|
||||
"""Test successful API token authentication."""
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
|
||||
authorization = "Bearer test_api_token_123"
|
||||
api_token_header = "test_api_token_123"
|
||||
|
||||
result = await get_current_user_api_token(mock_auth_service, authorization)
|
||||
result = await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert result == test_user
|
||||
mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_no_header(self, mock_auth_service):
|
||||
"""Test API token authentication without Authorization header."""
|
||||
"""Test API token authentication without API-TOKEN header."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_invalid_format(self, mock_auth_service):
|
||||
"""Test API token authentication with invalid header format."""
|
||||
authorization = "Invalid format"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, authorization)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid authorization header format" in exc_info.value.detail
|
||||
assert "API-TOKEN header required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_empty_token(self, mock_auth_service):
|
||||
"""Test API token authentication with empty token."""
|
||||
authorization = "Bearer "
|
||||
api_token_header = " "
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, authorization)
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "API token required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_whitespace_token(self, mock_auth_service):
|
||||
"""Test API token authentication with whitespace-only token."""
|
||||
api_token_header = " "
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "API token required" in exc_info.value.detail
|
||||
@@ -84,10 +84,10 @@ class TestApiTokenDependencies:
|
||||
"""Test API token authentication with invalid token."""
|
||||
mock_auth_service.get_user_by_api_token.return_value = None
|
||||
|
||||
authorization = "Bearer invalid_token"
|
||||
api_token_header = "invalid_token"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, authorization)
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid API token" in exc_info.value.detail
|
||||
@@ -101,10 +101,10 @@ class TestApiTokenDependencies:
|
||||
test_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1)
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
|
||||
authorization = "Bearer expired_token"
|
||||
api_token_header = "expired_token"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, authorization)
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "API token has expired" in exc_info.value.detail
|
||||
@@ -117,10 +117,10 @@ class TestApiTokenDependencies:
|
||||
test_user.is_active = False
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
|
||||
authorization = "Bearer test_token"
|
||||
api_token_header = "test_token"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, authorization)
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Account is deactivated" in exc_info.value.detail
|
||||
@@ -130,10 +130,10 @@ class TestApiTokenDependencies:
|
||||
"""Test API token authentication with service exception."""
|
||||
mock_auth_service.get_user_by_api_token.side_effect = Exception("Database error")
|
||||
|
||||
authorization = "Bearer test_token"
|
||||
api_token_header = "test_token"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, authorization)
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Could not validate API token" in exc_info.value.detail
|
||||
@@ -145,11 +145,11 @@ class TestApiTokenDependencies:
|
||||
"""Test flexible authentication uses API token when available."""
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
|
||||
authorization = "Bearer test_api_token_123"
|
||||
api_token_header = "test_api_token_123"
|
||||
access_token = "jwt_token"
|
||||
|
||||
result = await get_current_user_flexible(
|
||||
mock_auth_service, access_token, authorization,
|
||||
mock_auth_service, access_token, api_token_header,
|
||||
)
|
||||
|
||||
assert result == test_user
|
||||
@@ -170,22 +170,20 @@ class TestApiTokenDependencies:
|
||||
test_user.api_token_expires_at = None
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
|
||||
authorization = "Bearer test_token"
|
||||
api_token_header = "test_token"
|
||||
|
||||
result = await get_current_user_api_token(mock_auth_service, authorization)
|
||||
result = await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert result == test_user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_token_bearer_case_insensitive(self, mock_auth_service, test_user):
|
||||
"""Test that Bearer prefix is case-sensitive (as per OAuth2 spec)."""
|
||||
async def test_api_token_with_whitespace(self, mock_auth_service, test_user):
|
||||
"""Test API token with leading/trailing whitespace is handled correctly."""
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
|
||||
# lowercase bearer should fail
|
||||
authorization = "bearer test_token"
|
||||
api_token_header = " test_token "
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, authorization)
|
||||
result = await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid authorization header format" in exc_info.value.detail
|
||||
assert result == test_user
|
||||
mock_auth_service.get_user_by_api_token.assert_called_once_with("test_token")
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
"""Tests for token utilities."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from app.utils.auth import TokenUtils
|
||||
|
||||
TOKEN_LENGTH = 43 # Length of URL-safe base64 encoded 32-byte token
|
||||
UNIQUE_TOKENS_COUNT = 10 # Number of unique tokens to generate for uniqueness test
|
||||
|
||||
|
||||
class TestTokenUtils:
|
||||
"""Test token utility functions."""
|
||||
|
||||
def test_generate_api_token(self):
|
||||
def test_generate_api_token(self) -> None:
|
||||
"""Test API token generation."""
|
||||
token = TokenUtils.generate_api_token()
|
||||
|
||||
@@ -19,44 +23,44 @@ class TestTokenUtils:
|
||||
assert len(token) > 0
|
||||
|
||||
# Should be URL-safe base64 (43 characters for 32 bytes)
|
||||
assert len(token) == 43
|
||||
assert len(token) == TOKEN_LENGTH
|
||||
|
||||
# Should be unique (generate multiple and check they're different)
|
||||
tokens = [TokenUtils.generate_api_token() for _ in range(10)]
|
||||
assert len(set(tokens)) == 10
|
||||
tokens = [TokenUtils.generate_api_token() for _ in range(UNIQUE_TOKENS_COUNT)]
|
||||
assert len(set(tokens)) == UNIQUE_TOKENS_COUNT
|
||||
|
||||
def test_is_token_expired_none(self):
|
||||
def test_is_token_expired_none(self) -> None:
|
||||
"""Test token expiration check with None expires_at."""
|
||||
result = TokenUtils.is_token_expired(None)
|
||||
assert result is False
|
||||
|
||||
def test_is_token_expired_future_naive(self):
|
||||
def test_is_token_expired_future_naive(self) -> None:
|
||||
"""Test token expiration check with future naive datetime."""
|
||||
# Use UTC time for naive datetime (as the function assumes)
|
||||
expires_at = datetime.utcnow() + timedelta(hours=1)
|
||||
expires_at = datetime.now(UTC) + timedelta(hours=1)
|
||||
result = TokenUtils.is_token_expired(expires_at)
|
||||
assert result is False
|
||||
|
||||
def test_is_token_expired_past_naive(self):
|
||||
def test_is_token_expired_past_naive(self) -> None:
|
||||
"""Test token expiration check with past naive datetime."""
|
||||
# Use UTC time for naive datetime (as the function assumes)
|
||||
expires_at = datetime.utcnow() - timedelta(hours=1)
|
||||
expires_at = datetime.now(UTC) - timedelta(hours=1)
|
||||
result = TokenUtils.is_token_expired(expires_at)
|
||||
assert result is True
|
||||
|
||||
def test_is_token_expired_future_aware(self):
|
||||
def test_is_token_expired_future_aware(self) -> None:
|
||||
"""Test token expiration check with future timezone-aware datetime."""
|
||||
expires_at = datetime.now(UTC) + timedelta(hours=1)
|
||||
result = TokenUtils.is_token_expired(expires_at)
|
||||
assert result is False
|
||||
|
||||
def test_is_token_expired_past_aware(self):
|
||||
def test_is_token_expired_past_aware(self) -> None:
|
||||
"""Test token expiration check with past timezone-aware datetime."""
|
||||
expires_at = datetime.now(UTC) - timedelta(hours=1)
|
||||
result = TokenUtils.is_token_expired(expires_at)
|
||||
assert result is True
|
||||
|
||||
def test_is_token_expired_edge_case_now(self):
|
||||
def test_is_token_expired_edge_case_now(self) -> None:
|
||||
"""Test token expiration check with time very close to now."""
|
||||
# Token expires in 1 second
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=1)
|
||||
@@ -68,10 +72,8 @@ class TestTokenUtils:
|
||||
result = TokenUtils.is_token_expired(expires_at)
|
||||
assert result is True
|
||||
|
||||
def test_is_token_expired_timezone_conversion(self):
|
||||
def test_is_token_expired_timezone_conversion(self) -> None:
|
||||
"""Test token expiration check with different timezone."""
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
# Create a datetime in a different timezone
|
||||
eastern = ZoneInfo("US/Eastern")
|
||||
expires_at = datetime.now(eastern) + timedelta(hours=1)
|
||||
|
||||
Reference in New Issue
Block a user