Add tests for sound repository, user OAuth repository, credit service, and credit decorators
- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities. - Create tests for UserOauthRepository to validate OAuth record management. - Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits. - Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
@@ -13,8 +13,10 @@ from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.credit_transaction import CreditTransaction # Ensure model is imported for SQLAlchemy
|
||||
from app.models.plan import Plan
|
||||
from app.models.user import User
|
||||
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
|
||||
from app.utils.auth import JWTUtils, PasswordUtils
|
||||
|
||||
|
||||
|
||||
412
tests/repositories/test_credit_transaction.py
Normal file
412
tests/repositories/test_credit_transaction.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""Tests for credit transaction repository."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.user import User
|
||||
from app.repositories.credit_transaction import CreditTransactionRepository
|
||||
|
||||
|
||||
class TestCreditTransactionRepository:
|
||||
"""Test credit transaction repository operations."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def credit_transaction_repository(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[CreditTransactionRepository, None]: # type: ignore[misc]
|
||||
"""Create a credit transaction repository instance."""
|
||||
yield CreditTransactionRepository(test_session)
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_id(
|
||||
self,
|
||||
test_user: User,
|
||||
) -> int:
|
||||
"""Get test user ID to avoid lazy loading issues."""
|
||||
return test_user.id
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_transactions(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
test_user_id: int,
|
||||
) -> AsyncGenerator[list[CreditTransaction], None]: # type: ignore[misc]
|
||||
"""Create test credit transactions."""
|
||||
transactions = []
|
||||
user_id = test_user_id
|
||||
|
||||
# Create various types of transactions
|
||||
transaction_data = [
|
||||
{
|
||||
"user_id": user_id,
|
||||
"action_type": "vlc_play_sound",
|
||||
"amount": -1,
|
||||
"balance_before": 10,
|
||||
"balance_after": 9,
|
||||
"description": "Play sound via VLC",
|
||||
"success": True,
|
||||
"metadata_json": json.dumps({"sound_id": 1, "sound_name": "test.mp3"}),
|
||||
},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"action_type": "audio_extraction",
|
||||
"amount": -5,
|
||||
"balance_before": 9,
|
||||
"balance_after": 4,
|
||||
"description": "Extract audio from URL",
|
||||
"success": True,
|
||||
"metadata_json": json.dumps({"url": "https://example.com/video"}),
|
||||
},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"action_type": "vlc_play_sound",
|
||||
"amount": 0,
|
||||
"balance_before": 4,
|
||||
"balance_after": 4,
|
||||
"description": "Play sound via VLC (failed)",
|
||||
"success": False,
|
||||
"metadata_json": json.dumps({"sound_id": 2, "error": "File not found"}),
|
||||
},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"action_type": "credit_addition",
|
||||
"amount": 50,
|
||||
"balance_before": 4,
|
||||
"balance_after": 54,
|
||||
"description": "Bonus credits",
|
||||
"success": True,
|
||||
"metadata_json": json.dumps({"reason": "signup_bonus"}),
|
||||
},
|
||||
]
|
||||
|
||||
for data in transaction_data:
|
||||
transaction = CreditTransaction(**data)
|
||||
test_session.add(transaction)
|
||||
transactions.append(transaction)
|
||||
|
||||
await test_session.commit()
|
||||
for transaction in transactions:
|
||||
await test_session.refresh(transaction)
|
||||
|
||||
yield transactions
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def other_user_transaction(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans: tuple, # noqa: ARG002
|
||||
) -> AsyncGenerator[CreditTransaction, None]: # type: ignore[misc]
|
||||
"""Create a transaction for a different user."""
|
||||
from app.models.plan import Plan
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
# Create another user
|
||||
user_repo = UserRepository(test_session)
|
||||
other_user_data = {
|
||||
"email": "other@example.com",
|
||||
"name": "Other User",
|
||||
"password_hash": "hashed_password",
|
||||
"role": "user",
|
||||
"is_active": True,
|
||||
}
|
||||
other_user = await user_repo.create(other_user_data)
|
||||
|
||||
# Create transaction for the other user
|
||||
transaction_data = {
|
||||
"user_id": other_user.id,
|
||||
"action_type": "vlc_play_sound",
|
||||
"amount": -1,
|
||||
"balance_before": 100,
|
||||
"balance_after": 99,
|
||||
"description": "Other user play sound",
|
||||
"success": True,
|
||||
"metadata_json": None,
|
||||
}
|
||||
transaction = CreditTransaction(**transaction_data)
|
||||
test_session.add(transaction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(transaction)
|
||||
|
||||
yield transaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_existing(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test getting transaction by ID when it exists."""
|
||||
transaction = await credit_transaction_repository.get_by_id(test_transactions[0].id)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.id == test_transactions[0].id
|
||||
assert transaction.action_type == "vlc_play_sound"
|
||||
assert transaction.amount == -1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_nonexistent(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
) -> None:
|
||||
"""Test getting transaction by ID when it doesn't exist."""
|
||||
transaction = await credit_transaction_repository.get_by_id(99999)
|
||||
|
||||
assert transaction is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_id(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
other_user_transaction: CreditTransaction,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting transactions by user ID."""
|
||||
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
|
||||
|
||||
# Should return all transactions for test_user
|
||||
assert len(transactions) == 4
|
||||
# Should be ordered by created_at desc (newest first)
|
||||
assert all(t.user_id == test_user_id for t in transactions)
|
||||
|
||||
# Should not include other user's transaction
|
||||
other_user_ids = [t.user_id for t in transactions]
|
||||
assert other_user_transaction.user_id not in other_user_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_id_with_pagination(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting transactions by user ID with pagination."""
|
||||
# Get first 2 transactions
|
||||
first_page = await credit_transaction_repository.get_by_user_id(
|
||||
test_user_id, limit=2, offset=0
|
||||
)
|
||||
assert len(first_page) == 2
|
||||
|
||||
# Get next 2 transactions
|
||||
second_page = await credit_transaction_repository.get_by_user_id(
|
||||
test_user_id, limit=2, offset=2
|
||||
)
|
||||
assert len(second_page) == 2
|
||||
|
||||
# Should not overlap
|
||||
first_page_ids = {t.id for t in first_page}
|
||||
second_page_ids = {t.id for t in second_page}
|
||||
assert first_page_ids.isdisjoint(second_page_ids)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_action_type(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test getting transactions by action type."""
|
||||
vlc_transactions = await credit_transaction_repository.get_by_action_type(
|
||||
"vlc_play_sound"
|
||||
)
|
||||
|
||||
# Should return 2 VLC transactions (1 successful, 1 failed)
|
||||
assert len(vlc_transactions) >= 2
|
||||
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
|
||||
|
||||
extraction_transactions = await credit_transaction_repository.get_by_action_type(
|
||||
"audio_extraction"
|
||||
)
|
||||
|
||||
# Should return 1 extraction transaction
|
||||
assert len(extraction_transactions) >= 1
|
||||
assert all(t.action_type == "audio_extraction" for t in extraction_transactions)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_action_type_with_pagination(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test getting transactions by action type with pagination."""
|
||||
# Test with limit
|
||||
transactions = await credit_transaction_repository.get_by_action_type(
|
||||
"vlc_play_sound", limit=1
|
||||
)
|
||||
assert len(transactions) == 1
|
||||
assert transactions[0].action_type == "vlc_play_sound"
|
||||
|
||||
# Test with offset
|
||||
transactions = await credit_transaction_repository.get_by_action_type(
|
||||
"vlc_play_sound", limit=1, offset=1
|
||||
)
|
||||
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_successful_transactions(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test getting only successful transactions."""
|
||||
successful_transactions = await credit_transaction_repository.get_successful_transactions()
|
||||
|
||||
# Should only return successful transactions
|
||||
assert all(t.success is True for t in successful_transactions)
|
||||
# Should be at least 3 (vlc_play_sound, audio_extraction, credit_addition)
|
||||
assert len(successful_transactions) >= 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_successful_transactions_by_user(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
other_user_transaction: CreditTransaction,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting successful transactions filtered by user."""
|
||||
successful_transactions = await credit_transaction_repository.get_successful_transactions(
|
||||
user_id=test_user_id
|
||||
)
|
||||
|
||||
# Should only return successful transactions for test_user
|
||||
assert all(t.success is True for t in successful_transactions)
|
||||
assert all(t.user_id == test_user_id for t in successful_transactions)
|
||||
# Should be 3 successful transactions for test_user
|
||||
assert len(successful_transactions) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_successful_transactions_with_pagination(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting successful transactions with pagination."""
|
||||
# Get first 2 successful transactions
|
||||
first_page = await credit_transaction_repository.get_successful_transactions(
|
||||
user_id=test_user_id, limit=2, offset=0
|
||||
)
|
||||
assert len(first_page) == 2
|
||||
assert all(t.success is True for t in first_page)
|
||||
|
||||
# Get next successful transaction
|
||||
second_page = await credit_transaction_repository.get_successful_transactions(
|
||||
user_id=test_user_id, limit=2, offset=2
|
||||
)
|
||||
assert len(second_page) == 1 # Should be 1 remaining
|
||||
assert all(t.success is True for t in second_page)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_transactions(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
other_user_transaction: CreditTransaction,
|
||||
) -> None:
|
||||
"""Test getting all transactions."""
|
||||
all_transactions = await credit_transaction_repository.get_all()
|
||||
|
||||
# Should return all transactions
|
||||
assert len(all_transactions) >= 5 # 4 from test_transactions + 1 other_user_transaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_transaction(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test creating a new transaction."""
|
||||
transaction_data = {
|
||||
"user_id": test_user_id,
|
||||
"action_type": "test_action",
|
||||
"amount": -10,
|
||||
"balance_before": 100,
|
||||
"balance_after": 90,
|
||||
"description": "Test transaction",
|
||||
"success": True,
|
||||
"metadata_json": json.dumps({"test": "data"}),
|
||||
}
|
||||
|
||||
transaction = await credit_transaction_repository.create(transaction_data)
|
||||
|
||||
assert transaction.id is not None
|
||||
assert transaction.user_id == test_user_id
|
||||
assert transaction.action_type == "test_action"
|
||||
assert transaction.amount == -10
|
||||
assert transaction.balance_before == 100
|
||||
assert transaction.balance_after == 90
|
||||
assert transaction.success is True
|
||||
assert json.loads(transaction.metadata_json) == {"test": "data"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_transaction(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test updating a transaction."""
|
||||
transaction = test_transactions[0]
|
||||
update_data = {
|
||||
"description": "Updated description",
|
||||
"metadata_json": json.dumps({"updated": True}),
|
||||
}
|
||||
|
||||
updated_transaction = await credit_transaction_repository.update(
|
||||
transaction, update_data
|
||||
)
|
||||
|
||||
assert updated_transaction.id == transaction.id
|
||||
assert updated_transaction.description == "Updated description"
|
||||
assert json.loads(updated_transaction.metadata_json) == {"updated": True}
|
||||
# Other fields should remain unchanged
|
||||
assert updated_transaction.amount == transaction.amount
|
||||
assert updated_transaction.action_type == transaction.action_type
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_transaction(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_session: AsyncSession,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test deleting a transaction."""
|
||||
# Create a transaction to delete
|
||||
transaction_data = {
|
||||
"user_id": test_user_id,
|
||||
"action_type": "to_delete",
|
||||
"amount": -1,
|
||||
"balance_before": 10,
|
||||
"balance_after": 9,
|
||||
"description": "To be deleted",
|
||||
"success": True,
|
||||
"metadata_json": None,
|
||||
}
|
||||
transaction = await credit_transaction_repository.create(transaction_data)
|
||||
transaction_id = transaction.id
|
||||
|
||||
# Delete the transaction
|
||||
await credit_transaction_repository.delete(transaction)
|
||||
|
||||
# Verify transaction is deleted
|
||||
deleted_transaction = await credit_transaction_repository.get_by_id(transaction_id)
|
||||
assert deleted_transaction is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_ordering(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test that transactions are ordered by created_at desc."""
|
||||
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
|
||||
|
||||
# Should be ordered by created_at desc (newest first)
|
||||
for i in range(len(transactions) - 1):
|
||||
assert transactions[i].created_at >= transactions[i + 1].created_at
|
||||
376
tests/repositories/test_sound.py
Normal file
376
tests/repositories/test_sound.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""Tests for sound repository."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.sound import Sound
|
||||
from app.repositories.sound import SoundRepository
|
||||
|
||||
|
||||
class TestSoundRepository:
|
||||
"""Test sound repository operations."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sound_repository(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[SoundRepository, None]: # type: ignore[misc]
|
||||
"""Create a sound repository instance."""
|
||||
yield SoundRepository(test_session)
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_sound(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
|
||||
"""Create a test sound."""
|
||||
sound_data = {
|
||||
"name": "Test Sound",
|
||||
"filename": "test_sound.mp3",
|
||||
"type": "SDB",
|
||||
"duration": 5000,
|
||||
"size": 1024000,
|
||||
"hash": "test_hash_123",
|
||||
"play_count": 0,
|
||||
"is_normalized": False,
|
||||
}
|
||||
sound = Sound(**sound_data)
|
||||
test_session.add(sound)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(sound)
|
||||
yield sound
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def normalized_sound(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
|
||||
"""Create a normalized test sound."""
|
||||
sound_data = {
|
||||
"name": "Normalized Sound",
|
||||
"filename": "normalized_sound.mp3",
|
||||
"type": "TTS",
|
||||
"duration": 3000,
|
||||
"size": 512000,
|
||||
"hash": "normalized_hash_456",
|
||||
"play_count": 5,
|
||||
"is_normalized": True,
|
||||
"normalized_filename": "normalized_sound_norm.mp3",
|
||||
"normalized_duration": 3000,
|
||||
"normalized_size": 480000,
|
||||
"normalized_hash": "normalized_hash_norm_456",
|
||||
}
|
||||
sound = Sound(**sound_data)
|
||||
test_session.add(sound)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(sound)
|
||||
yield sound
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_existing(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting sound by ID when it exists."""
|
||||
sound = await sound_repository.get_by_id(test_sound.id)
|
||||
|
||||
assert sound is not None
|
||||
assert sound.id == test_sound.id
|
||||
assert sound.name == test_sound.name
|
||||
assert sound.filename == test_sound.filename
|
||||
assert sound.type == test_sound.type
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_nonexistent(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
) -> None:
|
||||
"""Test getting sound by ID when it doesn't exist."""
|
||||
sound = await sound_repository.get_by_id(99999)
|
||||
|
||||
assert sound is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_filename_existing(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting sound by filename when it exists."""
|
||||
sound = await sound_repository.get_by_filename(test_sound.filename)
|
||||
|
||||
assert sound is not None
|
||||
assert sound.id == test_sound.id
|
||||
assert sound.filename == test_sound.filename
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_filename_nonexistent(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
) -> None:
|
||||
"""Test getting sound by filename when it doesn't exist."""
|
||||
sound = await sound_repository.get_by_filename("nonexistent.mp3")
|
||||
|
||||
assert sound is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_hash_existing(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting sound by hash when it exists."""
|
||||
sound = await sound_repository.get_by_hash(test_sound.hash)
|
||||
|
||||
assert sound is not None
|
||||
assert sound.id == test_sound.id
|
||||
assert sound.hash == test_sound.hash
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_hash_nonexistent(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
) -> None:
|
||||
"""Test getting sound by hash when it doesn't exist."""
|
||||
sound = await sound_repository.get_by_hash("nonexistent_hash")
|
||||
|
||||
assert sound is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_type(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting sounds by type."""
|
||||
sdb_sounds = await sound_repository.get_by_type("SDB")
|
||||
tts_sounds = await sound_repository.get_by_type("TTS")
|
||||
ext_sounds = await sound_repository.get_by_type("EXT")
|
||||
|
||||
# Should find the SDB sound
|
||||
assert len(sdb_sounds) >= 1
|
||||
assert any(sound.id == test_sound.id for sound in sdb_sounds)
|
||||
|
||||
# Should find the TTS sound
|
||||
assert len(tts_sounds) >= 1
|
||||
assert any(sound.id == normalized_sound.id for sound in tts_sounds)
|
||||
|
||||
# Should not find any EXT sounds
|
||||
assert len(ext_sounds) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_sound(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
) -> None:
|
||||
"""Test creating a new sound."""
|
||||
sound_data = {
|
||||
"name": "New Sound",
|
||||
"filename": "new_sound.wav",
|
||||
"type": "EXT",
|
||||
"duration": 7500,
|
||||
"size": 2048000,
|
||||
"hash": "new_hash_789",
|
||||
"play_count": 0,
|
||||
"is_normalized": False,
|
||||
}
|
||||
|
||||
sound = await sound_repository.create(sound_data)
|
||||
|
||||
assert sound.id is not None
|
||||
assert sound.name == sound_data["name"]
|
||||
assert sound.filename == sound_data["filename"]
|
||||
assert sound.type == sound_data["type"]
|
||||
assert sound.duration == sound_data["duration"]
|
||||
assert sound.size == sound_data["size"]
|
||||
assert sound.hash == sound_data["hash"]
|
||||
assert sound.play_count == 0
|
||||
assert sound.is_normalized is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_sound(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test updating a sound."""
|
||||
update_data = {
|
||||
"name": "Updated Sound Name",
|
||||
"play_count": 10,
|
||||
"is_normalized": True,
|
||||
"normalized_filename": "updated_norm.mp3",
|
||||
}
|
||||
|
||||
updated_sound = await sound_repository.update(test_sound, update_data)
|
||||
|
||||
assert updated_sound.id == test_sound.id
|
||||
assert updated_sound.name == "Updated Sound Name"
|
||||
assert updated_sound.play_count == 10
|
||||
assert updated_sound.is_normalized is True
|
||||
assert updated_sound.normalized_filename == "updated_norm.mp3"
|
||||
assert updated_sound.filename == test_sound.filename # Unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_sound(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test deleting a sound."""
|
||||
# Create a sound to delete
|
||||
sound_data = {
|
||||
"name": "To Delete",
|
||||
"filename": "to_delete.mp3",
|
||||
"type": "SDB",
|
||||
"duration": 1000,
|
||||
"size": 256000,
|
||||
"hash": "delete_hash",
|
||||
"play_count": 0,
|
||||
"is_normalized": False,
|
||||
}
|
||||
sound = await sound_repository.create(sound_data)
|
||||
sound_id = sound.id
|
||||
|
||||
# Delete the sound
|
||||
await sound_repository.delete(sound)
|
||||
|
||||
# Verify sound is deleted
|
||||
deleted_sound = await sound_repository.get_by_id(sound_id)
|
||||
assert deleted_sound is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_name(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test searching sounds by name."""
|
||||
# Search for "test" should find test_sound
|
||||
results = await sound_repository.search_by_name("test")
|
||||
assert len(results) >= 1
|
||||
assert any(sound.id == test_sound.id for sound in results)
|
||||
|
||||
# Search for "normalized" should find normalized_sound
|
||||
results = await sound_repository.search_by_name("normalized")
|
||||
assert len(results) >= 1
|
||||
assert any(sound.id == normalized_sound.id for sound in results)
|
||||
|
||||
# Case insensitive search
|
||||
results = await sound_repository.search_by_name("TEST")
|
||||
assert len(results) >= 1
|
||||
assert any(sound.id == test_sound.id for sound in results)
|
||||
|
||||
# Partial match
|
||||
results = await sound_repository.search_by_name("norm")
|
||||
assert len(results) >= 1
|
||||
assert any(sound.id == normalized_sound.id for sound in results)
|
||||
|
||||
# No matches
|
||||
results = await sound_repository.search_by_name("nonexistent")
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_popular_sounds(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting popular sounds."""
|
||||
# Update play counts to test ordering
|
||||
await sound_repository.update(test_sound, {"play_count": 15})
|
||||
await sound_repository.update(normalized_sound, {"play_count": 5})
|
||||
|
||||
# Create another sound with higher play count
|
||||
high_play_sound_data = {
|
||||
"name": "Popular Sound",
|
||||
"filename": "popular.mp3",
|
||||
"type": "SDB",
|
||||
"duration": 2000,
|
||||
"size": 300000,
|
||||
"hash": "popular_hash",
|
||||
"play_count": 25,
|
||||
"is_normalized": False,
|
||||
}
|
||||
high_play_sound = await sound_repository.create(high_play_sound_data)
|
||||
|
||||
# Get popular sounds
|
||||
popular_sounds = await sound_repository.get_popular_sounds(limit=10)
|
||||
|
||||
assert len(popular_sounds) >= 3
|
||||
# Should be ordered by play_count desc
|
||||
assert popular_sounds[0].play_count >= popular_sounds[1].play_count
|
||||
# The highest play count sound should be first
|
||||
assert popular_sounds[0].id == high_play_sound.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unnormalized_sounds(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting unnormalized sounds."""
|
||||
unnormalized_sounds = await sound_repository.get_unnormalized_sounds()
|
||||
|
||||
# Should include test_sound (not normalized)
|
||||
assert any(sound.id == test_sound.id for sound in unnormalized_sounds)
|
||||
# Should not include normalized_sound (already normalized)
|
||||
assert not any(sound.id == normalized_sound.id for sound in unnormalized_sounds)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unnormalized_sounds_by_type(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting unnormalized sounds by type."""
|
||||
# Get unnormalized SDB sounds
|
||||
sdb_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("SDB")
|
||||
# Should include test_sound (SDB, not normalized)
|
||||
assert any(sound.id == test_sound.id for sound in sdb_unnormalized)
|
||||
|
||||
# Get unnormalized TTS sounds
|
||||
tts_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("TTS")
|
||||
# Should not include normalized_sound (TTS, but already normalized)
|
||||
assert not any(sound.id == normalized_sound.id for sound in tts_unnormalized)
|
||||
|
||||
# Get unnormalized EXT sounds
|
||||
ext_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("EXT")
|
||||
# Should be empty
|
||||
assert len(ext_unnormalized) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_hash(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test creating sound with duplicate hash is allowed."""
|
||||
# Store the hash to avoid lazy loading issues
|
||||
original_hash = test_sound.hash
|
||||
|
||||
duplicate_sound_data = {
|
||||
"name": "Duplicate Hash Sound",
|
||||
"filename": "duplicate.mp3",
|
||||
"type": "SDB",
|
||||
"duration": 1000,
|
||||
"size": 100000,
|
||||
"hash": original_hash, # Same hash as test_sound
|
||||
"play_count": 0,
|
||||
"is_normalized": False,
|
||||
}
|
||||
|
||||
# Should succeed - duplicate hashes are allowed
|
||||
duplicate_sound = await sound_repository.create(duplicate_sound_data)
|
||||
|
||||
assert duplicate_sound.id is not None
|
||||
assert duplicate_sound.name == "Duplicate Hash Sound"
|
||||
assert duplicate_sound.hash == original_hash # Same hash is allowed
|
||||
268
tests/repositories/test_user_oauth.py
Normal file
268
tests/repositories/test_user_oauth.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Tests for user OAuth repository."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.user_oauth import UserOauth
|
||||
from app.repositories.user_oauth import UserOauthRepository
|
||||
|
||||
|
||||
class TestUserOauthRepository:
|
||||
"""Test user OAuth repository operations."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def user_oauth_repository(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[UserOauthRepository, None]: # type: ignore[misc]
|
||||
"""Create a user OAuth repository instance."""
|
||||
yield UserOauthRepository(test_session)
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_id(
|
||||
self,
|
||||
test_user: User,
|
||||
) -> int:
|
||||
"""Get test user ID to avoid lazy loading issues."""
|
||||
return test_user.id
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_oauth(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
test_user_id: int,
|
||||
) -> AsyncGenerator[UserOauth, None]: # type: ignore[misc]
|
||||
"""Create a test OAuth record."""
|
||||
oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "google",
|
||||
"provider_user_id": "google_123456",
|
||||
"email": "test@gmail.com",
|
||||
"name": "Test User Google",
|
||||
"picture": None,
|
||||
}
|
||||
oauth = UserOauth(**oauth_data)
|
||||
test_session.add(oauth)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(oauth)
|
||||
yield oauth
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_provider_user_id_existing(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_oauth: UserOauth,
|
||||
) -> None:
|
||||
"""Test getting OAuth by provider user ID when it exists."""
|
||||
oauth = await user_oauth_repository.get_by_provider_user_id(
|
||||
"google", "google_123456"
|
||||
)
|
||||
|
||||
assert oauth is not None
|
||||
assert oauth.id == test_oauth.id
|
||||
assert oauth.provider == "google"
|
||||
assert oauth.provider_user_id == "google_123456"
|
||||
assert oauth.user_id == test_oauth.user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_provider_user_id_nonexistent(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
) -> None:
|
||||
"""Test getting OAuth by provider user ID when it doesn't exist."""
|
||||
oauth = await user_oauth_repository.get_by_provider_user_id(
|
||||
"google", "nonexistent_id"
|
||||
)
|
||||
|
||||
assert oauth is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_id_and_provider_existing(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_oauth: UserOauth,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting OAuth by user ID and provider when it exists."""
|
||||
oauth = await user_oauth_repository.get_by_user_id_and_provider(
|
||||
test_user_id, "google"
|
||||
)
|
||||
|
||||
assert oauth is not None
|
||||
assert oauth.id == test_oauth.id
|
||||
assert oauth.provider == "google"
|
||||
assert oauth.user_id == test_user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_id_and_provider_nonexistent(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting OAuth by user ID and provider when it doesn't exist."""
|
||||
oauth = await user_oauth_repository.get_by_user_id_and_provider(
|
||||
test_user_id, "github"
|
||||
)
|
||||
|
||||
assert oauth is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_oauth(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test creating a new OAuth record."""
|
||||
oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "github",
|
||||
"provider_user_id": "github_789",
|
||||
"email": "test@github.com",
|
||||
"name": "Test User GitHub",
|
||||
"picture": None,
|
||||
}
|
||||
|
||||
oauth = await user_oauth_repository.create(oauth_data)
|
||||
|
||||
assert oauth.id is not None
|
||||
assert oauth.user_id == test_user_id
|
||||
assert oauth.provider == "github"
|
||||
assert oauth.provider_user_id == "github_789"
|
||||
assert oauth.email == "test@github.com"
|
||||
assert oauth.name == "Test User GitHub"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_oauth(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_oauth: UserOauth,
|
||||
) -> None:
|
||||
"""Test updating an OAuth record."""
|
||||
update_data = {
|
||||
"email": "updated@gmail.com",
|
||||
"name": "Updated User Name",
|
||||
"picture": "https://example.com/photo.jpg",
|
||||
}
|
||||
|
||||
updated_oauth = await user_oauth_repository.update(test_oauth, update_data)
|
||||
|
||||
assert updated_oauth.id == test_oauth.id
|
||||
assert updated_oauth.email == "updated@gmail.com"
|
||||
assert updated_oauth.name == "Updated User Name"
|
||||
assert updated_oauth.picture == "https://example.com/photo.jpg"
|
||||
assert updated_oauth.provider == test_oauth.provider # Unchanged
|
||||
assert updated_oauth.provider_user_id == test_oauth.provider_user_id # Unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_oauth(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_session: AsyncSession,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test deleting an OAuth record."""
|
||||
# Create an OAuth record to delete
|
||||
oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "twitter",
|
||||
"provider_user_id": "twitter_456",
|
||||
"email": "test@twitter.com",
|
||||
"name": "Test User Twitter",
|
||||
"picture": None,
|
||||
}
|
||||
oauth = await user_oauth_repository.create(oauth_data)
|
||||
oauth_id = oauth.id
|
||||
|
||||
# Delete the OAuth record
|
||||
await user_oauth_repository.delete(oauth)
|
||||
|
||||
# Verify it's deleted by trying to find it
|
||||
deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
|
||||
"twitter", "twitter_456"
|
||||
)
|
||||
assert deleted_oauth is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_provider_user_id(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_oauth: UserOauth,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test creating OAuth with duplicate provider user ID should fail."""
|
||||
# Try to create another OAuth with the same provider and provider_user_id
|
||||
duplicate_oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "google",
|
||||
"provider_user_id": "google_123456", # Same as test_oauth
|
||||
"email": "another@gmail.com",
|
||||
"name": "Another User",
|
||||
"picture": None,
|
||||
}
|
||||
|
||||
# This should fail due to unique constraint
|
||||
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
|
||||
await user_oauth_repository.create(duplicate_oauth_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_providers_same_user(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test that a user can have multiple OAuth providers."""
|
||||
# Create Google OAuth
|
||||
google_oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "google",
|
||||
"provider_user_id": "google_user_1",
|
||||
"email": "user@gmail.com",
|
||||
"name": "Test User Google",
|
||||
"picture": None,
|
||||
}
|
||||
google_oauth = await user_oauth_repository.create(google_oauth_data)
|
||||
|
||||
# Create GitHub OAuth for the same user
|
||||
github_oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "github",
|
||||
"provider_user_id": "github_user_1",
|
||||
"email": "user@github.com",
|
||||
"name": "Test User GitHub",
|
||||
"picture": None,
|
||||
}
|
||||
github_oauth = await user_oauth_repository.create(github_oauth_data)
|
||||
|
||||
# Verify both exist by querying back from database
|
||||
found_google = await user_oauth_repository.get_by_user_id_and_provider(
|
||||
test_user_id, "google"
|
||||
)
|
||||
found_github = await user_oauth_repository.get_by_user_id_and_provider(
|
||||
test_user_id, "github"
|
||||
)
|
||||
|
||||
assert found_google is not None
|
||||
assert found_github is not None
|
||||
assert found_google.provider == "google"
|
||||
assert found_github.provider == "github"
|
||||
assert found_google.user_id == test_user_id
|
||||
assert found_github.user_id == test_user_id
|
||||
assert found_google.provider_user_id == "google_user_1"
|
||||
assert found_github.provider_user_id == "github_user_1"
|
||||
|
||||
# Verify we can also find them by provider_user_id
|
||||
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
|
||||
"google", "google_user_1"
|
||||
)
|
||||
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
|
||||
"github", "github_user_1"
|
||||
)
|
||||
|
||||
assert found_google_by_provider is not None
|
||||
assert found_github_by_provider is not None
|
||||
assert found_google_by_provider.user_id == test_user_id
|
||||
assert found_github_by_provider.user_id == test_user_id
|
||||
358
tests/services/test_credit.py
Normal file
358
tests/services/test_credit.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Tests for credit service."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.user import User
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
|
||||
|
||||
class TestCreditService:
|
||||
"""Test credit service functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session_factory(self):
|
||||
"""Create a mock database session factory."""
|
||||
session = AsyncMock(spec=AsyncSession)
|
||||
return lambda: session
|
||||
|
||||
@pytest.fixture
|
||||
def credit_service(self, mock_db_session_factory):
|
||||
"""Create a credit service instance for testing."""
|
||||
return CreditService(mock_db_session_factory)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user(self):
|
||||
"""Create a sample user for testing."""
|
||||
return User(
|
||||
id=1,
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
role="user",
|
||||
credits=10,
|
||||
plan_id=1,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_credits_sufficient(self, credit_service, sample_user):
|
||||
"""Test checking credits when user has sufficient credits."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
|
||||
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
|
||||
|
||||
assert result is True
|
||||
mock_repo.get_by_id.assert_called_once_with(1)
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_credits_insufficient(self, credit_service):
|
||||
"""Test checking credits when user has insufficient credits."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
poor_user = User(
|
||||
id=1,
|
||||
name="Poor User",
|
||||
email="poor@example.com",
|
||||
role="user",
|
||||
credits=0, # No credits
|
||||
plan_id=1,
|
||||
)
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = poor_user
|
||||
|
||||
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
|
||||
|
||||
assert result is False
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_credits_user_not_found(self, credit_service):
|
||||
"""Test checking credits when user is not found."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = None
|
||||
|
||||
result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND)
|
||||
|
||||
assert result is False
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_reserve_credits_success(self, credit_service, sample_user):
|
||||
"""Test successful credit validation and reservation."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
|
||||
user, action = await credit_service.validate_and_reserve_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND
|
||||
)
|
||||
|
||||
assert user == sample_user
|
||||
assert action.action_type == CreditActionType.VLC_PLAY_SOUND
|
||||
assert action.cost == 1
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_reserve_credits_insufficient(self, credit_service):
|
||||
"""Test credit validation with insufficient credits."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
poor_user = User(
|
||||
id=1,
|
||||
name="Poor User",
|
||||
email="poor@example.com",
|
||||
role="user",
|
||||
credits=0,
|
||||
plan_id=1,
|
||||
)
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = poor_user
|
||||
|
||||
with pytest.raises(InsufficientCreditsError) as exc_info:
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND
|
||||
)
|
||||
|
||||
assert exc_info.value.required == 1
|
||||
assert exc_info.value.available == 0
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_reserve_credits_user_not_found(self, credit_service):
|
||||
"""Test credit validation when user is not found."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="User 999 not found"):
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
999, CreditActionType.VLC_PLAY_SOUND
|
||||
)
|
||||
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduct_credits_success(self, credit_service, sample_user):
|
||||
"""Test successful credit deduction."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
mock_socket_manager.send_to_user = AsyncMock()
|
||||
|
||||
transaction = await credit_service.deduct_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
|
||||
)
|
||||
|
||||
# Verify user credits were updated
|
||||
mock_repo.update.assert_called_once_with(sample_user, {"credits": 9})
|
||||
|
||||
# Verify transaction was created
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify socket event was emitted
|
||||
mock_socket_manager.send_to_user.assert_called_once_with(
|
||||
"1", "user_credits_changed", {
|
||||
"user_id": "1",
|
||||
"credits_before": 10,
|
||||
"credits_after": 9,
|
||||
"credits_deducted": 1,
|
||||
"action_type": "vlc_play_sound",
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Check transaction details
|
||||
added_transaction = mock_session.add.call_args[0][0]
|
||||
assert isinstance(added_transaction, CreditTransaction)
|
||||
assert added_transaction.user_id == 1
|
||||
assert added_transaction.action_type == "vlc_play_sound"
|
||||
assert added_transaction.amount == -1
|
||||
assert added_transaction.balance_before == 10
|
||||
assert added_transaction.balance_after == 9
|
||||
assert added_transaction.success is True
|
||||
assert json.loads(added_transaction.metadata_json) == {"test": "data"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user):
|
||||
"""Test credit deduction when action failed but requires success."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
mock_socket_manager.send_to_user = AsyncMock()
|
||||
|
||||
transaction = await credit_service.deduct_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND, False # Action failed
|
||||
)
|
||||
|
||||
# Verify user credits were NOT updated (action requires success)
|
||||
mock_repo.update.assert_not_called()
|
||||
|
||||
# Verify transaction was still created for auditing
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify no socket event was emitted since no credits were actually deducted
|
||||
mock_socket_manager.send_to_user.assert_not_called()
|
||||
|
||||
# Check transaction details
|
||||
added_transaction = mock_session.add.call_args[0][0]
|
||||
assert added_transaction.amount == 0 # No deduction for failed action
|
||||
assert added_transaction.balance_before == 10
|
||||
assert added_transaction.balance_after == 10 # No change
|
||||
assert added_transaction.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduct_credits_insufficient(self, credit_service):
|
||||
"""Test credit deduction with insufficient credits."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
poor_user = User(
|
||||
id=1,
|
||||
name="Poor User",
|
||||
email="poor@example.com",
|
||||
role="user",
|
||||
credits=0,
|
||||
plan_id=1,
|
||||
)
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = poor_user
|
||||
mock_socket_manager.send_to_user = AsyncMock()
|
||||
|
||||
with pytest.raises(InsufficientCreditsError):
|
||||
await credit_service.deduct_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND, True
|
||||
)
|
||||
|
||||
# Verify no socket event was emitted since credits could not be deducted
|
||||
mock_socket_manager.send_to_user.assert_not_called()
|
||||
|
||||
mock_session.rollback.assert_called_once()
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_credits(self, credit_service, sample_user):
|
||||
"""Test adding credits to user account."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
mock_socket_manager.send_to_user = AsyncMock()
|
||||
|
||||
transaction = await credit_service.add_credits(
|
||||
1, 5, "Bonus credits", {"reason": "signup"}
|
||||
)
|
||||
|
||||
# Verify user credits were updated
|
||||
mock_repo.update.assert_called_once_with(sample_user, {"credits": 15})
|
||||
|
||||
# Verify transaction was created
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify socket event was emitted
|
||||
mock_socket_manager.send_to_user.assert_called_once_with(
|
||||
"1", "user_credits_changed", {
|
||||
"user_id": "1",
|
||||
"credits_before": 10,
|
||||
"credits_after": 15,
|
||||
"credits_added": 5,
|
||||
"description": "Bonus credits",
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Check transaction details
|
||||
added_transaction = mock_session.add.call_args[0][0]
|
||||
assert added_transaction.amount == 5
|
||||
assert added_transaction.balance_before == 10
|
||||
assert added_transaction.balance_after == 15
|
||||
assert added_transaction.description == "Bonus credits"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_credits_invalid_amount(self, credit_service):
|
||||
"""Test adding invalid amount of credits."""
|
||||
with pytest.raises(ValueError, match="Amount must be positive"):
|
||||
await credit_service.add_credits(1, 0, "Invalid")
|
||||
|
||||
with pytest.raises(ValueError, match="Amount must be positive"):
|
||||
await credit_service.add_credits(1, -5, "Invalid")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_balance(self, credit_service, sample_user):
|
||||
"""Test getting user credit balance."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
|
||||
balance = await credit_service.get_user_balance(1)
|
||||
|
||||
assert balance == 10
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_balance_user_not_found(self, credit_service):
|
||||
"""Test getting balance for non-existent user."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="User 999 not found"):
|
||||
await credit_service.get_user_balance(999)
|
||||
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
|
||||
class TestInsufficientCreditsError:
|
||||
"""Test InsufficientCreditsError exception."""
|
||||
|
||||
def test_insufficient_credits_error_creation(self):
|
||||
"""Test creating InsufficientCreditsError."""
|
||||
error = InsufficientCreditsError(5, 2)
|
||||
assert error.required == 5
|
||||
assert error.available == 2
|
||||
assert str(error) == "Insufficient credits: 5 required, 2 available"
|
||||
277
tests/utils/test_credit_decorators.py
Normal file
277
tests/utils/test_credit_decorators.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Tests for credit decorators."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
from app.utils.credit_decorators import CreditManager, requires_credits, validate_credits_only
|
||||
|
||||
|
||||
class TestRequiresCreditsDecorator:
|
||||
"""Test requires_credits decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
"""Create a mock credit service."""
|
||||
service = AsyncMock(spec=CreditService)
|
||||
service.validate_and_reserve_credits = AsyncMock()
|
||||
service.deduct_credits = AsyncMock()
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def credit_service_factory(self, mock_credit_service):
|
||||
"""Create a credit service factory."""
|
||||
return lambda: mock_credit_service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_success(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator with successful action."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int, message: str) -> str:
|
||||
return f"Success: {message}"
|
||||
|
||||
result = await test_action(user_id=123, message="test")
|
||||
|
||||
assert result == "Success: test"
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, None
|
||||
)
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, True, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_metadata(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator with metadata extraction."""
|
||||
|
||||
def extract_metadata(user_id: int, sound_name: str) -> dict:
|
||||
return {"sound_name": sound_name}
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id",
|
||||
metadata_extractor=extract_metadata
|
||||
)
|
||||
async def test_action(user_id: int, sound_name: str) -> bool:
|
||||
return True
|
||||
|
||||
await test_action(user_id=123, sound_name="test.mp3")
|
||||
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"}
|
||||
)
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_failed_action(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator with failed action."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int) -> bool:
|
||||
return False # Action fails
|
||||
|
||||
result = await test_action(user_id=123)
|
||||
|
||||
assert result is False
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_exception_in_action(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator when action raises exception."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int) -> str:
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await test_action(user_id=123)
|
||||
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_insufficient_credits(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator with insufficient credits."""
|
||||
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int) -> str:
|
||||
return "Should not execute"
|
||||
|
||||
with pytest.raises(InsufficientCreditsError):
|
||||
await test_action(user_id=123)
|
||||
|
||||
# Should not call deduct_credits since validation failed
|
||||
mock_credit_service.deduct_credits.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_user_id_in_args(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator extracting user_id from positional args."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int, message: str) -> str:
|
||||
return message
|
||||
|
||||
result = await test_action(123, "test")
|
||||
|
||||
assert result == "test"
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_missing_user_id(self, credit_service_factory):
|
||||
"""Test decorator when user_id cannot be extracted."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(other_param: str) -> str:
|
||||
return other_param
|
||||
|
||||
with pytest.raises(ValueError, match="Could not extract user_id"):
|
||||
await test_action(other_param="test")
|
||||
|
||||
|
||||
class TestValidateCreditsOnlyDecorator:
|
||||
"""Test validate_credits_only decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
"""Create a mock credit service."""
|
||||
service = AsyncMock(spec=CreditService)
|
||||
service.validate_and_reserve_credits = AsyncMock()
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def credit_service_factory(self, mock_credit_service):
|
||||
"""Create a credit service factory."""
|
||||
return lambda: mock_credit_service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_only_decorator(self, credit_service_factory, mock_credit_service):
|
||||
"""Test validate_credits_only decorator."""
|
||||
|
||||
@validate_credits_only(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int, message: str) -> str:
|
||||
return f"Validated: {message}"
|
||||
|
||||
result = await test_action(user_id=123, message="test")
|
||||
|
||||
assert result == "Validated: test"
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND
|
||||
)
|
||||
# Should not deduct credits, only validate
|
||||
mock_credit_service.deduct_credits.assert_not_called()
|
||||
|
||||
|
||||
class TestCreditManager:
|
||||
"""Test CreditManager context manager."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
"""Create a mock credit service."""
|
||||
service = AsyncMock(spec=CreditService)
|
||||
service.validate_and_reserve_credits = AsyncMock()
|
||||
service.deduct_credits = AsyncMock()
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_success(self, mock_credit_service):
|
||||
"""Test CreditManager with successful operation."""
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
123,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
{"test": "data"}
|
||||
) as manager:
|
||||
manager.mark_success()
|
||||
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"}
|
||||
)
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_failure(self, mock_credit_service):
|
||||
"""Test CreditManager with failed operation."""
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
123,
|
||||
CreditActionType.VLC_PLAY_SOUND
|
||||
):
|
||||
# Don't mark as success - should be considered failed
|
||||
pass
|
||||
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_exception(self, mock_credit_service):
|
||||
"""Test CreditManager when exception occurs."""
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
123,
|
||||
CreditActionType.VLC_PLAY_SOUND
|
||||
):
|
||||
raise ValueError("Test error")
|
||||
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_validation_failure(self, mock_credit_service):
|
||||
"""Test CreditManager when validation fails."""
|
||||
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
|
||||
|
||||
with pytest.raises(InsufficientCreditsError):
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
123,
|
||||
CreditActionType.VLC_PLAY_SOUND
|
||||
):
|
||||
pass
|
||||
|
||||
# Should not call deduct_credits since validation failed
|
||||
mock_credit_service.deduct_credits.assert_not_called()
|
||||
Reference in New Issue
Block a user