Add tests for sound repository, user OAuth repository, credit service, and credit decorators
- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities. - Create tests for UserOauthRepository to validate OAuth record management. - Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits. - Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
@@ -7,9 +7,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.models.user import User
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.services.extraction import ExtractionInfo, ExtractionService
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
|
||||
from app.services.sound_scanner import ScanResults, SoundScannerService
|
||||
@@ -45,6 +47,12 @@ def get_vlc_player() -> VLCPlayerService:
|
||||
return get_vlc_player_service(get_session_factory())
|
||||
|
||||
|
||||
def get_credit_service() -> CreditService:
|
||||
"""Get the credit service."""
|
||||
from app.core.database import get_session_factory
|
||||
return CreditService(get_session_factory())
|
||||
|
||||
|
||||
async def get_sound_repository(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> SoundRepository:
|
||||
@@ -373,8 +381,9 @@ async def play_sound_with_vlc(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
|
||||
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
||||
credit_service: Annotated[CreditService, Depends(get_credit_service)],
|
||||
) -> dict[str, str | int | bool]:
|
||||
"""Play a sound using VLC subprocess."""
|
||||
"""Play a sound using VLC subprocess (requires 1 credit)."""
|
||||
try:
|
||||
# Get the sound
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
@@ -384,9 +393,30 @@ async def play_sound_with_vlc(
|
||||
detail=f"Sound with ID {sound_id} not found",
|
||||
)
|
||||
|
||||
# Check and validate credits before playing
|
||||
try:
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
current_user.id,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
{"sound_id": sound_id, "sound_name": sound.name}
|
||||
)
|
||||
except InsufficientCreditsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Insufficient credits: {e.required} required, {e.available} available",
|
||||
) from e
|
||||
|
||||
# Play the sound using VLC
|
||||
success = await vlc_player.play_sound(sound)
|
||||
|
||||
# Deduct credits based on success
|
||||
await credit_service.deduct_credits(
|
||||
current_user.id,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
success,
|
||||
{"sound_id": sound_id, "sound_name": sound.name},
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@@ -398,6 +428,7 @@ async def play_sound_with_vlc(
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound.name,
|
||||
"success": True,
|
||||
"credits_deducted": 1,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
||||
121
app/models/credit_action.py
Normal file
121
app/models/credit_action.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Credit action definitions for the credit system."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CreditActionType(str, Enum):
|
||||
"""Types of actions that consume credits."""
|
||||
|
||||
VLC_PLAY_SOUND = "vlc_play_sound"
|
||||
AUDIO_EXTRACTION = "audio_extraction"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
SOUND_NORMALIZATION = "sound_normalization"
|
||||
API_REQUEST = "api_request"
|
||||
PLAYLIST_CREATION = "playlist_creation"
|
||||
|
||||
|
||||
class CreditAction:
|
||||
"""Definition of a credit-consuming action."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_type: CreditActionType,
|
||||
cost: int,
|
||||
description: str,
|
||||
*,
|
||||
requires_success: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a credit action.
|
||||
|
||||
Args:
|
||||
action_type: The type of action
|
||||
cost: Number of credits required
|
||||
description: Human-readable description
|
||||
requires_success: Whether credits are only deducted on successful completion
|
||||
|
||||
"""
|
||||
self.action_type = action_type
|
||||
self.cost = cost
|
||||
self.description = description
|
||||
self.requires_success = requires_success
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string representation of the action."""
|
||||
return f"{self.action_type.value} ({self.cost} credits)"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"action_type": self.action_type.value,
|
||||
"cost": self.cost,
|
||||
"description": self.description,
|
||||
"requires_success": self.requires_success,
|
||||
}
|
||||
|
||||
|
||||
# Predefined credit actions
|
||||
CREDIT_ACTIONS = {
|
||||
CreditActionType.VLC_PLAY_SOUND: CreditAction(
|
||||
action_type=CreditActionType.VLC_PLAY_SOUND,
|
||||
cost=1,
|
||||
description="Play a sound using VLC player",
|
||||
requires_success=True,
|
||||
),
|
||||
CreditActionType.AUDIO_EXTRACTION: CreditAction(
|
||||
action_type=CreditActionType.AUDIO_EXTRACTION,
|
||||
cost=5,
|
||||
description="Extract audio from external URL",
|
||||
requires_success=True,
|
||||
),
|
||||
CreditActionType.TEXT_TO_SPEECH: CreditAction(
|
||||
action_type=CreditActionType.TEXT_TO_SPEECH,
|
||||
cost=2,
|
||||
description="Generate speech from text",
|
||||
requires_success=True,
|
||||
),
|
||||
CreditActionType.SOUND_NORMALIZATION: CreditAction(
|
||||
action_type=CreditActionType.SOUND_NORMALIZATION,
|
||||
cost=1,
|
||||
description="Normalize audio levels",
|
||||
requires_success=True,
|
||||
),
|
||||
CreditActionType.API_REQUEST: CreditAction(
|
||||
action_type=CreditActionType.API_REQUEST,
|
||||
cost=1,
|
||||
description="API request (rate limiting)",
|
||||
requires_success=False, # Charged even if request fails
|
||||
),
|
||||
CreditActionType.PLAYLIST_CREATION: CreditAction(
|
||||
action_type=CreditActionType.PLAYLIST_CREATION,
|
||||
cost=3,
|
||||
description="Create a new playlist",
|
||||
requires_success=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_credit_action(action_type: CreditActionType) -> CreditAction:
|
||||
"""Get a credit action definition by type.
|
||||
|
||||
Args:
|
||||
action_type: The action type to look up
|
||||
|
||||
Returns:
|
||||
The credit action definition
|
||||
|
||||
Raises:
|
||||
KeyError: If action type is not found
|
||||
|
||||
"""
|
||||
return CREDIT_ACTIONS[action_type]
|
||||
|
||||
|
||||
def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
|
||||
"""Get all available credit actions.
|
||||
|
||||
Returns:
|
||||
Dictionary of all credit actions
|
||||
|
||||
"""
|
||||
return CREDIT_ACTIONS.copy()
|
||||
29
app/models/credit_transaction.py
Normal file
29
app/models/credit_transaction.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Credit transaction model for tracking credit usage."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class CreditTransaction(BaseModel, table=True):
|
||||
"""Database model for credit transactions."""
|
||||
|
||||
__tablename__ = "credit_transaction" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(foreign_key="user.id", nullable=False)
|
||||
action_type: str = Field(nullable=False)
|
||||
amount: int = Field(nullable=False) # Negative for deductions, positive for additions
|
||||
balance_before: int = Field(nullable=False)
|
||||
balance_after: int = Field(nullable=False)
|
||||
description: str = Field(nullable=False)
|
||||
success: bool = Field(nullable=False, default=True)
|
||||
# JSON string for additional data
|
||||
metadata_json: str | None = Field(default=None)
|
||||
|
||||
# relationships
|
||||
user: "User" = Relationship(back_populates="credit_transactions")
|
||||
@@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.extraction import Extraction
|
||||
from app.models.plan import Plan
|
||||
from app.models.playlist import Playlist
|
||||
@@ -35,3 +36,4 @@ class User(BaseModel, table=True):
|
||||
playlists: list["Playlist"] = Relationship(back_populates="user")
|
||||
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
|
||||
extractions: list["Extraction"] = Relationship(back_populates="user")
|
||||
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")
|
||||
|
||||
132
app/repositories/base.py
Normal file
132
app/repositories/base.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
# Type variable for the model
|
||||
ModelType = TypeVar("ModelType")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
|
||||
"""Initialize the repository.
|
||||
|
||||
Args:
|
||||
model: The SQLModel class
|
||||
session: Database session
|
||||
|
||||
"""
|
||||
self.model = model
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, entity_id: int) -> ModelType | None:
|
||||
"""Get an entity by ID.
|
||||
|
||||
Args:
|
||||
entity_id: The entity ID
|
||||
|
||||
Returns:
|
||||
The entity if found, None otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
|
||||
raise
|
||||
|
||||
async def get_all(
|
||||
self,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ModelType]:
|
||||
"""Get all entities with pagination.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entities to return
|
||||
offset: Number of entities to skip
|
||||
|
||||
Returns:
|
||||
List of entities
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(self.model).limit(limit).offset(offset)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get all %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
||||
"""Create a new entity.
|
||||
|
||||
Args:
|
||||
entity_data: Dictionary of entity data
|
||||
|
||||
Returns:
|
||||
The created entity
|
||||
|
||||
"""
|
||||
try:
|
||||
entity = self.model(**entity_data)
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
|
||||
"""Update an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to update
|
||||
update_data: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
The updated entity
|
||||
|
||||
"""
|
||||
try:
|
||||
for field, value in update_data.items():
|
||||
setattr(entity, field, value)
|
||||
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def delete(self, entity: ModelType) -> None:
|
||||
"""Delete an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to delete
|
||||
|
||||
"""
|
||||
try:
|
||||
await self.session.delete(entity)
|
||||
await self.session.commit()
|
||||
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete %s", self.model.__name__)
|
||||
raise
|
||||
108
app/repositories/credit_transaction.py
Normal file
108
app/repositories/credit_transaction.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Repository for credit transaction database operations."""
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class CreditTransactionRepository(BaseRepository[CreditTransaction]):
|
||||
"""Repository for credit transaction operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the repository.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
||||
"""
|
||||
super().__init__(CreditTransaction, session)
|
||||
|
||||
async def get_by_user_id(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[CreditTransaction]:
|
||||
"""Get credit transactions for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of transactions to return
|
||||
offset: Number of transactions to skip
|
||||
|
||||
Returns:
|
||||
List of credit transactions ordered by creation date (newest first)
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.user_id == user_id)
|
||||
.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
|
||||
async def get_by_action_type(
|
||||
self,
|
||||
action_type: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[CreditTransaction]:
|
||||
"""Get credit transactions by action type.
|
||||
|
||||
Args:
|
||||
action_type: The action type to filter by
|
||||
limit: Maximum number of transactions to return
|
||||
offset: Number of transactions to skip
|
||||
|
||||
Returns:
|
||||
List of credit transactions ordered by creation date (newest first)
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.action_type == action_type)
|
||||
.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
|
||||
async def get_successful_transactions(
|
||||
self,
|
||||
user_id: int | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[CreditTransaction]:
|
||||
"""Get successful credit transactions.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter by
|
||||
limit: Maximum number of transactions to return
|
||||
offset: Number of transactions to skip
|
||||
|
||||
Returns:
|
||||
List of successful credit transactions
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.success == True) # noqa: E712
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(CreditTransaction.user_id == user_id)
|
||||
|
||||
stmt = (
|
||||
stmt.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
383
app/services/credit.py
Normal file
383
app/services/credit.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""Credit management service for tracking and validating user credit usage."""
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.services.socket import socket_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class InsufficientCreditsError(Exception):
|
||||
"""Raised when user has insufficient credits for an action."""
|
||||
|
||||
def __init__(self, required: int, available: int) -> None:
|
||||
"""Initialize the error.
|
||||
|
||||
Args:
|
||||
required: Number of credits required
|
||||
available: Number of credits available
|
||||
|
||||
"""
|
||||
self.required = required
|
||||
self.available = available
|
||||
super().__init__(
|
||||
f"Insufficient credits: {required} required, {available} available"
|
||||
)
|
||||
|
||||
|
||||
class CreditService:
|
||||
"""Service for managing user credits and transactions."""
|
||||
|
||||
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
||||
"""Initialize the credit service.
|
||||
|
||||
Args:
|
||||
db_session_factory: Factory function to create database sessions
|
||||
|
||||
"""
|
||||
self.db_session_factory = db_session_factory
|
||||
|
||||
async def check_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
) -> bool:
|
||||
"""Check if user has sufficient credits for an action.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action to check
|
||||
|
||||
Returns:
|
||||
True if user has sufficient credits, False otherwise
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
return user.credits >= action.cost
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def validate_and_reserve_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> tuple[User, CreditAction]:
|
||||
"""Validate user has sufficient credits and optionally reserve them.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
Tuple of (user, credit_action)
|
||||
|
||||
Raises:
|
||||
InsufficientCreditsError: If user has insufficient credits
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
if user.credits < action.cost:
|
||||
raise InsufficientCreditsError(action.cost, user.credits)
|
||||
|
||||
logger.info(
|
||||
"Credits validated for user %s: %s credits available, %s required",
|
||||
user_id,
|
||||
user.credits,
|
||||
action.cost,
|
||||
)
|
||||
return user, action
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def deduct_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
success: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
"""Deduct credits from user account and record transaction.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action
|
||||
success: Whether the action was successful
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
The created credit transaction
|
||||
|
||||
Raises:
|
||||
InsufficientCreditsError: If user has insufficient credits
|
||||
ValueError: If user not found
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
|
||||
# Only deduct if action requires success and was successful, or doesn't require success
|
||||
should_deduct = (action.requires_success and success) or not action.requires_success
|
||||
|
||||
if not should_deduct:
|
||||
logger.info(
|
||||
"Skipping credit deduction for user %s: action %s failed and requires success",
|
||||
user_id,
|
||||
action_type.value,
|
||||
)
|
||||
# Still create a transaction record for auditing
|
||||
return await self._create_transaction_record(
|
||||
user_id, action, 0, success, metadata
|
||||
)
|
||||
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
if user.credits < action.cost:
|
||||
raise InsufficientCreditsError(action.cost, user.credits)
|
||||
|
||||
# Record transaction
|
||||
balance_before = user.credits
|
||||
balance_after = user.credits - action.cost
|
||||
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type=action_type.value,
|
||||
amount=-action.cost,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance_after,
|
||||
description=action.description,
|
||||
success=success,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
# Update user credits
|
||||
await user_repo.update(user, {"credits": balance_after})
|
||||
|
||||
# Save transaction
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
|
||||
user_id,
|
||||
action.cost,
|
||||
action_type.value,
|
||||
success,
|
||||
)
|
||||
|
||||
# Emit user_credits_changed event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"user_id": str(user_id),
|
||||
"credits_before": balance_before,
|
||||
"credits_after": balance_after,
|
||||
"credits_deducted": action.cost,
|
||||
"action_type": action_type.value,
|
||||
"success": success,
|
||||
}
|
||||
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def add_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
amount: int,
|
||||
description: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
"""Add credits to user account.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
amount: Number of credits to add
|
||||
description: Description of the credit addition
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
The created credit transaction
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found or amount is negative
|
||||
|
||||
"""
|
||||
if amount <= 0:
|
||||
msg = "Amount must be positive"
|
||||
raise ValueError(msg)
|
||||
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Record transaction
|
||||
balance_before = user.credits
|
||||
balance_after = user.credits + amount
|
||||
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type="credit_addition",
|
||||
amount=amount,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance_after,
|
||||
description=description,
|
||||
success=True,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
# Update user credits
|
||||
await user_repo.update(user, {"credits": balance_after})
|
||||
|
||||
# Save transaction
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Credits added for user %s: %s credits (description: %s)",
|
||||
user_id,
|
||||
amount,
|
||||
description,
|
||||
)
|
||||
|
||||
# Emit user_credits_changed event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"user_id": str(user_id),
|
||||
"credits_before": balance_before,
|
||||
"credits_after": balance_after,
|
||||
"credits_added": amount,
|
||||
"description": description,
|
||||
"success": True,
|
||||
}
|
||||
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def _create_transaction_record(
|
||||
self,
|
||||
user_id: int,
|
||||
action: CreditAction,
|
||||
amount: int,
|
||||
success: bool,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
"""Create a transaction record without modifying credits.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action: The credit action
|
||||
amount: Amount to record (typically 0 for failed actions)
|
||||
success: Whether the action was successful
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
The created transaction record
|
||||
|
||||
"""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type=action.action_type.value,
|
||||
amount=amount,
|
||||
balance_before=user.credits,
|
||||
balance_after=user.credits,
|
||||
description=f"{action.description} (failed)" if not success else action.description,
|
||||
success=success,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def get_user_balance(self, user_id: int) -> int:
|
||||
"""Get current credit balance for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
Current credit balance
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found
|
||||
|
||||
"""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
return user.credits
|
||||
finally:
|
||||
await session.close()
|
||||
192
app/utils/credit_decorators.py
Normal file
192
app/utils/credit_decorators.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Decorators for credit management and validation."""
|
||||
|
||||
import functools
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
|
||||
|
||||
|
||||
def requires_credits(
|
||||
action_type: CreditActionType,
|
||||
credit_service_factory: Callable[[], CreditService],
|
||||
user_id_param: str = "user_id",
|
||||
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to enforce credit requirements for actions.
|
||||
|
||||
Args:
|
||||
action_type: The type of action that requires credits
|
||||
credit_service_factory: Factory to create credit service instance
|
||||
user_id_param: Name of the parameter containing user ID
|
||||
metadata_extractor: Optional function to extract metadata from function args
|
||||
|
||||
Returns:
|
||||
Decorated function that validates and deducts credits
|
||||
|
||||
Example:
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
lambda: get_credit_service(),
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def play_sound_for_user(user_id: int, sound: Sound) -> bool:
|
||||
# Implementation here
|
||||
return True
|
||||
|
||||
"""
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
# Extract user ID from parameters
|
||||
user_id = None
|
||||
if user_id_param in kwargs:
|
||||
user_id = kwargs[user_id_param]
|
||||
else:
|
||||
# Try to find user_id in function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
if user_id_param in param_names:
|
||||
param_index = param_names.index(user_id_param)
|
||||
if param_index < len(args):
|
||||
user_id = args[param_index]
|
||||
|
||||
if user_id is None:
|
||||
msg = f"Could not extract user_id from parameter '{user_id_param}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Extract metadata if extractor provided
|
||||
metadata = None
|
||||
if metadata_extractor:
|
||||
metadata = metadata_extractor(*args, **kwargs)
|
||||
|
||||
# Get credit service
|
||||
credit_service = credit_service_factory()
|
||||
|
||||
# Validate credits before execution
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
user_id, action_type, metadata
|
||||
)
|
||||
|
||||
# Execute the function
|
||||
success = False
|
||||
result = None
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
success = bool(result) # Consider function result as success indicator
|
||||
return result
|
||||
except Exception:
|
||||
success = False
|
||||
raise
|
||||
finally:
|
||||
# Deduct credits based on success
|
||||
await credit_service.deduct_credits(
|
||||
user_id, action_type, success, metadata
|
||||
)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_credits_only(
|
||||
action_type: CreditActionType,
|
||||
credit_service_factory: Callable[[], CreditService],
|
||||
user_id_param: str = "user_id",
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to only validate credits without deducting them.
|
||||
|
||||
Useful for checking if a user can perform an action before actual execution.
|
||||
|
||||
Args:
|
||||
action_type: The type of action that requires credits
|
||||
credit_service_factory: Factory to create credit service instance
|
||||
user_id_param: Name of the parameter containing user ID
|
||||
|
||||
Returns:
|
||||
Decorated function that validates credits only
|
||||
|
||||
"""
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
# Extract user ID from parameters
|
||||
user_id = None
|
||||
if user_id_param in kwargs:
|
||||
user_id = kwargs[user_id_param]
|
||||
else:
|
||||
# Try to find user_id in function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
if user_id_param in param_names:
|
||||
param_index = param_names.index(user_id_param)
|
||||
if param_index < len(args):
|
||||
user_id = args[param_index]
|
||||
|
||||
if user_id is None:
|
||||
msg = f"Could not extract user_id from parameter '{user_id_param}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get credit service
|
||||
credit_service = credit_service_factory()
|
||||
|
||||
# Validate credits only
|
||||
await credit_service.validate_and_reserve_credits(user_id, action_type)
|
||||
|
||||
# Execute the function
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
return decorator
|
||||
|
||||
|
||||
class CreditManager:
|
||||
"""Context manager for credit operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credit_service: CreditService,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize credit manager.
|
||||
|
||||
Args:
|
||||
credit_service: Credit service instance
|
||||
user_id: User ID
|
||||
action_type: Action type
|
||||
metadata: Optional metadata
|
||||
|
||||
"""
|
||||
self.credit_service = credit_service
|
||||
self.user_id = user_id
|
||||
self.action_type = action_type
|
||||
self.metadata = metadata
|
||||
self.validated = False
|
||||
self.success = False
|
||||
|
||||
async def __aenter__(self) -> "CreditManager":
|
||||
"""Enter context manager - validate credits."""
|
||||
await self.credit_service.validate_and_reserve_credits(
|
||||
self.user_id, self.action_type, self.metadata
|
||||
)
|
||||
self.validated = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
|
||||
"""Exit context manager - deduct credits based on success."""
|
||||
if self.validated:
|
||||
# If no exception occurred, consider it successful
|
||||
success = exc_type is None and self.success
|
||||
await self.credit_service.deduct_credits(
|
||||
self.user_id, self.action_type, success, self.metadata
|
||||
)
|
||||
|
||||
def mark_success(self) -> None:
|
||||
"""Mark the operation as successful."""
|
||||
self.success = True
|
||||
Reference in New Issue
Block a user