625 lines
24 KiB
Python
625 lines
24 KiB
Python
"""Tests for credit service."""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.models.credit_action import CreditActionType
|
|
from app.models.credit_transaction import CreditTransaction
|
|
from app.models.plan import Plan
|
|
from app.models.user import User
|
|
from app.services.credit import CreditService, InsufficientCreditsError
|
|
|
|
|
|
class TestCreditService:
|
|
"""Test credit service functionality."""
|
|
|
|
@pytest.fixture
|
|
def mock_db_session_factory(self):
|
|
"""Create a mock database session factory."""
|
|
session = AsyncMock(spec=AsyncSession)
|
|
return lambda: session
|
|
|
|
@pytest.fixture
|
|
def credit_service(self, mock_db_session_factory):
|
|
"""Create a credit service instance for testing."""
|
|
return CreditService(mock_db_session_factory)
|
|
|
|
@pytest.fixture
|
|
def sample_user(self):
|
|
"""Create a sample user for testing."""
|
|
return User(
|
|
id=1,
|
|
name="Test User",
|
|
email="test@example.com",
|
|
role="user",
|
|
credits=10,
|
|
plan_id=1,
|
|
)
|
|
|
|
@pytest.fixture
|
|
def sample_plan(self):
|
|
"""Create a sample plan for testing."""
|
|
return Plan(
|
|
id=1,
|
|
code="basic",
|
|
name="Basic Plan",
|
|
description="Basic plan with limited credits",
|
|
credits=50,
|
|
max_credits=100,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_credits_sufficient(self, credit_service, sample_user) -> None:
|
|
"""Test checking credits when user has sufficient credits."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = sample_user
|
|
|
|
result = await credit_service.check_credits(
|
|
1,
|
|
CreditActionType.VLC_PLAY_SOUND,
|
|
)
|
|
|
|
assert result is True
|
|
mock_repo.get_by_id.assert_called_once_with(1)
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_credits_insufficient(self, credit_service) -> None:
|
|
"""Test checking credits when user has insufficient credits."""
|
|
mock_session = credit_service.db_session_factory()
|
|
poor_user = User(
|
|
id=1,
|
|
name="Poor User",
|
|
email="poor@example.com",
|
|
role="user",
|
|
credits=0, # No credits
|
|
plan_id=1,
|
|
)
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = poor_user
|
|
|
|
result = await credit_service.check_credits(
|
|
1,
|
|
CreditActionType.VLC_PLAY_SOUND,
|
|
)
|
|
|
|
assert result is False
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_credits_user_not_found(self, credit_service) -> None:
|
|
"""Test checking credits when user is not found."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = None
|
|
|
|
result = await credit_service.check_credits(
|
|
999,
|
|
CreditActionType.VLC_PLAY_SOUND,
|
|
)
|
|
|
|
assert result is False
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_and_reserve_credits_success(
|
|
self,
|
|
credit_service,
|
|
sample_user,
|
|
) -> None:
|
|
"""Test successful credit validation and reservation."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = sample_user
|
|
|
|
user, action = await credit_service.validate_and_reserve_credits(
|
|
1,
|
|
CreditActionType.VLC_PLAY_SOUND,
|
|
)
|
|
|
|
assert user == sample_user
|
|
assert action.action_type == CreditActionType.VLC_PLAY_SOUND
|
|
assert action.cost == 1
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_and_reserve_credits_insufficient(
|
|
self,
|
|
credit_service,
|
|
) -> None:
|
|
"""Test credit validation with insufficient credits."""
|
|
mock_session = credit_service.db_session_factory()
|
|
poor_user = User(
|
|
id=1,
|
|
name="Poor User",
|
|
email="poor@example.com",
|
|
role="user",
|
|
credits=0,
|
|
plan_id=1,
|
|
)
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = poor_user
|
|
|
|
with pytest.raises(InsufficientCreditsError) as exc_info:
|
|
await credit_service.validate_and_reserve_credits(
|
|
1,
|
|
CreditActionType.VLC_PLAY_SOUND,
|
|
)
|
|
|
|
assert exc_info.value.required == 1
|
|
assert exc_info.value.available == 0
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_and_reserve_credits_user_not_found(
|
|
self,
|
|
credit_service,
|
|
) -> None:
|
|
"""Test credit validation when user is not found."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = None
|
|
|
|
with pytest.raises(ValueError, match="User 999 not found"):
|
|
await credit_service.validate_and_reserve_credits(
|
|
999,
|
|
CreditActionType.VLC_PLAY_SOUND,
|
|
)
|
|
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deduct_credits_success(self, credit_service, sample_user) -> None:
|
|
"""Test successful credit deduction."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with (
|
|
patch("app.services.credit.UserRepository") as mock_repo_class,
|
|
patch("app.services.credit.socket_manager") as mock_socket_manager,
|
|
):
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = sample_user
|
|
mock_socket_manager.send_to_user = AsyncMock()
|
|
|
|
await credit_service.deduct_credits(
|
|
1,
|
|
CreditActionType.VLC_PLAY_SOUND,
|
|
success=True,
|
|
metadata={"test": "data"},
|
|
)
|
|
|
|
# Verify user credits were updated
|
|
mock_repo.update.assert_called_once_with(sample_user, {"credits": 9})
|
|
|
|
# Verify transaction was created
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
# Verify socket event was emitted
|
|
mock_socket_manager.send_to_user.assert_called_once_with(
|
|
"1",
|
|
"user_credits_changed",
|
|
{
|
|
"user_id": "1",
|
|
"credits_before": 10,
|
|
"credits_after": 9,
|
|
"credits_deducted": 1,
|
|
"action_type": "vlc_play_sound",
|
|
"success": True,
|
|
},
|
|
)
|
|
|
|
# Check transaction details
|
|
added_transaction = mock_session.add.call_args[0][0]
|
|
assert isinstance(added_transaction, CreditTransaction)
|
|
assert added_transaction.user_id == 1
|
|
assert added_transaction.action_type == "vlc_play_sound"
|
|
assert added_transaction.amount == -1
|
|
assert added_transaction.balance_before == 10
|
|
assert added_transaction.balance_after == 9
|
|
assert added_transaction.success is True
|
|
assert json.loads(added_transaction.metadata_json) == {"test": "data"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deduct_credits_failed_action_requires_success(
|
|
self,
|
|
credit_service,
|
|
sample_user,
|
|
) -> None:
|
|
"""Test credit deduction when action failed but requires success."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with (
|
|
patch("app.services.credit.UserRepository") as mock_repo_class,
|
|
patch("app.services.credit.socket_manager") as mock_socket_manager,
|
|
):
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = sample_user
|
|
mock_socket_manager.send_to_user = AsyncMock()
|
|
|
|
await credit_service.deduct_credits(
|
|
1,
|
|
CreditActionType.VLC_PLAY_SOUND,
|
|
success=False, # Action failed
|
|
)
|
|
|
|
# Verify user credits were NOT updated (action requires success)
|
|
mock_repo.update.assert_not_called()
|
|
|
|
# Verify transaction was still created for auditing
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
# Verify no socket event was emitted since no credits were actually deducted
|
|
mock_socket_manager.send_to_user.assert_not_called()
|
|
|
|
# Check transaction details
|
|
added_transaction = mock_session.add.call_args[0][0]
|
|
assert added_transaction.amount == 0 # No deduction for failed action
|
|
assert added_transaction.balance_before == 10
|
|
assert added_transaction.balance_after == 10 # No change
|
|
assert added_transaction.success is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deduct_credits_insufficient(self, credit_service) -> None:
|
|
"""Test credit deduction with insufficient credits."""
|
|
mock_session = credit_service.db_session_factory()
|
|
poor_user = User(
|
|
id=1,
|
|
name="Poor User",
|
|
email="poor@example.com",
|
|
role="user",
|
|
credits=0,
|
|
plan_id=1,
|
|
)
|
|
|
|
with (
|
|
patch("app.services.credit.UserRepository") as mock_repo_class,
|
|
patch("app.services.credit.socket_manager") as mock_socket_manager,
|
|
):
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = poor_user
|
|
mock_socket_manager.send_to_user = AsyncMock()
|
|
|
|
with pytest.raises(InsufficientCreditsError):
|
|
await credit_service.deduct_credits(
|
|
1,
|
|
CreditActionType.VLC_PLAY_SOUND,
|
|
success=True,
|
|
)
|
|
|
|
# Verify no socket event was emitted since credits could not be deducted
|
|
mock_socket_manager.send_to_user.assert_not_called()
|
|
|
|
mock_session.rollback.assert_called_once()
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_credits(self, credit_service, sample_user) -> None:
|
|
"""Test adding credits to user account."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with (
|
|
patch("app.services.credit.UserRepository") as mock_repo_class,
|
|
patch("app.services.credit.socket_manager") as mock_socket_manager,
|
|
):
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = sample_user
|
|
mock_socket_manager.send_to_user = AsyncMock()
|
|
|
|
await credit_service.add_credits(
|
|
1,
|
|
5,
|
|
"Bonus credits",
|
|
{"reason": "signup"},
|
|
)
|
|
|
|
# Verify user credits were updated
|
|
mock_repo.update.assert_called_once_with(sample_user, {"credits": 15})
|
|
|
|
# Verify transaction was created
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
# Verify socket event was emitted
|
|
mock_socket_manager.send_to_user.assert_called_once_with(
|
|
"1",
|
|
"user_credits_changed",
|
|
{
|
|
"user_id": "1",
|
|
"credits_before": 10,
|
|
"credits_after": 15,
|
|
"credits_added": 5,
|
|
"description": "Bonus credits",
|
|
"success": True,
|
|
},
|
|
)
|
|
|
|
# Check transaction details
|
|
added_transaction = mock_session.add.call_args[0][0]
|
|
assert added_transaction.amount == 5
|
|
assert added_transaction.balance_before == 10
|
|
assert added_transaction.balance_after == 15
|
|
assert added_transaction.description == "Bonus credits"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_credits_invalid_amount(self, credit_service) -> None:
|
|
"""Test adding invalid amount of credits."""
|
|
with pytest.raises(ValueError, match="Amount must be positive"):
|
|
await credit_service.add_credits(1, 0, "Invalid")
|
|
|
|
with pytest.raises(ValueError, match="Amount must be positive"):
|
|
await credit_service.add_credits(1, -5, "Invalid")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_balance(self, credit_service, sample_user) -> None:
|
|
"""Test getting user credit balance."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = sample_user
|
|
|
|
balance = await credit_service.get_user_balance(1)
|
|
|
|
assert balance == 10
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_balance_user_not_found(self, credit_service) -> None:
|
|
"""Test getting balance for non-existent user."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = None
|
|
|
|
with pytest.raises(ValueError, match="User 999 not found"):
|
|
await credit_service.get_user_balance(999)
|
|
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_recharge_user_credits_success(
|
|
self, credit_service, sample_user
|
|
) -> None:
|
|
"""Test successful credit recharge for a user."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with (
|
|
patch("app.services.credit.UserRepository") as mock_repo_class,
|
|
patch("app.services.credit.socket_manager") as mock_socket_manager,
|
|
):
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = sample_user
|
|
mock_socket_manager.send_to_user = AsyncMock()
|
|
|
|
# Test recharging 20 credits with max of 100
|
|
transaction = await credit_service.recharge_user_credits(1, 20, 100)
|
|
|
|
# Verify user credits were updated (10 + 20 = 30)
|
|
mock_repo.update.assert_called_once_with(sample_user, {"credits": 30})
|
|
|
|
# Verify transaction was created
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
# Verify socket event was emitted
|
|
mock_socket_manager.send_to_user.assert_called_once_with(
|
|
"1",
|
|
"user_credits_changed",
|
|
{
|
|
"user_id": "1",
|
|
"credits_before": 10,
|
|
"credits_after": 30,
|
|
"credits_added": 20,
|
|
"description": "Daily credit recharge",
|
|
"success": True,
|
|
},
|
|
)
|
|
|
|
# Check transaction details
|
|
assert transaction is not None
|
|
added_transaction = mock_session.add.call_args[0][0]
|
|
assert isinstance(added_transaction, CreditTransaction)
|
|
assert added_transaction.user_id == 1
|
|
assert (
|
|
added_transaction.action_type == CreditActionType.DAILY_RECHARGE.value
|
|
)
|
|
assert added_transaction.amount == 20
|
|
assert added_transaction.balance_before == 10
|
|
assert added_transaction.balance_after == 30
|
|
assert added_transaction.success is True
|
|
assert json.loads(added_transaction.metadata_json) == {
|
|
"plan_credits": 20,
|
|
"max_credits": 100,
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_recharge_user_credits_at_max(self, credit_service) -> None:
|
|
"""Test credit recharge when user is already at max credits."""
|
|
mock_session = credit_service.db_session_factory()
|
|
max_user = User(
|
|
id=1,
|
|
name="Max User",
|
|
email="max@example.com",
|
|
role="user",
|
|
credits=100, # Already at max
|
|
plan_id=1,
|
|
)
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = max_user
|
|
|
|
# Test recharging when already at max (100)
|
|
transaction = await credit_service.recharge_user_credits(1, 50, 100)
|
|
|
|
# Verify no credits were added
|
|
assert transaction is None
|
|
mock_repo.update.assert_not_called()
|
|
mock_session.add.assert_not_called()
|
|
mock_session.commit.assert_not_called()
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_recharge_user_credits_partial_recharge(self, credit_service) -> None:
|
|
"""Test credit recharge when it would exceed max credits."""
|
|
mock_session = credit_service.db_session_factory()
|
|
high_user = User(
|
|
id=1,
|
|
name="High User",
|
|
email="high@example.com",
|
|
role="user",
|
|
credits=90, # Close to max
|
|
plan_id=1,
|
|
)
|
|
|
|
with (
|
|
patch("app.services.credit.UserRepository") as mock_repo_class,
|
|
patch("app.services.credit.socket_manager") as mock_socket_manager,
|
|
):
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = high_user
|
|
mock_socket_manager.send_to_user = AsyncMock()
|
|
|
|
# Test recharging 50 credits when max is 100 (should only add 10)
|
|
transaction = await credit_service.recharge_user_credits(1, 50, 100)
|
|
|
|
# Verify only 10 credits were added (90 + 10 = 100)
|
|
mock_repo.update.assert_called_once_with(high_user, {"credits": 100})
|
|
|
|
# Check transaction details
|
|
assert transaction is not None
|
|
added_transaction = mock_session.add.call_args[0][0]
|
|
assert added_transaction.amount == 10 # Only 10 credits added
|
|
assert added_transaction.balance_before == 90
|
|
assert added_transaction.balance_after == 100
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_recharge_user_credits_user_not_found(self, credit_service) -> None:
|
|
"""Test credit recharge when user is not found."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
mock_repo.get_by_id.return_value = None
|
|
|
|
with pytest.raises(ValueError, match="User 999 not found"):
|
|
await credit_service.recharge_user_credits(999, 50, 100)
|
|
|
|
mock_session.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_recharge_all_users_credits(self, credit_service) -> None:
|
|
"""Test recharging credits for all users."""
|
|
mock_session = credit_service.db_session_factory()
|
|
|
|
# Create sample users with plans
|
|
user1 = User(id=1, name="User1", email="u1@example.com", credits=10, plan_id=1)
|
|
user1.plan = Plan(id=1, code="basic", name="Basic", credits=20, max_credits=50)
|
|
|
|
user2 = User(id=2, name="User2", email="u2@example.com", credits=45, plan_id=2)
|
|
user2.plan = Plan(id=2, code="pro", name="Pro", credits=30, max_credits=100)
|
|
|
|
user3 = User(id=3, name="User3", email="u3@example.com", credits=100, plan_id=2)
|
|
user3.plan = Plan(id=2, code="pro", name="Pro", credits=30, max_credits=100)
|
|
|
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
|
mock_repo = AsyncMock()
|
|
mock_repo_class.return_value = mock_repo
|
|
|
|
# Mock get_all_with_plan to return users in batches
|
|
mock_repo.get_all_with_plan.side_effect = [
|
|
[user1, user2, user3], # First batch
|
|
[], # Empty batch to end loop
|
|
]
|
|
|
|
# Mock recharge_user_credits to simulate individual recharges
|
|
with patch.object(credit_service, "recharge_user_credits") as mock_recharge:
|
|
# Mock return values: transaction for user1 and user2, None for user3
|
|
mock_transaction1 = CreditTransaction(
|
|
user_id=1,
|
|
amount=20,
|
|
action_type=CreditActionType.DAILY_RECHARGE.value,
|
|
balance_before=10,
|
|
balance_after=30,
|
|
description="Daily credit recharge",
|
|
success=True,
|
|
)
|
|
mock_transaction2 = CreditTransaction(
|
|
user_id=2,
|
|
amount=30,
|
|
action_type=CreditActionType.DAILY_RECHARGE.value,
|
|
balance_before=45,
|
|
balance_after=75,
|
|
description="Daily credit recharge",
|
|
success=True,
|
|
)
|
|
mock_recharge.side_effect = [
|
|
mock_transaction1, # User 1 recharged
|
|
mock_transaction2, # User 2 recharged
|
|
None, # User 3 at max, no recharge
|
|
]
|
|
|
|
stats = await credit_service.recharge_all_users_credits()
|
|
|
|
# Verify stats
|
|
assert stats == {
|
|
"total_users": 3,
|
|
"recharged_users": 2,
|
|
"skipped_users": 1,
|
|
"total_credits_added": 50, # 20 + 30
|
|
}
|
|
|
|
# Verify recharge was called for each user
|
|
assert mock_recharge.call_count == 3
|
|
mock_recharge.assert_any_call(1, 20, 50)
|
|
mock_recharge.assert_any_call(2, 30, 100)
|
|
mock_recharge.assert_any_call(3, 30, 100)
|
|
|
|
mock_session.close.assert_called_once()
|
|
|
|
|
|
class TestInsufficientCreditsError:
|
|
"""Test InsufficientCreditsError exception."""
|
|
|
|
def test_insufficient_credits_error_creation(self) -> None:
|
|
"""Test creating InsufficientCreditsError."""
|
|
error = InsufficientCreditsError(5, 2)
|
|
assert error.required == 5
|
|
assert error.available == 2
|
|
assert str(error) == "Insufficient credits: 5 required, 2 available"
|