feat: Update API token handling to use API-TOKEN header and improve related tests

This commit is contained in:
JSC
2025-07-27 22:15:23 +02:00
parent 3dc21337f9
commit 58030914e6
5 changed files with 80 additions and 85 deletions

View File

@@ -239,7 +239,7 @@ class TestApiTokenEndpoints:
api_token = token_response.json()["api_token"]
# Use API token to authenticate
headers = {"Authorization": f"Bearer {api_token}"}
headers = {"API-TOKEN": api_token}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 200
@@ -250,7 +250,7 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio
async def test_api_token_authentication_invalid_token(self, client: AsyncClient):
"""Test authentication with invalid API token."""
headers = {"Authorization": "Bearer invalid_token"}
headers = {"API-TOKEN": "invalid_token"}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
@@ -271,7 +271,7 @@ class TestApiTokenEndpoints:
# Mock expired token
with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True):
headers = {"Authorization": f"Bearer {api_token}"}
headers = {"API-TOKEN": api_token}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
@@ -279,18 +279,18 @@ class TestApiTokenEndpoints:
assert "API token has expired" in data["detail"]
@pytest.mark.asyncio
async def test_api_token_authentication_malformed_header(self, client: AsyncClient):
"""Test authentication with malformed Authorization header."""
# Missing Bearer prefix
headers = {"Authorization": "invalid_format"}
async def test_api_token_authentication_empty_token(self, client: AsyncClient):
"""Test authentication with empty API-TOKEN header."""
# Empty token
headers = {"API-TOKEN": ""}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
data = response.json()
assert "Invalid authorization header format" in data["detail"]
assert "Could not validate credentials" in data["detail"]
# Empty token
headers = {"Authorization": "Bearer "}
# Whitespace only token
headers = {"API-TOKEN": " "}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
@@ -313,7 +313,7 @@ class TestApiTokenEndpoints:
authenticated_user.is_active = False
# Try to authenticate with API token
headers = {"Authorization": f"Bearer {api_token}"}
headers = {"API-TOKEN": api_token}
response = await client.get("/api/v1/auth/me", headers=headers)
assert response.status_code == 401
@@ -332,9 +332,9 @@ class TestApiTokenEndpoints:
)
api_token = token_response.json()["api_token"]
# Set both cookies and Authorization header
# Set both cookies and API-TOKEN header
client.cookies.update(auth_cookies)
headers = {"Authorization": f"Bearer {api_token}"}
headers = {"API-TOKEN": api_token}
# This should use API token authentication
response = await client.get("/api/v1/auth/me", headers=headers)

View File

@@ -41,40 +41,40 @@ class TestApiTokenDependencies:
"""Test successful API token authentication."""
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_api_token_123"
api_token_header = "test_api_token_123"
result = await get_current_user_api_token(mock_auth_service, authorization)
result = await get_current_user_api_token(mock_auth_service, api_token_header)
assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123")
@pytest.mark.asyncio
async def test_get_current_user_api_token_no_header(self, mock_auth_service):
"""Test API token authentication without Authorization header."""
"""Test API token authentication without API-TOKEN header."""
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, None)
assert exc_info.value.status_code == 401
assert "Authorization header required" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_invalid_format(self, mock_auth_service):
"""Test API token authentication with invalid header format."""
authorization = "Invalid format"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
assert exc_info.value.status_code == 401
assert "Invalid authorization header format" in exc_info.value.detail
assert "API-TOKEN header required" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_empty_token(self, mock_auth_service):
"""Test API token authentication with empty token."""
authorization = "Bearer "
api_token_header = " "
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401
assert "API token required" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_api_token_whitespace_token(self, mock_auth_service):
"""Test API token authentication with whitespace-only token."""
api_token_header = " "
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401
assert "API token required" in exc_info.value.detail
@@ -84,10 +84,10 @@ class TestApiTokenDependencies:
"""Test API token authentication with invalid token."""
mock_auth_service.get_user_by_api_token.return_value = None
authorization = "Bearer invalid_token"
api_token_header = "invalid_token"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401
assert "Invalid API token" in exc_info.value.detail
@@ -101,10 +101,10 @@ class TestApiTokenDependencies:
test_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1)
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer expired_token"
api_token_header = "expired_token"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401
assert "API token has expired" in exc_info.value.detail
@@ -117,10 +117,10 @@ class TestApiTokenDependencies:
test_user.is_active = False
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_token"
api_token_header = "test_token"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401
assert "Account is deactivated" in exc_info.value.detail
@@ -130,10 +130,10 @@ class TestApiTokenDependencies:
"""Test API token authentication with service exception."""
mock_auth_service.get_user_by_api_token.side_effect = Exception("Database error")
authorization = "Bearer test_token"
api_token_header = "test_token"
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401
assert "Could not validate API token" in exc_info.value.detail
@@ -145,11 +145,11 @@ class TestApiTokenDependencies:
"""Test flexible authentication uses API token when available."""
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_api_token_123"
api_token_header = "test_api_token_123"
access_token = "jwt_token"
result = await get_current_user_flexible(
mock_auth_service, access_token, authorization,
mock_auth_service, access_token, api_token_header,
)
assert result == test_user
@@ -170,22 +170,20 @@ class TestApiTokenDependencies:
test_user.api_token_expires_at = None
mock_auth_service.get_user_by_api_token.return_value = test_user
authorization = "Bearer test_token"
api_token_header = "test_token"
result = await get_current_user_api_token(mock_auth_service, authorization)
result = await get_current_user_api_token(mock_auth_service, api_token_header)
assert result == test_user
@pytest.mark.asyncio
async def test_api_token_bearer_case_insensitive(self, mock_auth_service, test_user):
"""Test that Bearer prefix is case-sensitive (as per OAuth2 spec)."""
async def test_api_token_with_whitespace(self, mock_auth_service, test_user):
"""Test API token with leading/trailing whitespace is handled correctly."""
mock_auth_service.get_user_by_api_token.return_value = test_user
# lowercase bearer should fail
authorization = "bearer test_token"
api_token_header = " test_token "
with pytest.raises(HTTPException) as exc_info:
await get_current_user_api_token(mock_auth_service, authorization)
result = await get_current_user_api_token(mock_auth_service, api_token_header)
assert exc_info.value.status_code == 401
assert "Invalid authorization header format" in exc_info.value.detail
assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with("test_token")

