- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates. - Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs. - Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback. - Ensured proper error handling and edge cases in task execution scenarios. - Added fixtures and mocks to facilitate isolated testing of services and repositories.
376 lines
10 KiB
Python
376 lines
10 KiB
Python
"""Test configuration and fixtures."""
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlmodel import SQLModel, select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.api import api_router
|
|
from app.core.database import get_db
|
|
from app.middleware.logging import LoggingMiddleware
|
|
|
|
# Import all models to ensure SQLAlchemy relationships are properly resolved
|
|
from app.models.credit_action import CreditAction # noqa: F401
|
|
from app.models.credit_transaction import CreditTransaction # noqa: F401
|
|
from app.models.extraction import Extraction # noqa: F401
|
|
from app.models.favorite import Favorite # noqa: F401
|
|
from app.models.plan import Plan
|
|
from app.models.playlist import Playlist # noqa: F401
|
|
from app.models.playlist_sound import PlaylistSound # noqa: F401
|
|
from app.models.scheduled_task import ScheduledTask # noqa: F401
|
|
from app.models.sound import Sound # noqa: F401
|
|
from app.models.sound_played import SoundPlayed # noqa: F401
|
|
from app.models.user import User
|
|
from app.models.user_oauth import UserOauth # noqa: F401
|
|
from app.utils.auth import JWTUtils, PasswordUtils
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop() -> Any:
|
|
"""Create an instance of the default event loop for the test session."""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
async def test_engine() -> Any:
|
|
"""Create a test database engine."""
|
|
# Use in-memory SQLite database for tests
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
echo=False,
|
|
)
|
|
|
|
# Create all tables
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
|
|
yield engine
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_session(test_engine: Any) -> AsyncGenerator[AsyncSession]:
|
|
"""Create a test database session."""
|
|
connection = await test_engine.connect()
|
|
transaction = await connection.begin()
|
|
|
|
session = AsyncSession(bind=connection)
|
|
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
await transaction.rollback()
|
|
await connection.close()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_app(test_session: AsyncSession) -> FastAPI:
|
|
"""Create a test FastAPI application."""
|
|
# Create FastAPI app directly for testing (without Socket.IO)
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:8001"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
app.add_middleware(LoggingMiddleware)
|
|
app.include_router(api_router)
|
|
|
|
# Override the database dependency
|
|
async def override_get_db() -> AsyncGenerator[AsyncSession]:
|
|
yield test_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
return app
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_client(test_app) -> AsyncGenerator[AsyncClient]:
|
|
"""Create a test HTTP client."""
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=test_app),
|
|
base_url="http://test",
|
|
) as client:
|
|
yield client
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def authenticated_client(
|
|
test_app: FastAPI,
|
|
auth_cookies: dict[str, str],
|
|
) -> AsyncGenerator[AsyncClient]:
|
|
"""Create a test HTTP client with authentication cookies."""
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=test_app),
|
|
base_url="http://test",
|
|
cookies=auth_cookies,
|
|
) as client:
|
|
yield client
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def authenticated_admin_client(
|
|
test_app: FastAPI,
|
|
admin_cookies: dict[str, str],
|
|
) -> AsyncGenerator[AsyncClient]:
|
|
"""Create a test HTTP client with admin authentication cookies."""
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=test_app),
|
|
base_url="http://test",
|
|
cookies=admin_cookies,
|
|
) as client:
|
|
yield client
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_plan(test_session: AsyncSession) -> Plan:
|
|
"""Create a test plan."""
|
|
# Check if plan already exists in this session
|
|
existing_plan = await test_session.exec(select(Plan).where(Plan.code == "free"))
|
|
plan = existing_plan.first()
|
|
|
|
if not plan:
|
|
plan = Plan(
|
|
code="free",
|
|
name="Free Plan",
|
|
description="Test free plan",
|
|
credits=100,
|
|
max_credits=100,
|
|
)
|
|
test_session.add(plan)
|
|
await test_session.commit()
|
|
await test_session.refresh(plan)
|
|
|
|
return plan
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_pro_plan(test_session: AsyncSession) -> Plan:
|
|
"""Create a test pro plan."""
|
|
# Check if plan already exists in this session
|
|
existing_plan = await test_session.exec(select(Plan).where(Plan.code == "pro"))
|
|
plan = existing_plan.first()
|
|
|
|
if not plan:
|
|
plan = Plan(
|
|
code="pro",
|
|
name="Pro Plan",
|
|
description="Test pro plan",
|
|
credits=300,
|
|
max_credits=300,
|
|
)
|
|
test_session.add(plan)
|
|
await test_session.commit()
|
|
await test_session.refresh(plan)
|
|
|
|
return plan
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def ensure_plans(test_session: AsyncSession) -> tuple[Plan, Plan]:
|
|
"""Ensure both free and pro plans exist."""
|
|
# Check for free plan
|
|
free_result = await test_session.exec(select(Plan).where(Plan.code == "free"))
|
|
free_plan = free_result.first()
|
|
|
|
if not free_plan:
|
|
free_plan = Plan(
|
|
code="free",
|
|
name="Free Plan",
|
|
description="Test free plan",
|
|
credits=100,
|
|
max_credits=100,
|
|
)
|
|
test_session.add(free_plan)
|
|
|
|
# Check for pro plan
|
|
pro_result = await test_session.exec(select(Plan).where(Plan.code == "pro"))
|
|
pro_plan = pro_result.first()
|
|
|
|
if not pro_plan:
|
|
pro_plan = Plan(
|
|
code="pro",
|
|
name="Pro Plan",
|
|
description="Test pro plan",
|
|
credits=300,
|
|
max_credits=300,
|
|
)
|
|
test_session.add(pro_plan)
|
|
|
|
await test_session.commit()
|
|
await test_session.refresh(free_plan)
|
|
await test_session.refresh(pro_plan)
|
|
|
|
return free_plan, pro_plan
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(
|
|
test_session: AsyncSession,
|
|
ensure_plans: tuple[Plan, Plan],
|
|
) -> User:
|
|
"""Create a test user."""
|
|
user = User(
|
|
email="test@example.com",
|
|
name="Test User",
|
|
password_hash=PasswordUtils.hash_password("testpassword123"),
|
|
role="user",
|
|
is_active=True,
|
|
plan_id=ensure_plans[0].id, # Use free plan
|
|
credits=100,
|
|
)
|
|
test_session.add(user)
|
|
await test_session.commit()
|
|
await test_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def admin_user(
|
|
test_session: AsyncSession,
|
|
ensure_plans: tuple[Plan, Plan],
|
|
) -> User:
|
|
"""Create a test admin user."""
|
|
user = User(
|
|
email="admin@example.com",
|
|
name="Admin User",
|
|
password_hash=PasswordUtils.hash_password("adminpassword123"),
|
|
role="admin",
|
|
is_active=True,
|
|
plan_id=ensure_plans[1].id, # Use pro plan for admin
|
|
credits=1000,
|
|
)
|
|
test_session.add(user)
|
|
await test_session.commit()
|
|
await test_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest.fixture
|
|
def test_user_data() -> dict[str, Any]:
|
|
"""Test user registration data."""
|
|
return {
|
|
"email": "newuser@example.com",
|
|
"password": "newpassword123",
|
|
"name": "New User",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def test_login_data() -> dict[str, str]:
|
|
"""Test user login data."""
|
|
return {
|
|
"email": "test@example.com",
|
|
"password": "testpassword123",
|
|
}
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def auth_headers(test_user: User) -> dict[str, str]:
|
|
"""Create authentication headers with JWT token."""
|
|
token_data = {
|
|
"sub": str(test_user.id),
|
|
"email": test_user.email,
|
|
"role": test_user.role,
|
|
}
|
|
|
|
access_token = JWTUtils.create_access_token(token_data)
|
|
|
|
return {"Authorization": f"Bearer {access_token}"}
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def admin_headers(admin_user: User) -> dict[str, str]:
|
|
"""Create admin authentication headers with JWT token."""
|
|
token_data = {
|
|
"sub": str(admin_user.id),
|
|
"email": admin_user.email,
|
|
"role": admin_user.role,
|
|
}
|
|
|
|
access_token = JWTUtils.create_access_token(token_data)
|
|
|
|
return {"Authorization": f"Bearer {access_token}"}
|
|
|
|
|
|
@pytest.fixture
|
|
def client(test_client: AsyncClient) -> AsyncClient:
|
|
"""Alias for test_client to match test expectations."""
|
|
return test_client
|
|
|
|
|
|
@pytest.fixture
|
|
def authenticated_user(test_user: User) -> User:
|
|
"""Alias for test_user to match test expectations."""
|
|
return test_user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def auth_cookies(test_user: User) -> dict[str, str]:
|
|
"""Create authentication cookies with JWT token."""
|
|
token_data = {
|
|
"sub": str(test_user.id),
|
|
"email": test_user.email,
|
|
"role": test_user.role,
|
|
}
|
|
|
|
access_token = JWTUtils.create_access_token(token_data)
|
|
|
|
return {"access_token": access_token}
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def admin_cookies(admin_user: User) -> dict[str, str]:
|
|
"""Create admin authentication cookies with JWT token."""
|
|
token_data = {
|
|
"sub": str(admin_user.id),
|
|
"email": admin_user.email,
|
|
"role": admin_user.role,
|
|
}
|
|
|
|
access_token = JWTUtils.create_access_token(token_data)
|
|
|
|
return {"access_token": access_token}
|
|
|
|
|
|
@pytest.fixture
|
|
def test_user_id(test_user: User):
|
|
"""Get test user ID."""
|
|
return test_user.id
|
|
|
|
|
|
@pytest.fixture
|
|
def test_sound_id():
|
|
"""Create a test sound ID."""
|
|
import uuid
|
|
return uuid.uuid4()
|
|
|
|
|
|
@pytest.fixture
|
|
def test_playlist_id():
|
|
"""Create a test playlist ID."""
|
|
import uuid
|
|
return uuid.uuid4()
|
|
|
|
|
|
@pytest.fixture
|
|
def db_session(test_session: AsyncSession) -> AsyncSession:
|
|
"""Alias for test_session to match test expectations."""
|
|
return test_session
|