"""Tests for credit transaction repository.""" import json from collections.abc import AsyncGenerator from typing import Any import pytest import pytest_asyncio from sqlmodel.ext.asyncio.session import AsyncSession from app.models.credit_transaction import CreditTransaction from app.models.user import User from app.repositories.credit_transaction import CreditTransactionRepository class TestCreditTransactionRepository: """Test credit transaction repository operations.""" @pytest_asyncio.fixture async def credit_transaction_repository( self, test_session: AsyncSession, ) -> AsyncGenerator[CreditTransactionRepository, None]: """Create a credit transaction repository instance.""" yield CreditTransactionRepository(test_session) @pytest_asyncio.fixture async def test_user_id( self, test_user: User, ) -> int: """Get test user ID to avoid lazy loading issues.""" assert test_user.id is not None return test_user.id @pytest_asyncio.fixture async def test_transactions( self, test_session: AsyncSession, test_user_id: int, ) -> AsyncGenerator[list[CreditTransaction], None]: """Create test credit transactions.""" transactions = [] user_id = test_user_id # Create various types of transactions transaction_data = [ { "user_id": user_id, "action_type": "vlc_play_sound", "amount": -1, "balance_before": 10, "balance_after": 9, "description": "Play sound via VLC", "success": True, "metadata_json": json.dumps({"sound_id": 1, "sound_name": "test.mp3"}), }, { "user_id": user_id, "action_type": "audio_extraction", "amount": -5, "balance_before": 9, "balance_after": 4, "description": "Extract audio from URL", "success": True, "metadata_json": json.dumps({"url": "https://example.com/video"}), }, { "user_id": user_id, "action_type": "vlc_play_sound", "amount": 0, "balance_before": 4, "balance_after": 4, "description": "Play sound via VLC (failed)", "success": False, "metadata_json": json.dumps({"sound_id": 2, "error": "File not found"}), }, { "user_id": user_id, "action_type": "credit_addition", "amount": 50, "balance_before": 4, "balance_after": 54, "description": "Bonus credits", "success": True, "metadata_json": json.dumps({"reason": "signup_bonus"}), }, ] for data in transaction_data: transaction = CreditTransaction(**data) test_session.add(transaction) transactions.append(transaction) await test_session.commit() for transaction in transactions: await test_session.refresh(transaction) yield transactions @pytest_asyncio.fixture async def other_user_transaction( self, test_session: AsyncSession, ensure_plans: tuple[Any, ...], # noqa: ARG002 ) -> AsyncGenerator[CreditTransaction, None]: """Create a transaction for a different user.""" from app.models.plan import Plan from app.repositories.user import UserRepository # Create another user user_repo = UserRepository(test_session) other_user_data = { "email": "other@example.com", "name": "Other User", "password_hash": "hashed_password", "role": "user", "is_active": True, } other_user = await user_repo.create(other_user_data) # Create transaction for the other user transaction_data = { "user_id": other_user.id, "action_type": "vlc_play_sound", "amount": -1, "balance_before": 100, "balance_after": 99, "description": "Other user play sound", "success": True, "metadata_json": None, } transaction = CreditTransaction(**transaction_data) test_session.add(transaction) await test_session.commit() await test_session.refresh(transaction) yield transaction @pytest.mark.asyncio async def test_get_by_id_existing( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], ) -> None: """Test getting transaction by ID when it exists.""" transaction_id = test_transactions[0].id assert transaction_id is not None transaction = await credit_transaction_repository.get_by_id(transaction_id) assert transaction is not None assert transaction.id == test_transactions[0].id assert transaction.action_type == "vlc_play_sound" assert transaction.amount == -1 @pytest.mark.asyncio async def test_get_by_id_nonexistent( self, credit_transaction_repository: CreditTransactionRepository, ) -> None: """Test getting transaction by ID when it doesn't exist.""" transaction = await credit_transaction_repository.get_by_id(99999) assert transaction is None @pytest.mark.asyncio async def test_get_by_user_id( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], other_user_transaction: CreditTransaction, test_user_id: int, ) -> None: """Test getting transactions by user ID.""" transactions = await credit_transaction_repository.get_by_user_id(test_user_id) # Should return all transactions for test_user assert len(transactions) == 4 # Should be ordered by created_at desc (newest first) assert all(t.user_id == test_user_id for t in transactions) # Should not include other user's transaction other_user_ids = [t.user_id for t in transactions] assert other_user_transaction.user_id not in other_user_ids @pytest.mark.asyncio async def test_get_by_user_id_with_pagination( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], test_user_id: int, ) -> None: """Test getting transactions by user ID with pagination.""" # Get first 2 transactions first_page = await credit_transaction_repository.get_by_user_id( test_user_id, limit=2, offset=0 ) assert len(first_page) == 2 # Get next 2 transactions second_page = await credit_transaction_repository.get_by_user_id( test_user_id, limit=2, offset=2 ) assert len(second_page) == 2 # Should not overlap first_page_ids = {t.id for t in first_page} second_page_ids = {t.id for t in second_page} assert first_page_ids.isdisjoint(second_page_ids) @pytest.mark.asyncio async def test_get_by_action_type( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], ) -> None: """Test getting transactions by action type.""" vlc_transactions = await credit_transaction_repository.get_by_action_type( "vlc_play_sound" ) # Should return 2 VLC transactions (1 successful, 1 failed) assert len(vlc_transactions) >= 2 assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions) extraction_transactions = await credit_transaction_repository.get_by_action_type( "audio_extraction" ) # Should return 1 extraction transaction assert len(extraction_transactions) >= 1 assert all(t.action_type == "audio_extraction" for t in extraction_transactions) @pytest.mark.asyncio async def test_get_by_action_type_with_pagination( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], ) -> None: """Test getting transactions by action type with pagination.""" # Test with limit transactions = await credit_transaction_repository.get_by_action_type( "vlc_play_sound", limit=1 ) assert len(transactions) == 1 assert transactions[0].action_type == "vlc_play_sound" # Test with offset transactions = await credit_transaction_repository.get_by_action_type( "vlc_play_sound", limit=1, offset=1 ) assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total @pytest.mark.asyncio async def test_get_successful_transactions( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], ) -> None: """Test getting only successful transactions.""" successful_transactions = await credit_transaction_repository.get_successful_transactions() # Should only return successful transactions assert all(t.success is True for t in successful_transactions) # Should be at least 3 (vlc_play_sound, audio_extraction, credit_addition) assert len(successful_transactions) >= 3 @pytest.mark.asyncio async def test_get_successful_transactions_by_user( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], other_user_transaction: CreditTransaction, test_user_id: int, ) -> None: """Test getting successful transactions filtered by user.""" successful_transactions = await credit_transaction_repository.get_successful_transactions( user_id=test_user_id ) # Should only return successful transactions for test_user assert all(t.success is True for t in successful_transactions) assert all(t.user_id == test_user_id for t in successful_transactions) # Should be 3 successful transactions for test_user assert len(successful_transactions) == 3 @pytest.mark.asyncio async def test_get_successful_transactions_with_pagination( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], test_user_id: int, ) -> None: """Test getting successful transactions with pagination.""" # Get first 2 successful transactions first_page = await credit_transaction_repository.get_successful_transactions( user_id=test_user_id, limit=2, offset=0 ) assert len(first_page) == 2 assert all(t.success is True for t in first_page) # Get next successful transaction second_page = await credit_transaction_repository.get_successful_transactions( user_id=test_user_id, limit=2, offset=2 ) assert len(second_page) == 1 # Should be 1 remaining assert all(t.success is True for t in second_page) @pytest.mark.asyncio async def test_get_all_transactions( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], other_user_transaction: CreditTransaction, ) -> None: """Test getting all transactions.""" all_transactions = await credit_transaction_repository.get_all() # Should return all transactions assert len(all_transactions) >= 5 # 4 from test_transactions + 1 other_user_transaction @pytest.mark.asyncio async def test_create_transaction( self, credit_transaction_repository: CreditTransactionRepository, test_user_id: int, ) -> None: """Test creating a new transaction.""" transaction_data = { "user_id": test_user_id, "action_type": "test_action", "amount": -10, "balance_before": 100, "balance_after": 90, "description": "Test transaction", "success": True, "metadata_json": json.dumps({"test": "data"}), } transaction = await credit_transaction_repository.create(transaction_data) assert transaction.id is not None assert transaction.user_id == test_user_id assert transaction.action_type == "test_action" assert transaction.amount == -10 assert transaction.balance_before == 100 assert transaction.balance_after == 90 assert transaction.success is True assert transaction.metadata_json is not None assert json.loads(transaction.metadata_json) == {"test": "data"} @pytest.mark.asyncio async def test_update_transaction( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], ) -> None: """Test updating a transaction.""" transaction = test_transactions[0] update_data = { "description": "Updated description", "metadata_json": json.dumps({"updated": True}), } updated_transaction = await credit_transaction_repository.update( transaction, update_data ) assert updated_transaction.id == transaction.id assert updated_transaction.description == "Updated description" assert updated_transaction.metadata_json is not None assert json.loads(updated_transaction.metadata_json) == {"updated": True} # Other fields should remain unchanged assert updated_transaction.amount == transaction.amount assert updated_transaction.action_type == transaction.action_type @pytest.mark.asyncio async def test_delete_transaction( self, credit_transaction_repository: CreditTransactionRepository, test_session: AsyncSession, test_user_id: int, ) -> None: """Test deleting a transaction.""" # Create a transaction to delete transaction_data = { "user_id": test_user_id, "action_type": "to_delete", "amount": -1, "balance_before": 10, "balance_after": 9, "description": "To be deleted", "success": True, "metadata_json": None, } transaction = await credit_transaction_repository.create(transaction_data) transaction_id = transaction.id # Delete the transaction await credit_transaction_repository.delete(transaction) # Verify transaction is deleted assert transaction_id is not None deleted_transaction = await credit_transaction_repository.get_by_id(transaction_id) assert deleted_transaction is None @pytest.mark.asyncio async def test_transaction_ordering( self, credit_transaction_repository: CreditTransactionRepository, test_transactions: list[CreditTransaction], test_user_id: int, ) -> None: """Test that transactions are ordered by created_at desc.""" transactions = await credit_transaction_repository.get_by_user_id(test_user_id) # Should be ordered by created_at desc (newest first) for i in range(len(transactions) - 1): assert transactions[i].created_at >= transactions[i + 1].created_at