View File

@@ -1,14 +1,18 @@
"""Tests for token utilities."""
from datetime import UTC, datetime, timedelta
from zoneinfo import ZoneInfo
from app.utils.auth import TokenUtils
TOKEN_LENGTH = 43 # Length of URL-safe base64 encoded 32-byte token
UNIQUE_TOKENS_COUNT = 10 # Number of unique tokens to generate for uniqueness test
class TestTokenUtils:
"""Test token utility functions."""
def test_generate_api_token(self):
def test_generate_api_token(self) -> None:
"""Test API token generation."""
token = TokenUtils.generate_api_token()
@@ -19,44 +23,44 @@ class TestTokenUtils:
assert len(token) > 0
# Should be URL-safe base64 (43 characters for 32 bytes)
assert len(token) == 43
assert len(token) == TOKEN_LENGTH
# Should be unique (generate multiple and check they're different)
tokens = [TokenUtils.generate_api_token() for _ in range(10)]
assert len(set(tokens)) == 10
tokens = [TokenUtils.generate_api_token() for _ in range(UNIQUE_TOKENS_COUNT)]
assert len(set(tokens)) == UNIQUE_TOKENS_COUNT
def test_is_token_expired_none(self):
def test_is_token_expired_none(self) -> None:
"""Test token expiration check with None expires_at."""
result = TokenUtils.is_token_expired(None)
assert result is False
def test_is_token_expired_future_naive(self):
def test_is_token_expired_future_naive(self) -> None:
"""Test token expiration check with future naive datetime."""
# Use UTC time for naive datetime (as the function assumes)
expires_at = datetime.utcnow() + timedelta(hours=1)
expires_at = datetime.now(UTC) + timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is False
def test_is_token_expired_past_naive(self):
def test_is_token_expired_past_naive(self) -> None:
"""Test token expiration check with past naive datetime."""
# Use UTC time for naive datetime (as the function assumes)
expires_at = datetime.utcnow() - timedelta(hours=1)
expires_at = datetime.now(UTC) - timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is True
def test_is_token_expired_future_aware(self):
def test_is_token_expired_future_aware(self) -> None:
"""Test token expiration check with future timezone-aware datetime."""
expires_at = datetime.now(UTC) + timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is False
def test_is_token_expired_past_aware(self):
def test_is_token_expired_past_aware(self) -> None:
"""Test token expiration check with past timezone-aware datetime."""
expires_at = datetime.now(UTC) - timedelta(hours=1)
result = TokenUtils.is_token_expired(expires_at)
assert result is True
def test_is_token_expired_edge_case_now(self):
def test_is_token_expired_edge_case_now(self) -> None:
"""Test token expiration check with time very close to now."""
# Token expires in 1 second
expires_at = datetime.now(UTC) + timedelta(seconds=1)
@@ -68,10 +72,8 @@ class TestTokenUtils:
result = TokenUtils.is_token_expired(expires_at)
assert result is True
def test_is_token_expired_timezone_conversion(self):
def test_is_token_expired_timezone_conversion(self) -> None:
"""Test token expiration check with different timezone."""
from zoneinfo import ZoneInfo
# Create a datetime in a different timezone
eastern = ZoneInfo("US/Eastern")
expires_at = datetime.now(eastern) + timedelta(hours=1)