Compare commits
2 Commits
502feea035
...
dc29915fbc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc29915fbc | ||
|
|
389cfe2d6a |
1
tests/core/__init__.py
Normal file
1
tests/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for core module."""
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for API token authentication dependencies."""
|
||||
# ruff: noqa: S106
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock
|
||||
@@ -10,17 +11,20 @@ from app.core.dependencies import get_current_user_api_token, get_current_user_f
|
||||
from app.models.user import User
|
||||
from app.services.auth import AuthService
|
||||
|
||||
# Constants
|
||||
HTTP_401_UNAUTHORIZED = 401
|
||||
|
||||
|
||||
class TestApiTokenDependencies:
|
||||
"""Test API token authentication dependencies."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_service(self):
|
||||
def mock_auth_service(self) -> AsyncMock:
|
||||
"""Create a mock auth service."""
|
||||
return AsyncMock(spec=AuthService)
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(self):
|
||||
def test_user(self) -> User:
|
||||
"""Create a test user."""
|
||||
return User(
|
||||
id=1,
|
||||
@@ -37,9 +41,9 @@ class TestApiTokenDependencies:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_success(
|
||||
self,
|
||||
mock_auth_service,
|
||||
test_user,
|
||||
):
|
||||
mock_auth_service: AsyncMock,
|
||||
test_user: User,
|
||||
) -> None:
|
||||
"""Test successful API token authentication."""
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
|
||||
@@ -53,38 +57,46 @@ class TestApiTokenDependencies:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_no_header(self, mock_auth_service):
|
||||
async def test_get_current_user_api_token_no_header(
|
||||
self, mock_auth_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Test API token authentication without API-TOKEN header."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert "API-TOKEN header required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_empty_token(self, mock_auth_service):
|
||||
async def test_get_current_user_api_token_empty_token(
|
||||
self, mock_auth_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Test API token authentication with empty token."""
|
||||
api_token_header = " "
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert "API token required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_whitespace_token(self, mock_auth_service):
|
||||
async def test_get_current_user_api_token_whitespace_token(
|
||||
self, mock_auth_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Test API token authentication with whitespace-only token."""
|
||||
api_token_header = " "
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert "API token required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_invalid_token(self, mock_auth_service):
|
||||
async def test_get_current_user_api_token_invalid_token(
|
||||
self, mock_auth_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Test API token authentication with invalid token."""
|
||||
mock_auth_service.get_user_by_api_token.return_value = None
|
||||
|
||||
@@ -93,15 +105,15 @@ class TestApiTokenDependencies:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert "Invalid API token" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_expired_token(
|
||||
self,
|
||||
mock_auth_service,
|
||||
test_user,
|
||||
):
|
||||
mock_auth_service: AsyncMock,
|
||||
test_user: User,
|
||||
) -> None:
|
||||
"""Test API token authentication with expired token."""
|
||||
# Set expired token
|
||||
test_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1)
|
||||
@@ -112,15 +124,15 @@ class TestApiTokenDependencies:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert "API token has expired" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_inactive_user(
|
||||
self,
|
||||
mock_auth_service,
|
||||
test_user,
|
||||
):
|
||||
mock_auth_service: AsyncMock,
|
||||
test_user: User,
|
||||
) -> None:
|
||||
"""Test API token authentication with inactive user."""
|
||||
test_user.is_active = False
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
@@ -130,13 +142,13 @@ class TestApiTokenDependencies:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert "Account is deactivated" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_api_token_service_exception(
|
||||
self, mock_auth_service,
|
||||
):
|
||||
self, mock_auth_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Test API token authentication with service exception."""
|
||||
mock_auth_service.get_user_by_api_token.side_effect = Exception(
|
||||
"Database error",
|
||||
@@ -147,15 +159,15 @@ class TestApiTokenDependencies:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_api_token(mock_auth_service, api_token_header)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert "Could not validate API token" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_flexible_uses_api_token(
|
||||
self,
|
||||
mock_auth_service,
|
||||
test_user,
|
||||
):
|
||||
mock_auth_service: AsyncMock,
|
||||
test_user: User,
|
||||
) -> None:
|
||||
"""Test flexible authentication uses API token when available."""
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
|
||||
@@ -174,18 +186,20 @@ class TestApiTokenDependencies:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_flexible_falls_back_to_jwt(self, mock_auth_service):
|
||||
async def test_get_current_user_flexible_falls_back_to_jwt(
|
||||
self, mock_auth_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Test flexible authentication falls back to JWT when no API token."""
|
||||
# Mock the get_current_user function (normally imported)
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception, match="Database error|Could not validate"):
|
||||
# This will fail because we can't easily mock the get_current_user import
|
||||
# In a real test, you'd mock the import or use dependency injection
|
||||
await get_current_user_flexible(mock_auth_service, "jwt_token", None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_token_no_expiry_never_expires(
|
||||
self, mock_auth_service, test_user,
|
||||
):
|
||||
self, mock_auth_service: AsyncMock, test_user: User,
|
||||
) -> None:
|
||||
"""Test API token with no expiry date never expires."""
|
||||
test_user.api_token_expires_at = None
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
@@ -197,7 +211,9 @@ class TestApiTokenDependencies:
|
||||
assert result == test_user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_token_with_whitespace(self, mock_auth_service, test_user):
|
||||
async def test_api_token_with_whitespace(
|
||||
self, mock_auth_service: AsyncMock, test_user: User,
|
||||
) -> None:
|
||||
"""Test API token with leading/trailing whitespace is handled correctly."""
|
||||
mock_auth_service.get_user_by_api_token.return_value = test_user
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for credit transaction repository."""
|
||||
# ruff: noqa: ARG002, E501
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -11,6 +12,18 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.user import User
|
||||
from app.repositories.credit_transaction import CreditTransactionRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
# Constants
|
||||
EXPECTED_TRANSACTION_COUNT = 4
|
||||
PAGE_SIZE = 2
|
||||
MIN_VLC_TRANSACTIONS = 2
|
||||
MIN_SUCCESSFUL_TRANSACTIONS = 3
|
||||
SUCCESSFUL_TRANSACTION_COUNT = 3
|
||||
MIN_ALL_TRANSACTIONS = 5
|
||||
TEST_AMOUNT = -10
|
||||
TEST_BALANCE_BEFORE = 100
|
||||
TEST_BALANCE_AFTER = 90
|
||||
|
||||
|
||||
class TestCreditTransactionRepository:
|
||||
@@ -102,11 +115,9 @@ class TestCreditTransactionRepository:
|
||||
async def other_user_transaction(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans: tuple[Any, ...], # noqa: ARG002
|
||||
ensure_plans: tuple[Any, ...],
|
||||
) -> AsyncGenerator[CreditTransaction, None]:
|
||||
"""Create a transaction for a different user."""
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
# Create another user
|
||||
user_repo = UserRepository(test_session)
|
||||
other_user_data = {
|
||||
@@ -174,7 +185,7 @@ class TestCreditTransactionRepository:
|
||||
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
|
||||
|
||||
# Should return all transactions for test_user
|
||||
assert len(transactions) == 4
|
||||
assert len(transactions) == EXPECTED_TRANSACTION_COUNT
|
||||
# Should be ordered by created_at desc (newest first)
|
||||
assert all(t.user_id == test_user_id for t in transactions)
|
||||
|
||||
@@ -194,13 +205,13 @@ class TestCreditTransactionRepository:
|
||||
first_page = await credit_transaction_repository.get_by_user_id(
|
||||
test_user_id, limit=2, offset=0,
|
||||
)
|
||||
assert len(first_page) == 2
|
||||
assert len(first_page) == PAGE_SIZE
|
||||
|
||||
# Get next 2 transactions
|
||||
second_page = await credit_transaction_repository.get_by_user_id(
|
||||
test_user_id, limit=2, offset=2,
|
||||
)
|
||||
assert len(second_page) == 2
|
||||
assert len(second_page) == PAGE_SIZE
|
||||
|
||||
# Should not overlap
|
||||
first_page_ids = {t.id for t in first_page}
|
||||
@@ -219,11 +230,13 @@ class TestCreditTransactionRepository:
|
||||
)
|
||||
|
||||
# Should return 2 VLC transactions (1 successful, 1 failed)
|
||||
assert len(vlc_transactions) >= 2
|
||||
assert len(vlc_transactions) >= MIN_VLC_TRANSACTIONS
|
||||
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
|
||||
|
||||
extraction_transactions = await credit_transaction_repository.get_by_action_type(
|
||||
"audio_extraction",
|
||||
extraction_transactions = (
|
||||
await credit_transaction_repository.get_by_action_type(
|
||||
"audio_extraction",
|
||||
)
|
||||
)
|
||||
|
||||
# Should return 1 extraction transaction
|
||||
@@ -262,7 +275,7 @@ class TestCreditTransactionRepository:
|
||||
# Should only return successful transactions
|
||||
assert all(t.success is True for t in successful_transactions)
|
||||
# Should be at least 3 (vlc_play_sound, audio_extraction, credit_addition)
|
||||
assert len(successful_transactions) >= 3
|
||||
assert len(successful_transactions) >= MIN_SUCCESSFUL_TRANSACTIONS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_successful_transactions_by_user(
|
||||
@@ -281,7 +294,7 @@ class TestCreditTransactionRepository:
|
||||
assert all(t.success is True for t in successful_transactions)
|
||||
assert all(t.user_id == test_user_id for t in successful_transactions)
|
||||
# Should be 3 successful transactions for test_user
|
||||
assert len(successful_transactions) == 3
|
||||
assert len(successful_transactions) == SUCCESSFUL_TRANSACTION_COUNT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_successful_transactions_with_pagination(
|
||||
@@ -295,7 +308,7 @@ class TestCreditTransactionRepository:
|
||||
first_page = await credit_transaction_repository.get_successful_transactions(
|
||||
user_id=test_user_id, limit=2, offset=0,
|
||||
)
|
||||
assert len(first_page) == 2
|
||||
assert len(first_page) == PAGE_SIZE
|
||||
assert all(t.success is True for t in first_page)
|
||||
|
||||
# Get next successful transaction
|
||||
@@ -316,7 +329,7 @@ class TestCreditTransactionRepository:
|
||||
all_transactions = await credit_transaction_repository.get_all()
|
||||
|
||||
# Should return all transactions
|
||||
assert len(all_transactions) >= 5 # 4 from test_transactions + 1 other_user_transaction
|
||||
assert len(all_transactions) >= MIN_ALL_TRANSACTIONS # 4 from test_transactions + 1 other_user_transaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_transaction(
|
||||
@@ -341,9 +354,9 @@ class TestCreditTransactionRepository:
|
||||
assert transaction.id is not None
|
||||
assert transaction.user_id == test_user_id
|
||||
assert transaction.action_type == "test_action"
|
||||
assert transaction.amount == -10
|
||||
assert transaction.balance_before == 100
|
||||
assert transaction.balance_after == 90
|
||||
assert transaction.amount == TEST_AMOUNT
|
||||
assert transaction.balance_before == TEST_BALANCE_BEFORE
|
||||
assert transaction.balance_after == TEST_BALANCE_AFTER
|
||||
assert transaction.success is True
|
||||
assert transaction.metadata_json is not None
|
||||
assert json.loads(transaction.metadata_json) == {"test": "data"}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for extraction repository."""
|
||||
# ruff: noqa: ANN001, ANN201
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
@@ -8,6 +9,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.models.extraction import Extraction
|
||||
from app.repositories.extraction import ExtractionRepository
|
||||
|
||||
# Constants
|
||||
TEST_SOUND_ID = 42
|
||||
|
||||
|
||||
class TestExtractionRepository:
|
||||
"""Test extraction repository."""
|
||||
@@ -123,6 +127,6 @@ class TestExtractionRepository:
|
||||
result = await extraction_repo.update(extraction, update_data)
|
||||
|
||||
assert result.status == "completed"
|
||||
assert result.sound_id == 42
|
||||
assert result.sound_id == TEST_SOUND_ID
|
||||
extraction_repo.session.commit.assert_called_once()
|
||||
extraction_repo.session.refresh.assert_called_once_with(extraction)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Tests for playlist repository."""
|
||||
# ruff: noqa: PLR2004, ANN401
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@@ -10,6 +12,16 @@ from app.models.playlist import Playlist
|
||||
from app.models.sound import Sound
|
||||
from app.models.user import User
|
||||
from app.repositories.playlist import PlaylistRepository
|
||||
from app.utils.auth import PasswordUtils
|
||||
|
||||
# Constants
|
||||
TEST_POSITION = 5
|
||||
TEST_TOTAL_SOUNDS = 3
|
||||
ONE_PLAYLIST = 1
|
||||
ZERO_SOUNDS = 0
|
||||
ONE_SOUND = 1
|
||||
TWO_SOUNDS = 2
|
||||
DEFAULT_POSITION = 0
|
||||
|
||||
|
||||
class TestPlaylistRepository:
|
||||
@@ -134,11 +146,10 @@ class TestPlaylistRepository:
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test getting playlists by user ID."""
|
||||
# Create test user within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
@@ -172,7 +183,7 @@ class TestPlaylistRepository:
|
||||
playlists = await playlist_repository.get_by_user_id(user_id)
|
||||
|
||||
# Should only return user's playlists, not the main playlist (user_id=None)
|
||||
assert len(playlists) == 1
|
||||
assert len(playlists) == ONE_PLAYLIST
|
||||
assert playlists[0].name == "Test Playlist"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -210,11 +221,10 @@ class TestPlaylistRepository:
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test getting current playlist when none is set."""
|
||||
# Create test user within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test2@example.com",
|
||||
name="Test User 2",
|
||||
@@ -302,11 +312,10 @@ class TestPlaylistRepository:
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test searching playlists by name."""
|
||||
# Create test user within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test3@example.com",
|
||||
name="Test User 3",
|
||||
@@ -353,11 +362,12 @@ class TestPlaylistRepository:
|
||||
|
||||
# Search with user filter
|
||||
user_results = await playlist_repository.search_by_name("playlist", user_id)
|
||||
assert len(user_results) == 1 # Only user's playlists, not main playlist
|
||||
# Only user's playlists, not main playlist
|
||||
assert len(user_results) == ONE_PLAYLIST
|
||||
|
||||
# Search for specific playlist
|
||||
test_results = await playlist_repository.search_by_name("test", user_id)
|
||||
assert len(test_results) == 1
|
||||
assert len(test_results) == ONE_PLAYLIST
|
||||
assert test_results[0].name == "Test Playlist"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -365,11 +375,10 @@ class TestPlaylistRepository:
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test adding a sound to a playlist."""
|
||||
# Create test user within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test4@example.com",
|
||||
name="Test User 4",
|
||||
@@ -421,18 +430,17 @@ class TestPlaylistRepository:
|
||||
|
||||
assert playlist_sound.playlist_id == playlist_id
|
||||
assert playlist_sound.sound_id == sound_id
|
||||
assert playlist_sound.position == 0
|
||||
assert playlist_sound.position == DEFAULT_POSITION
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_sound_to_playlist_with_position(
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test adding a sound to a playlist with specific position."""
|
||||
# Create test user within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test5@example.com",
|
||||
name="Test User 5",
|
||||
@@ -485,18 +493,17 @@ class TestPlaylistRepository:
|
||||
playlist_id, sound_id, position=5,
|
||||
)
|
||||
|
||||
assert playlist_sound.position == 5
|
||||
assert playlist_sound.position == TEST_POSITION
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_sound_from_playlist(
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test removing a sound from a playlist."""
|
||||
# Create objects within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
@@ -564,11 +571,10 @@ class TestPlaylistRepository:
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test getting sounds in a playlist."""
|
||||
# Create objects within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
@@ -615,14 +621,14 @@ class TestPlaylistRepository:
|
||||
|
||||
# Initially empty
|
||||
sounds = await playlist_repository.get_playlist_sounds(playlist_id)
|
||||
assert len(sounds) == 0
|
||||
assert len(sounds) == ZERO_SOUNDS
|
||||
|
||||
# Add sound
|
||||
await playlist_repository.add_sound_to_playlist(playlist_id, sound_id)
|
||||
|
||||
# Check sounds
|
||||
sounds = await playlist_repository.get_playlist_sounds(playlist_id)
|
||||
assert len(sounds) == 1
|
||||
assert len(sounds) == ONE_SOUND
|
||||
assert sounds[0].id == sound_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -630,11 +636,10 @@ class TestPlaylistRepository:
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test getting sound count in a playlist."""
|
||||
# Create objects within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
@@ -695,11 +700,10 @@ class TestPlaylistRepository:
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test checking if sound is in playlist."""
|
||||
# Create objects within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
@@ -762,11 +766,10 @@ class TestPlaylistRepository:
|
||||
self,
|
||||
playlist_repository: PlaylistRepository,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans,
|
||||
ensure_plans: Any,
|
||||
) -> None:
|
||||
"""Test reordering sounds in a playlist."""
|
||||
# Create objects within this test
|
||||
from app.utils.auth import PasswordUtils
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
@@ -823,6 +826,6 @@ class TestPlaylistRepository:
|
||||
|
||||
# Verify new order
|
||||
sounds = await playlist_repository.get_playlist_sounds(playlist_id)
|
||||
assert len(sounds) == 2
|
||||
assert len(sounds) == TWO_SOUNDS
|
||||
assert sounds[0].id == sound2_id # sound2 now at position 5
|
||||
assert sounds[1].id == sound1_id # sound1 now at position 10
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
"""Tests for sound repository."""
|
||||
# ruff: noqa: ARG002, PLR2004
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.sound import Sound
|
||||
from app.repositories.sound import SoundRepository
|
||||
|
||||
# Constants
|
||||
MIN_POPULAR_SOUNDS = 3
|
||||
|
||||
|
||||
class TestSoundRepository:
|
||||
"""Test sound repository operations."""
|
||||
@@ -306,7 +311,7 @@ class TestSoundRepository:
|
||||
# Get popular sounds
|
||||
popular_sounds = await sound_repository.get_popular_sounds(limit=10)
|
||||
|
||||
assert len(popular_sounds) >= 3
|
||||
assert len(popular_sounds) >= MIN_POPULAR_SOUNDS
|
||||
# Should be ordered by play_count desc
|
||||
assert popular_sounds[0].play_count >= popular_sounds[1].play_count
|
||||
# The highest play count sound should be first
|
||||
@@ -372,5 +377,5 @@ class TestSoundRepository:
|
||||
}
|
||||
|
||||
# Should fail due to unique constraint on hash
|
||||
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
|
||||
with pytest.raises(IntegrityError, match="UNIQUE constraint failed"):
|
||||
await sound_repository.create(duplicate_sound_data)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for user repository."""
|
||||
# ruff: noqa: ARG002
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
@@ -174,24 +175,24 @@ class TestUserRepository:
|
||||
test_user: User,
|
||||
) -> None:
|
||||
"""Test updating a user."""
|
||||
UPDATED_CREDITS = 200
|
||||
updated_credits = 200
|
||||
update_data = {
|
||||
"name": "Updated Name",
|
||||
"credits": UPDATED_CREDITS,
|
||||
"credits": updated_credits,
|
||||
}
|
||||
|
||||
updated_user = await user_repository.update(test_user, update_data)
|
||||
|
||||
assert updated_user.id == test_user.id
|
||||
assert updated_user.name == "Updated Name"
|
||||
assert updated_user.credits == UPDATED_CREDITS
|
||||
assert updated_user.credits == updated_credits
|
||||
assert updated_user.email == test_user.email # Unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user(
|
||||
self,
|
||||
user_repository: UserRepository,
|
||||
ensure_plans: tuple[Plan, Plan], # noqa: ARG002
|
||||
ensure_plans: tuple[Plan, Plan],
|
||||
test_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test deleting a user."""
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for user OAuth repository."""
|
||||
# ruff: noqa: ARG002
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
@@ -156,7 +158,9 @@ class TestUserOauthRepository:
|
||||
assert updated_oauth.name == "Updated User Name"
|
||||
assert updated_oauth.picture == "https://example.com/photo.jpg"
|
||||
assert updated_oauth.provider == test_oauth.provider # Unchanged
|
||||
assert updated_oauth.provider_user_id == test_oauth.provider_user_id # Unchanged
|
||||
assert (
|
||||
updated_oauth.provider_user_id == test_oauth.provider_user_id
|
||||
) # Unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_oauth(
|
||||
@@ -176,7 +180,7 @@ class TestUserOauthRepository:
|
||||
"picture": None,
|
||||
}
|
||||
oauth = await user_oauth_repository.create(oauth_data)
|
||||
oauth_id = oauth.id
|
||||
_ = oauth.id # Store ID but don't use it
|
||||
|
||||
# Delete the OAuth record
|
||||
await user_oauth_repository.delete(oauth)
|
||||
@@ -206,7 +210,7 @@ class TestUserOauthRepository:
|
||||
}
|
||||
|
||||
# This should fail due to unique constraint
|
||||
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
|
||||
with pytest.raises(IntegrityError, match="UNIQUE constraint failed"):
|
||||
await user_oauth_repository.create(duplicate_oauth_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -225,7 +229,7 @@ class TestUserOauthRepository:
|
||||
"name": "Test User Google",
|
||||
"picture": None,
|
||||
}
|
||||
google_oauth = await user_oauth_repository.create(google_oauth_data)
|
||||
_ = await user_oauth_repository.create(google_oauth_data)
|
||||
|
||||
# Create GitHub OAuth for the same user
|
||||
github_oauth_data = {
|
||||
@@ -236,7 +240,7 @@ class TestUserOauthRepository:
|
||||
"name": "Test User GitHub",
|
||||
"picture": None,
|
||||
}
|
||||
github_oauth = await user_oauth_repository.create(github_oauth_data)
|
||||
_ = await user_oauth_repository.create(github_oauth_data)
|
||||
|
||||
# Verify both exist by querying back from database
|
||||
found_google = await user_oauth_repository.get_by_user_id_and_provider(
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import hashlib
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -15,11 +15,18 @@ from app.utils.audio import (
|
||||
get_sound_file_path,
|
||||
)
|
||||
|
||||
# Constants
|
||||
SHA256_HASH_LENGTH = 64
|
||||
BINARY_FILE_SIZE = 700
|
||||
EXPECTED_DURATION_MS_1 = 123456 # 123.456 seconds * 1000
|
||||
EXPECTED_DURATION_MS_2 = 60000 # 60 seconds * 1000
|
||||
EXPECTED_DURATION_MS_3 = 45123 # 45.123 seconds * 1000
|
||||
|
||||
|
||||
class TestAudioUtils:
|
||||
"""Test audio utility functions."""
|
||||
|
||||
def test_get_file_hash(self):
|
||||
def test_get_file_hash(self) -> None:
|
||||
"""Test file hash calculation."""
|
||||
# Create a temporary file with known content
|
||||
test_content = "test content for hashing"
|
||||
@@ -36,13 +43,13 @@ class TestAudioUtils:
|
||||
|
||||
# Verify the hash is correct
|
||||
assert result_hash == expected_hash
|
||||
assert len(result_hash) == 64 # SHA-256 hash length
|
||||
assert len(result_hash) == SHA256_HASH_LENGTH # SHA-256 hash length
|
||||
assert isinstance(result_hash, str)
|
||||
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_get_file_hash_binary_content(self):
|
||||
def test_get_file_hash_binary_content(self) -> None:
|
||||
"""Test file hash calculation with binary content."""
|
||||
# Create a temporary file with binary content
|
||||
test_bytes = b"\x00\x01\x02\x03\xff\xfe\xfd"
|
||||
@@ -59,13 +66,13 @@ class TestAudioUtils:
|
||||
|
||||
# Verify the hash is correct
|
||||
assert result_hash == expected_hash
|
||||
assert len(result_hash) == 64 # SHA-256 hash length
|
||||
assert len(result_hash) == SHA256_HASH_LENGTH # SHA-256 hash length
|
||||
assert isinstance(result_hash, str)
|
||||
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_get_file_hash_empty_file(self):
|
||||
def test_get_file_hash_empty_file(self) -> None:
|
||||
"""Test file hash calculation for empty file."""
|
||||
# Create an empty temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
@@ -80,13 +87,13 @@ class TestAudioUtils:
|
||||
|
||||
# Verify the hash is correct
|
||||
assert result_hash == expected_hash
|
||||
assert len(result_hash) == 64 # SHA-256 hash length
|
||||
assert len(result_hash) == SHA256_HASH_LENGTH # SHA-256 hash length
|
||||
assert isinstance(result_hash, str)
|
||||
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_get_file_hash_large_file(self):
|
||||
def test_get_file_hash_large_file(self) -> None:
|
||||
"""Test file hash calculation for large file (tests chunked reading)."""
|
||||
# Create a large temporary file (larger than 4096 bytes chunk size)
|
||||
test_content = "A" * 10000 # 10KB of 'A' characters
|
||||
@@ -103,13 +110,13 @@ class TestAudioUtils:
|
||||
|
||||
# Verify the hash is correct
|
||||
assert result_hash == expected_hash
|
||||
assert len(result_hash) == 64 # SHA-256 hash length
|
||||
assert len(result_hash) == SHA256_HASH_LENGTH # SHA-256 hash length
|
||||
assert isinstance(result_hash, str)
|
||||
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_get_file_size(self):
|
||||
def test_get_file_size(self) -> None:
|
||||
"""Test file size calculation."""
|
||||
# Create a temporary file with known content
|
||||
test_content = "test content for size calculation"
|
||||
@@ -132,7 +139,7 @@ class TestAudioUtils:
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_get_file_size_empty_file(self):
|
||||
def test_get_file_size_empty_file(self) -> None:
|
||||
"""Test file size calculation for empty file."""
|
||||
# Create an empty temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
@@ -149,7 +156,7 @@ class TestAudioUtils:
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_get_file_size_binary_file(self):
|
||||
def test_get_file_size_binary_file(self) -> None:
|
||||
"""Test file size calculation for binary file."""
|
||||
# Create a temporary file with binary content
|
||||
test_bytes = b"\x00\x01\x02\x03\xff\xfe\xfd" * 100 # 700 bytes
|
||||
@@ -163,14 +170,14 @@ class TestAudioUtils:
|
||||
|
||||
# Verify the size is correct
|
||||
assert result_size == len(test_bytes)
|
||||
assert result_size == 700
|
||||
assert result_size == BINARY_FILE_SIZE
|
||||
assert isinstance(result_size, int)
|
||||
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
@patch("app.utils.audio.ffmpeg.probe")
|
||||
def test_get_audio_duration_success(self, mock_probe):
|
||||
def test_get_audio_duration_success(self, mock_probe: MagicMock) -> None:
|
||||
"""Test successful audio duration extraction."""
|
||||
# Mock ffmpeg.probe to return duration
|
||||
mock_probe.return_value = {"format": {"duration": "123.456"}}
|
||||
@@ -179,12 +186,12 @@ class TestAudioUtils:
|
||||
duration = get_audio_duration(temp_path)
|
||||
|
||||
# Verify duration is converted correctly (seconds to milliseconds)
|
||||
assert duration == 123456 # 123.456 seconds * 1000 = 123456 ms
|
||||
assert duration == EXPECTED_DURATION_MS_1 # 123.456 seconds * 1000 = 123456 ms
|
||||
assert isinstance(duration, int)
|
||||
mock_probe.assert_called_once_with(str(temp_path))
|
||||
|
||||
@patch("app.utils.audio.ffmpeg.probe")
|
||||
def test_get_audio_duration_integer_duration(self, mock_probe):
|
||||
def test_get_audio_duration_integer_duration(self, mock_probe: MagicMock) -> None:
|
||||
"""Test audio duration extraction with integer duration."""
|
||||
# Mock ffmpeg.probe to return integer duration
|
||||
mock_probe.return_value = {"format": {"duration": "60"}}
|
||||
@@ -193,12 +200,12 @@ class TestAudioUtils:
|
||||
duration = get_audio_duration(temp_path)
|
||||
|
||||
# Verify duration is converted correctly
|
||||
assert duration == 60000 # 60 seconds * 1000 = 60000 ms
|
||||
assert duration == EXPECTED_DURATION_MS_2 # 60 seconds * 1000 = 60000 ms
|
||||
assert isinstance(duration, int)
|
||||
mock_probe.assert_called_once_with(str(temp_path))
|
||||
|
||||
@patch("app.utils.audio.ffmpeg.probe")
|
||||
def test_get_audio_duration_zero_duration(self, mock_probe):
|
||||
def test_get_audio_duration_zero_duration(self, mock_probe: MagicMock) -> None:
|
||||
"""Test audio duration extraction with zero duration."""
|
||||
# Mock ffmpeg.probe to return zero duration
|
||||
mock_probe.return_value = {"format": {"duration": "0.0"}}
|
||||
@@ -212,7 +219,9 @@ class TestAudioUtils:
|
||||
mock_probe.assert_called_once_with(str(temp_path))
|
||||
|
||||
@patch("app.utils.audio.ffmpeg.probe")
|
||||
def test_get_audio_duration_fractional_duration(self, mock_probe):
|
||||
def test_get_audio_duration_fractional_duration(
|
||||
self, mock_probe: MagicMock,
|
||||
) -> None:
|
||||
"""Test audio duration extraction with fractional seconds."""
|
||||
# Mock ffmpeg.probe to return fractional duration
|
||||
mock_probe.return_value = {"format": {"duration": "45.123"}}
|
||||
@@ -221,12 +230,12 @@ class TestAudioUtils:
|
||||
duration = get_audio_duration(temp_path)
|
||||
|
||||
# Verify duration is converted and rounded correctly
|
||||
assert duration == 45123 # 45.123 seconds * 1000 = 45123 ms
|
||||
assert duration == EXPECTED_DURATION_MS_3 # 45.123 seconds * 1000 = 45123 ms
|
||||
assert isinstance(duration, int)
|
||||
mock_probe.assert_called_once_with(str(temp_path))
|
||||
|
||||
@patch("app.utils.audio.ffmpeg.probe")
|
||||
def test_get_audio_duration_ffmpeg_error(self, mock_probe):
|
||||
def test_get_audio_duration_ffmpeg_error(self, mock_probe: MagicMock) -> None:
|
||||
"""Test audio duration extraction when ffmpeg fails."""
|
||||
# Mock ffmpeg.probe to raise an exception
|
||||
mock_probe.side_effect = Exception("FFmpeg error: file not found")
|
||||
@@ -240,7 +249,7 @@ class TestAudioUtils:
|
||||
mock_probe.assert_called_once_with(str(temp_path))
|
||||
|
||||
@patch("app.utils.audio.ffmpeg.probe")
|
||||
def test_get_audio_duration_missing_format(self, mock_probe):
|
||||
def test_get_audio_duration_missing_format(self, mock_probe: MagicMock) -> None:
|
||||
"""Test audio duration extraction when format info is missing."""
|
||||
# Mock ffmpeg.probe to return data without format info
|
||||
mock_probe.return_value = {"streams": []}
|
||||
@@ -254,7 +263,7 @@ class TestAudioUtils:
|
||||
mock_probe.assert_called_once_with(str(temp_path))
|
||||
|
||||
@patch("app.utils.audio.ffmpeg.probe")
|
||||
def test_get_audio_duration_missing_duration(self, mock_probe):
|
||||
def test_get_audio_duration_missing_duration(self, mock_probe: MagicMock) -> None:
|
||||
"""Test audio duration extraction when duration is missing."""
|
||||
# Mock ffmpeg.probe to return format without duration
|
||||
mock_probe.return_value = {"format": {"size": "1024"}}
|
||||
@@ -268,7 +277,7 @@ class TestAudioUtils:
|
||||
mock_probe.assert_called_once_with(str(temp_path))
|
||||
|
||||
@patch("app.utils.audio.ffmpeg.probe")
|
||||
def test_get_audio_duration_invalid_duration(self, mock_probe):
|
||||
def test_get_audio_duration_invalid_duration(self, mock_probe: MagicMock) -> None:
|
||||
"""Test audio duration extraction with invalid duration value."""
|
||||
# Mock ffmpeg.probe to return invalid duration
|
||||
mock_probe.return_value = {"format": {"duration": "invalid"}}
|
||||
@@ -281,7 +290,7 @@ class TestAudioUtils:
|
||||
assert isinstance(duration, int)
|
||||
mock_probe.assert_called_once_with(str(temp_path))
|
||||
|
||||
def test_get_file_hash_nonexistent_file(self):
|
||||
def test_get_file_hash_nonexistent_file(self) -> None:
|
||||
"""Test file hash calculation for nonexistent file."""
|
||||
nonexistent_path = Path("/fake/nonexistent/file.mp3")
|
||||
|
||||
@@ -289,7 +298,7 @@ class TestAudioUtils:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
get_file_hash(nonexistent_path)
|
||||
|
||||
def test_get_file_size_nonexistent_file(self):
|
||||
def test_get_file_size_nonexistent_file(self) -> None:
|
||||
"""Test file size calculation for nonexistent file."""
|
||||
nonexistent_path = Path("/fake/nonexistent/file.mp3")
|
||||
|
||||
@@ -297,7 +306,7 @@ class TestAudioUtils:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
get_file_size(nonexistent_path)
|
||||
|
||||
def test_get_sound_file_path_sdb_original(self):
|
||||
def test_get_sound_file_path_sdb_original(self) -> None:
|
||||
"""Test getting sound file path for SDB type original file."""
|
||||
sound = Sound(
|
||||
id=1,
|
||||
@@ -311,7 +320,7 @@ class TestAudioUtils:
|
||||
expected = Path("sounds/originals/soundboard/test.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_sdb_normalized(self):
|
||||
def test_get_sound_file_path_sdb_normalized(self) -> None:
|
||||
"""Test getting sound file path for SDB type normalized file."""
|
||||
sound = Sound(
|
||||
id=1,
|
||||
@@ -326,7 +335,7 @@ class TestAudioUtils:
|
||||
expected = Path("sounds/normalized/soundboard/normalized.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_tts_original(self):
|
||||
def test_get_sound_file_path_tts_original(self) -> None:
|
||||
"""Test getting sound file path for TTS type original file."""
|
||||
sound = Sound(
|
||||
id=2,
|
||||
@@ -340,7 +349,7 @@ class TestAudioUtils:
|
||||
expected = Path("sounds/originals/text_to_speech/tts_file.wav")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_tts_normalized(self):
|
||||
def test_get_sound_file_path_tts_normalized(self) -> None:
|
||||
"""Test getting sound file path for TTS type normalized file."""
|
||||
sound = Sound(
|
||||
id=2,
|
||||
@@ -355,7 +364,7 @@ class TestAudioUtils:
|
||||
expected = Path("sounds/normalized/text_to_speech/normalized.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_ext_original(self):
|
||||
def test_get_sound_file_path_ext_original(self) -> None:
|
||||
"""Test getting sound file path for EXT type original file."""
|
||||
sound = Sound(
|
||||
id=3,
|
||||
@@ -369,7 +378,7 @@ class TestAudioUtils:
|
||||
expected = Path("sounds/originals/extracted/extracted.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_ext_normalized(self):
|
||||
def test_get_sound_file_path_ext_normalized(self) -> None:
|
||||
"""Test getting sound file path for EXT type normalized file."""
|
||||
sound = Sound(
|
||||
id=3,
|
||||
@@ -384,7 +393,7 @@ class TestAudioUtils:
|
||||
expected = Path("sounds/normalized/extracted/normalized.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_unknown_type_fallback(self):
|
||||
def test_get_sound_file_path_unknown_type_fallback(self) -> None:
|
||||
"""Test getting sound file path for unknown type falls back to lowercase."""
|
||||
sound = Sound(
|
||||
id=4,
|
||||
@@ -398,7 +407,7 @@ class TestAudioUtils:
|
||||
expected = Path("sounds/originals/custom/unknown.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_normalized_without_filename(self):
|
||||
def test_get_sound_file_path_normalized_without_filename(self) -> None:
|
||||
"""Test getting sound file path when normalized but no normalized_filename."""
|
||||
sound = Sound(
|
||||
id=5,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for cookie utilities."""
|
||||
# ruff: noqa: ANN201, E501
|
||||
|
||||
from app.utils.cookies import extract_access_token_from_cookies, parse_cookies
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Tests for credit decorators."""
|
||||
# ruff: noqa: ARG001, ANN001, E501, PT012
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Never
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
@@ -17,7 +20,7 @@ class TestRequiresCreditsDecorator:
|
||||
"""Test requires_credits decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
def mock_credit_service(self) -> AsyncMock:
|
||||
"""Create a mock credit service."""
|
||||
service = AsyncMock(spec=CreditService)
|
||||
service.validate_and_reserve_credits = AsyncMock()
|
||||
@@ -25,12 +28,14 @@ class TestRequiresCreditsDecorator:
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def credit_service_factory(self, mock_credit_service):
|
||||
def credit_service_factory(self, mock_credit_service: AsyncMock) -> Callable[[], AsyncMock]:
|
||||
"""Create a credit service factory."""
|
||||
return lambda: mock_credit_service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_success(self, credit_service_factory, mock_credit_service):
|
||||
async def test_decorator_success(
|
||||
self, credit_service_factory: Callable[[], AsyncMock], mock_credit_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Test decorator with successful action."""
|
||||
|
||||
@requires_credits(
|
||||
@@ -52,10 +57,12 @@ class TestRequiresCreditsDecorator:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_metadata(self, credit_service_factory, mock_credit_service):
|
||||
async def test_decorator_with_metadata(
|
||||
self, credit_service_factory: Callable[[], AsyncMock], mock_credit_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Test decorator with metadata extraction."""
|
||||
|
||||
def extract_metadata(user_id: int, sound_name: str) -> dict:
|
||||
def extract_metadata(user_id: int, sound_name: str) -> dict[str, str]:
|
||||
return {"sound_name": sound_name}
|
||||
|
||||
@requires_credits(
|
||||
@@ -77,7 +84,7 @@ class TestRequiresCreditsDecorator:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_failed_action(self, credit_service_factory, mock_credit_service):
|
||||
async def test_decorator_failed_action(self, credit_service_factory, mock_credit_service) -> None:
|
||||
"""Test decorator with failed action."""
|
||||
|
||||
@requires_credits(
|
||||
@@ -96,7 +103,7 @@ class TestRequiresCreditsDecorator:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_exception_in_action(self, credit_service_factory, mock_credit_service):
|
||||
async def test_decorator_exception_in_action(self, credit_service_factory, mock_credit_service) -> None:
|
||||
"""Test decorator when action raises exception."""
|
||||
|
||||
@requires_credits(
|
||||
@@ -105,7 +112,8 @@ class TestRequiresCreditsDecorator:
|
||||
user_id_param="user_id",
|
||||
)
|
||||
async def test_action(user_id: int) -> str:
|
||||
raise ValueError("Test error")
|
||||
msg = "Test error"
|
||||
raise ValueError(msg)
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await test_action(user_id=123)
|
||||
@@ -115,7 +123,7 @@ class TestRequiresCreditsDecorator:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_insufficient_credits(self, credit_service_factory, mock_credit_service):
|
||||
async def test_decorator_insufficient_credits(self, credit_service_factory, mock_credit_service) -> None:
|
||||
"""Test decorator with insufficient credits."""
|
||||
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
|
||||
|
||||
@@ -134,7 +142,7 @@ class TestRequiresCreditsDecorator:
|
||||
mock_credit_service.deduct_credits.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_user_id_in_args(self, credit_service_factory, mock_credit_service):
|
||||
async def test_decorator_user_id_in_args(self, credit_service_factory, mock_credit_service) -> None:
|
||||
"""Test decorator extracting user_id from positional args."""
|
||||
|
||||
@requires_credits(
|
||||
@@ -153,7 +161,7 @@ class TestRequiresCreditsDecorator:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_missing_user_id(self, credit_service_factory):
|
||||
async def test_decorator_missing_user_id(self, credit_service_factory) -> None:
|
||||
"""Test decorator when user_id cannot be extracted."""
|
||||
|
||||
@requires_credits(
|
||||
@@ -172,19 +180,19 @@ class TestValidateCreditsOnlyDecorator:
|
||||
"""Test validate_credits_only decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
def mock_credit_service(self) -> AsyncMock:
|
||||
"""Create a mock credit service."""
|
||||
service = AsyncMock(spec=CreditService)
|
||||
service.validate_and_reserve_credits = AsyncMock()
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def credit_service_factory(self, mock_credit_service):
|
||||
def credit_service_factory(self, mock_credit_service: AsyncMock) -> Callable[[], AsyncMock]:
|
||||
"""Create a credit service factory."""
|
||||
return lambda: mock_credit_service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_only_decorator(self, credit_service_factory, mock_credit_service):
|
||||
async def test_validate_only_decorator(self, credit_service_factory, mock_credit_service) -> None:
|
||||
"""Test validate_credits_only decorator."""
|
||||
|
||||
@validate_credits_only(
|
||||
@@ -209,7 +217,7 @@ class TestCreditManager:
|
||||
"""Test CreditManager context manager."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
def mock_credit_service(self) -> AsyncMock:
|
||||
"""Create a mock credit service."""
|
||||
service = AsyncMock(spec=CreditService)
|
||||
service.validate_and_reserve_credits = AsyncMock()
|
||||
@@ -217,7 +225,7 @@ class TestCreditManager:
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_success(self, mock_credit_service):
|
||||
async def test_credit_manager_success(self, mock_credit_service) -> None:
|
||||
"""Test CreditManager with successful operation."""
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
@@ -235,7 +243,7 @@ class TestCreditManager:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_failure(self, mock_credit_service):
|
||||
async def test_credit_manager_failure(self, mock_credit_service) -> None:
|
||||
"""Test CreditManager with failed operation."""
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
@@ -250,7 +258,7 @@ class TestCreditManager:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_exception(self, mock_credit_service):
|
||||
async def test_credit_manager_exception(self, mock_credit_service) -> Never:
|
||||
"""Test CreditManager when exception occurs."""
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
async with CreditManager(
|
||||
@@ -258,14 +266,15 @@ class TestCreditManager:
|
||||
123,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
):
|
||||
raise ValueError("Test error")
|
||||
msg = "Test error"
|
||||
raise ValueError(msg)
|
||||
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, success=False, metadata=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_validation_failure(self, mock_credit_service):
|
||||
async def test_credit_manager_validation_failure(self, mock_credit_service) -> None:
|
||||
"""Test CreditManager when validation fails."""
|
||||
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user