255 lines
6.7 KiB
Python
255 lines
6.7 KiB
Python
"""Test configuration and fixtures."""
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from fastapi import FastAPI
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlmodel import SQLModel, select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.database import get_db
|
|
from app.main import create_app
|
|
from app.models.plan import Plan
|
|
from app.models.user import User
|
|
from app.utils.auth import JWTUtils, PasswordUtils
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop() -> Any:
|
|
"""Create an instance of the default event loop for the test session."""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
async def test_engine() -> Any:
|
|
"""Create a test database engine."""
|
|
# Use in-memory SQLite database for tests
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
echo=False,
|
|
)
|
|
|
|
# Create all tables
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
|
|
yield engine
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_session(test_engine: Any) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Create a test database session."""
|
|
connection = await test_engine.connect()
|
|
transaction = await connection.begin()
|
|
|
|
session = AsyncSession(bind=connection)
|
|
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
await transaction.rollback()
|
|
await connection.close()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_app(test_session: AsyncSession) -> FastAPI:
|
|
"""Create a test FastAPI application."""
|
|
app = create_app()
|
|
|
|
# Override the database dependency
|
|
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
yield test_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
return app
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_client(test_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create a test HTTP client."""
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=test_app),
|
|
base_url="http://test",
|
|
) as client:
|
|
yield client
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_plan(test_session: AsyncSession) -> Plan:
|
|
"""Create a test plan."""
|
|
# Check if plan already exists in this session
|
|
existing_plan = await test_session.exec(select(Plan).where(Plan.code == "free"))
|
|
plan = existing_plan.first()
|
|
|
|
if not plan:
|
|
plan = Plan(
|
|
code="free",
|
|
name="Free Plan",
|
|
description="Test free plan",
|
|
credits=100,
|
|
max_credits=100,
|
|
)
|
|
test_session.add(plan)
|
|
await test_session.commit()
|
|
await test_session.refresh(plan)
|
|
|
|
return plan
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_pro_plan(test_session: AsyncSession) -> Plan:
|
|
"""Create a test pro plan."""
|
|
# Check if plan already exists in this session
|
|
existing_plan = await test_session.exec(select(Plan).where(Plan.code == "pro"))
|
|
plan = existing_plan.first()
|
|
|
|
if not plan:
|
|
plan = Plan(
|
|
code="pro",
|
|
name="Pro Plan",
|
|
description="Test pro plan",
|
|
credits=300,
|
|
max_credits=300,
|
|
)
|
|
test_session.add(plan)
|
|
await test_session.commit()
|
|
await test_session.refresh(plan)
|
|
|
|
return plan
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def ensure_plans(test_session: AsyncSession) -> tuple[Plan, Plan]:
|
|
"""Ensure both free and pro plans exist."""
|
|
# Check for free plan
|
|
free_result = await test_session.exec(select(Plan).where(Plan.code == "free"))
|
|
free_plan = free_result.first()
|
|
|
|
if not free_plan:
|
|
free_plan = Plan(
|
|
code="free",
|
|
name="Free Plan",
|
|
description="Test free plan",
|
|
credits=100,
|
|
max_credits=100,
|
|
)
|
|
test_session.add(free_plan)
|
|
|
|
# Check for pro plan
|
|
pro_result = await test_session.exec(select(Plan).where(Plan.code == "pro"))
|
|
pro_plan = pro_result.first()
|
|
|
|
if not pro_plan:
|
|
pro_plan = Plan(
|
|
code="pro",
|
|
name="Pro Plan",
|
|
description="Test pro plan",
|
|
credits=300,
|
|
max_credits=300,
|
|
)
|
|
test_session.add(pro_plan)
|
|
|
|
await test_session.commit()
|
|
await test_session.refresh(free_plan)
|
|
await test_session.refresh(pro_plan)
|
|
|
|
return free_plan, pro_plan
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(
|
|
test_session: AsyncSession, ensure_plans: tuple[Plan, Plan]
|
|
) -> User:
|
|
"""Create a test user."""
|
|
user = User(
|
|
email="test@example.com",
|
|
name="Test User",
|
|
password_hash=PasswordUtils.hash_password("testpassword123"),
|
|
role="user",
|
|
is_active=True,
|
|
plan_id=ensure_plans[0].id, # Use free plan
|
|
credits=100,
|
|
)
|
|
test_session.add(user)
|
|
await test_session.commit()
|
|
await test_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def admin_user(
|
|
test_session: AsyncSession, ensure_plans: tuple[Plan, Plan]
|
|
) -> User:
|
|
"""Create a test admin user."""
|
|
user = User(
|
|
email="admin@example.com",
|
|
name="Admin User",
|
|
password_hash=PasswordUtils.hash_password("adminpassword123"),
|
|
role="admin",
|
|
is_active=True,
|
|
plan_id=ensure_plans[1].id, # Use pro plan for admin
|
|
credits=1000,
|
|
)
|
|
test_session.add(user)
|
|
await test_session.commit()
|
|
await test_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest.fixture
|
|
def test_user_data() -> dict[str, Any]:
|
|
"""Test user registration data."""
|
|
return {
|
|
"email": "newuser@example.com",
|
|
"password": "newpassword123",
|
|
"name": "New User",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def test_login_data() -> dict[str, str]:
|
|
"""Test user login data."""
|
|
return {
|
|
"email": "test@example.com",
|
|
"password": "testpassword123",
|
|
}
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def auth_headers(test_user: User) -> dict[str, str]:
|
|
"""Create authentication headers with JWT token."""
|
|
token_data = {
|
|
"sub": str(test_user.id),
|
|
"email": test_user.email,
|
|
"role": test_user.role,
|
|
}
|
|
|
|
access_token = JWTUtils.create_access_token(token_data)
|
|
|
|
return {"Authorization": f"Bearer {access_token}"}
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def admin_headers(admin_user: User) -> dict[str, str]:
|
|
"""Create admin authentication headers with JWT token."""
|
|
token_data = {
|
|
"sub": str(admin_user.id),
|
|
"email": admin_user.email,
|
|
"role": admin_user.role,
|
|
}
|
|
|
|
access_token = JWTUtils.create_access_token(token_data)
|
|
|
|
return {"Authorization": f"Bearer {access_token}"}
|