Add tests for sound repository, user OAuth repository, credit service, and credit decorators

- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities.
- Create tests for UserOauthRepository to validate OAuth record management.
- Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits.
- Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
JSC
2025-07-30 21:33:55 +02:00
parent dd10ef5d41
commit e43650c26c
14 changed files with 2692 additions and 1 deletions

View File

@@ -0,0 +1,412 @@
"""Tests for credit transaction repository."""
import json
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.repositories.credit_transaction import CreditTransactionRepository
class TestCreditTransactionRepository:
"""Test credit transaction repository operations."""
@pytest_asyncio.fixture
async def credit_transaction_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[CreditTransactionRepository, None]: # type: ignore[misc]
"""Create a credit transaction repository instance."""
yield CreditTransactionRepository(test_session)
@pytest_asyncio.fixture
async def test_user_id(
self,
test_user: User,
) -> int:
"""Get test user ID to avoid lazy loading issues."""
return test_user.id
@pytest_asyncio.fixture
async def test_transactions(
self,
test_session: AsyncSession,
test_user_id: int,
) -> AsyncGenerator[list[CreditTransaction], None]: # type: ignore[misc]
"""Create test credit transactions."""
transactions = []
user_id = test_user_id
# Create various types of transactions
transaction_data = [
{
"user_id": user_id,
"action_type": "vlc_play_sound",
"amount": -1,
"balance_before": 10,
"balance_after": 9,
"description": "Play sound via VLC",
"success": True,
"metadata_json": json.dumps({"sound_id": 1, "sound_name": "test.mp3"}),
},
{
"user_id": user_id,
"action_type": "audio_extraction",
"amount": -5,
"balance_before": 9,
"balance_after": 4,
"description": "Extract audio from URL",
"success": True,
"metadata_json": json.dumps({"url": "https://example.com/video"}),
},
{
"user_id": user_id,
"action_type": "vlc_play_sound",
"amount": 0,
"balance_before": 4,
"balance_after": 4,
"description": "Play sound via VLC (failed)",
"success": False,
"metadata_json": json.dumps({"sound_id": 2, "error": "File not found"}),
},
{
"user_id": user_id,
"action_type": "credit_addition",
"amount": 50,
"balance_before": 4,
"balance_after": 54,
"description": "Bonus credits",
"success": True,
"metadata_json": json.dumps({"reason": "signup_bonus"}),
},
]
for data in transaction_data:
transaction = CreditTransaction(**data)
test_session.add(transaction)
transactions.append(transaction)
await test_session.commit()
for transaction in transactions:
await test_session.refresh(transaction)
yield transactions
@pytest_asyncio.fixture
async def other_user_transaction(
self,
test_session: AsyncSession,
ensure_plans: tuple, # noqa: ARG002
) -> AsyncGenerator[CreditTransaction, None]: # type: ignore[misc]
"""Create a transaction for a different user."""
from app.models.plan import Plan
from app.repositories.user import UserRepository
# Create another user
user_repo = UserRepository(test_session)
other_user_data = {
"email": "other@example.com",
"name": "Other User",
"password_hash": "hashed_password",
"role": "user",
"is_active": True,
}
other_user = await user_repo.create(other_user_data)
# Create transaction for the other user
transaction_data = {
"user_id": other_user.id,
"action_type": "vlc_play_sound",
"amount": -1,
"balance_before": 100,
"balance_after": 99,
"description": "Other user play sound",
"success": True,
"metadata_json": None,
}
transaction = CreditTransaction(**transaction_data)
test_session.add(transaction)
await test_session.commit()
await test_session.refresh(transaction)
yield transaction
@pytest.mark.asyncio
async def test_get_by_id_existing(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting transaction by ID when it exists."""
transaction = await credit_transaction_repository.get_by_id(test_transactions[0].id)
assert transaction is not None
assert transaction.id == test_transactions[0].id
assert transaction.action_type == "vlc_play_sound"
assert transaction.amount == -1
@pytest.mark.asyncio
async def test_get_by_id_nonexistent(
self,
credit_transaction_repository: CreditTransactionRepository,
) -> None:
"""Test getting transaction by ID when it doesn't exist."""
transaction = await credit_transaction_repository.get_by_id(99999)
assert transaction is None
@pytest.mark.asyncio
async def test_get_by_user_id(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
other_user_transaction: CreditTransaction,
test_user_id: int,
) -> None:
"""Test getting transactions by user ID."""
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
# Should return all transactions for test_user
assert len(transactions) == 4
# Should be ordered by created_at desc (newest first)
assert all(t.user_id == test_user_id for t in transactions)
# Should not include other user's transaction
other_user_ids = [t.user_id for t in transactions]
assert other_user_transaction.user_id not in other_user_ids
@pytest.mark.asyncio
async def test_get_by_user_id_with_pagination(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
test_user_id: int,
) -> None:
"""Test getting transactions by user ID with pagination."""
# Get first 2 transactions
first_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=0
)
assert len(first_page) == 2
# Get next 2 transactions
second_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=2
)
assert len(second_page) == 2
# Should not overlap
first_page_ids = {t.id for t in first_page}
second_page_ids = {t.id for t in second_page}
assert first_page_ids.isdisjoint(second_page_ids)
@pytest.mark.asyncio
async def test_get_by_action_type(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting transactions by action type."""
vlc_transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound"
)
# Should return 2 VLC transactions (1 successful, 1 failed)
assert len(vlc_transactions) >= 2
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
extraction_transactions = await credit_transaction_repository.get_by_action_type(
"audio_extraction"
)
# Should return 1 extraction transaction
assert len(extraction_transactions) >= 1
assert all(t.action_type == "audio_extraction" for t in extraction_transactions)
@pytest.mark.asyncio
async def test_get_by_action_type_with_pagination(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting transactions by action type with pagination."""
# Test with limit
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1
)
assert len(transactions) == 1
assert transactions[0].action_type == "vlc_play_sound"
# Test with offset
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1, offset=1
)
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
@pytest.mark.asyncio
async def test_get_successful_transactions(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting only successful transactions."""
successful_transactions = await credit_transaction_repository.get_successful_transactions()
# Should only return successful transactions
assert all(t.success is True for t in successful_transactions)
# Should be at least 3 (vlc_play_sound, audio_extraction, credit_addition)
assert len(successful_transactions) >= 3
@pytest.mark.asyncio
async def test_get_successful_transactions_by_user(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
other_user_transaction: CreditTransaction,
test_user_id: int,
) -> None:
"""Test getting successful transactions filtered by user."""
successful_transactions = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id
)
# Should only return successful transactions for test_user
assert all(t.success is True for t in successful_transactions)
assert all(t.user_id == test_user_id for t in successful_transactions)
# Should be 3 successful transactions for test_user
assert len(successful_transactions) == 3
@pytest.mark.asyncio
async def test_get_successful_transactions_with_pagination(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
test_user_id: int,
) -> None:
"""Test getting successful transactions with pagination."""
# Get first 2 successful transactions
first_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=0
)
assert len(first_page) == 2
assert all(t.success is True for t in first_page)
# Get next successful transaction
second_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=2
)
assert len(second_page) == 1 # Should be 1 remaining
assert all(t.success is True for t in second_page)
@pytest.mark.asyncio
async def test_get_all_transactions(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
other_user_transaction: CreditTransaction,
) -> None:
"""Test getting all transactions."""
all_transactions = await credit_transaction_repository.get_all()
# Should return all transactions
assert len(all_transactions) >= 5 # 4 from test_transactions + 1 other_user_transaction
@pytest.mark.asyncio
async def test_create_transaction(
self,
credit_transaction_repository: CreditTransactionRepository,
test_user_id: int,
) -> None:
"""Test creating a new transaction."""
transaction_data = {
"user_id": test_user_id,
"action_type": "test_action",
"amount": -10,
"balance_before": 100,
"balance_after": 90,
"description": "Test transaction",
"success": True,
"metadata_json": json.dumps({"test": "data"}),
}
transaction = await credit_transaction_repository.create(transaction_data)
assert transaction.id is not None
assert transaction.user_id == test_user_id
assert transaction.action_type == "test_action"
assert transaction.amount == -10
assert transaction.balance_before == 100
assert transaction.balance_after == 90
assert transaction.success is True
assert json.loads(transaction.metadata_json) == {"test": "data"}
@pytest.mark.asyncio
async def test_update_transaction(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test updating a transaction."""
transaction = test_transactions[0]
update_data = {
"description": "Updated description",
"metadata_json": json.dumps({"updated": True}),
}
updated_transaction = await credit_transaction_repository.update(
transaction, update_data
)
assert updated_transaction.id == transaction.id
assert updated_transaction.description == "Updated description"
assert json.loads(updated_transaction.metadata_json) == {"updated": True}
# Other fields should remain unchanged
assert updated_transaction.amount == transaction.amount
assert updated_transaction.action_type == transaction.action_type
@pytest.mark.asyncio
async def test_delete_transaction(
self,
credit_transaction_repository: CreditTransactionRepository,
test_session: AsyncSession,
test_user_id: int,
) -> None:
"""Test deleting a transaction."""
# Create a transaction to delete
transaction_data = {
"user_id": test_user_id,
"action_type": "to_delete",
"amount": -1,
"balance_before": 10,
"balance_after": 9,
"description": "To be deleted",
"success": True,
"metadata_json": None,
}
transaction = await credit_transaction_repository.create(transaction_data)
transaction_id = transaction.id
# Delete the transaction
await credit_transaction_repository.delete(transaction)
# Verify transaction is deleted
deleted_transaction = await credit_transaction_repository.get_by_id(transaction_id)
assert deleted_transaction is None
@pytest.mark.asyncio
async def test_transaction_ordering(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
test_user_id: int,
) -> None:
"""Test that transactions are ordered by created_at desc."""
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
# Should be ordered by created_at desc (newest first)
for i in range(len(transactions) - 1):
assert transactions[i].created_at >= transactions[i + 1].created_at

View File

@@ -0,0 +1,376 @@
"""Tests for sound repository."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.sound import Sound
from app.repositories.sound import SoundRepository
class TestSoundRepository:
"""Test sound repository operations."""
@pytest_asyncio.fixture
async def sound_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[SoundRepository, None]: # type: ignore[misc]
"""Create a sound repository instance."""
yield SoundRepository(test_session)
@pytest_asyncio.fixture
async def test_sound(
self,
test_session: AsyncSession,
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
"""Create a test sound."""
sound_data = {
"name": "Test Sound",
"filename": "test_sound.mp3",
"type": "SDB",
"duration": 5000,
"size": 1024000,
"hash": "test_hash_123",
"play_count": 0,
"is_normalized": False,
}
sound = Sound(**sound_data)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(sound)
yield sound
@pytest_asyncio.fixture
async def normalized_sound(
self,
test_session: AsyncSession,
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
"""Create a normalized test sound."""
sound_data = {
"name": "Normalized Sound",
"filename": "normalized_sound.mp3",
"type": "TTS",
"duration": 3000,
"size": 512000,
"hash": "normalized_hash_456",
"play_count": 5,
"is_normalized": True,
"normalized_filename": "normalized_sound_norm.mp3",
"normalized_duration": 3000,
"normalized_size": 480000,
"normalized_hash": "normalized_hash_norm_456",
}
sound = Sound(**sound_data)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(sound)
yield sound
@pytest.mark.asyncio
async def test_get_by_id_existing(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test getting sound by ID when it exists."""
sound = await sound_repository.get_by_id(test_sound.id)
assert sound is not None
assert sound.id == test_sound.id
assert sound.name == test_sound.name
assert sound.filename == test_sound.filename
assert sound.type == test_sound.type
@pytest.mark.asyncio
async def test_get_by_id_nonexistent(
self,
sound_repository: SoundRepository,
) -> None:
"""Test getting sound by ID when it doesn't exist."""
sound = await sound_repository.get_by_id(99999)
assert sound is None
@pytest.mark.asyncio
async def test_get_by_filename_existing(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test getting sound by filename when it exists."""
sound = await sound_repository.get_by_filename(test_sound.filename)
assert sound is not None
assert sound.id == test_sound.id
assert sound.filename == test_sound.filename
@pytest.mark.asyncio
async def test_get_by_filename_nonexistent(
self,
sound_repository: SoundRepository,
) -> None:
"""Test getting sound by filename when it doesn't exist."""
sound = await sound_repository.get_by_filename("nonexistent.mp3")
assert sound is None
@pytest.mark.asyncio
async def test_get_by_hash_existing(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test getting sound by hash when it exists."""
sound = await sound_repository.get_by_hash(test_sound.hash)
assert sound is not None
assert sound.id == test_sound.id
assert sound.hash == test_sound.hash
@pytest.mark.asyncio
async def test_get_by_hash_nonexistent(
self,
sound_repository: SoundRepository,
) -> None:
"""Test getting sound by hash when it doesn't exist."""
sound = await sound_repository.get_by_hash("nonexistent_hash")
assert sound is None
@pytest.mark.asyncio
async def test_get_by_type(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting sounds by type."""
sdb_sounds = await sound_repository.get_by_type("SDB")
tts_sounds = await sound_repository.get_by_type("TTS")
ext_sounds = await sound_repository.get_by_type("EXT")
# Should find the SDB sound
assert len(sdb_sounds) >= 1
assert any(sound.id == test_sound.id for sound in sdb_sounds)
# Should find the TTS sound
assert len(tts_sounds) >= 1
assert any(sound.id == normalized_sound.id for sound in tts_sounds)
# Should not find any EXT sounds
assert len(ext_sounds) == 0
@pytest.mark.asyncio
async def test_create_sound(
self,
sound_repository: SoundRepository,
) -> None:
"""Test creating a new sound."""
sound_data = {
"name": "New Sound",
"filename": "new_sound.wav",
"type": "EXT",
"duration": 7500,
"size": 2048000,
"hash": "new_hash_789",
"play_count": 0,
"is_normalized": False,
}
sound = await sound_repository.create(sound_data)
assert sound.id is not None
assert sound.name == sound_data["name"]
assert sound.filename == sound_data["filename"]
assert sound.type == sound_data["type"]
assert sound.duration == sound_data["duration"]
assert sound.size == sound_data["size"]
assert sound.hash == sound_data["hash"]
assert sound.play_count == 0
assert sound.is_normalized is False
@pytest.mark.asyncio
async def test_update_sound(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test updating a sound."""
update_data = {
"name": "Updated Sound Name",
"play_count": 10,
"is_normalized": True,
"normalized_filename": "updated_norm.mp3",
}
updated_sound = await sound_repository.update(test_sound, update_data)
assert updated_sound.id == test_sound.id
assert updated_sound.name == "Updated Sound Name"
assert updated_sound.play_count == 10
assert updated_sound.is_normalized is True
assert updated_sound.normalized_filename == "updated_norm.mp3"
assert updated_sound.filename == test_sound.filename # Unchanged
@pytest.mark.asyncio
async def test_delete_sound(
self,
sound_repository: SoundRepository,
test_session: AsyncSession,
) -> None:
"""Test deleting a sound."""
# Create a sound to delete
sound_data = {
"name": "To Delete",
"filename": "to_delete.mp3",
"type": "SDB",
"duration": 1000,
"size": 256000,
"hash": "delete_hash",
"play_count": 0,
"is_normalized": False,
}
sound = await sound_repository.create(sound_data)
sound_id = sound.id
# Delete the sound
await sound_repository.delete(sound)
# Verify sound is deleted
deleted_sound = await sound_repository.get_by_id(sound_id)
assert deleted_sound is None
@pytest.mark.asyncio
async def test_search_by_name(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test searching sounds by name."""
# Search for "test" should find test_sound
results = await sound_repository.search_by_name("test")
assert len(results) >= 1
assert any(sound.id == test_sound.id for sound in results)
# Search for "normalized" should find normalized_sound
results = await sound_repository.search_by_name("normalized")
assert len(results) >= 1
assert any(sound.id == normalized_sound.id for sound in results)
# Case insensitive search
results = await sound_repository.search_by_name("TEST")
assert len(results) >= 1
assert any(sound.id == test_sound.id for sound in results)
# Partial match
results = await sound_repository.search_by_name("norm")
assert len(results) >= 1
assert any(sound.id == normalized_sound.id for sound in results)
# No matches
results = await sound_repository.search_by_name("nonexistent")
assert len(results) == 0
@pytest.mark.asyncio
async def test_get_popular_sounds(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting popular sounds."""
# Update play counts to test ordering
await sound_repository.update(test_sound, {"play_count": 15})
await sound_repository.update(normalized_sound, {"play_count": 5})
# Create another sound with higher play count
high_play_sound_data = {
"name": "Popular Sound",
"filename": "popular.mp3",
"type": "SDB",
"duration": 2000,
"size": 300000,
"hash": "popular_hash",
"play_count": 25,
"is_normalized": False,
}
high_play_sound = await sound_repository.create(high_play_sound_data)
# Get popular sounds
popular_sounds = await sound_repository.get_popular_sounds(limit=10)
assert len(popular_sounds) >= 3
# Should be ordered by play_count desc
assert popular_sounds[0].play_count >= popular_sounds[1].play_count
# The highest play count sound should be first
assert popular_sounds[0].id == high_play_sound.id
@pytest.mark.asyncio
async def test_get_unnormalized_sounds(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting unnormalized sounds."""
unnormalized_sounds = await sound_repository.get_unnormalized_sounds()
# Should include test_sound (not normalized)
assert any(sound.id == test_sound.id for sound in unnormalized_sounds)
# Should not include normalized_sound (already normalized)
assert not any(sound.id == normalized_sound.id for sound in unnormalized_sounds)
@pytest.mark.asyncio
async def test_get_unnormalized_sounds_by_type(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting unnormalized sounds by type."""
# Get unnormalized SDB sounds
sdb_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("SDB")
# Should include test_sound (SDB, not normalized)
assert any(sound.id == test_sound.id for sound in sdb_unnormalized)
# Get unnormalized TTS sounds
tts_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("TTS")
# Should not include normalized_sound (TTS, but already normalized)
assert not any(sound.id == normalized_sound.id for sound in tts_unnormalized)
# Get unnormalized EXT sounds
ext_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("EXT")
# Should be empty
assert len(ext_unnormalized) == 0
@pytest.mark.asyncio
async def test_create_duplicate_hash(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test creating sound with duplicate hash is allowed."""
# Store the hash to avoid lazy loading issues
original_hash = test_sound.hash
duplicate_sound_data = {
"name": "Duplicate Hash Sound",
"filename": "duplicate.mp3",
"type": "SDB",
"duration": 1000,
"size": 100000,
"hash": original_hash, # Same hash as test_sound
"play_count": 0,
"is_normalized": False,
}
# Should succeed - duplicate hashes are allowed
duplicate_sound = await sound_repository.create(duplicate_sound_data)
assert duplicate_sound.id is not None
assert duplicate_sound.name == "Duplicate Hash Sound"
assert duplicate_sound.hash == original_hash # Same hash is allowed

View File

@@ -0,0 +1,268 @@
"""Tests for user OAuth repository."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.user import User
from app.models.user_oauth import UserOauth
from app.repositories.user_oauth import UserOauthRepository
class TestUserOauthRepository:
"""Test user OAuth repository operations."""
@pytest_asyncio.fixture
async def user_oauth_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[UserOauthRepository, None]: # type: ignore[misc]
"""Create a user OAuth repository instance."""
yield UserOauthRepository(test_session)
@pytest_asyncio.fixture
async def test_user_id(
self,
test_user: User,
) -> int:
"""Get test user ID to avoid lazy loading issues."""
return test_user.id
@pytest_asyncio.fixture
async def test_oauth(
self,
test_session: AsyncSession,
test_user_id: int,
) -> AsyncGenerator[UserOauth, None]: # type: ignore[misc]
"""Create a test OAuth record."""
oauth_data = {
"user_id": test_user_id,
"provider": "google",
"provider_user_id": "google_123456",
"email": "test@gmail.com",
"name": "Test User Google",
"picture": None,
}
oauth = UserOauth(**oauth_data)
test_session.add(oauth)
await test_session.commit()
await test_session.refresh(oauth)
yield oauth
@pytest.mark.asyncio
async def test_get_by_provider_user_id_existing(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
) -> None:
"""Test getting OAuth by provider user ID when it exists."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "google_123456"
)
assert oauth is not None
assert oauth.id == test_oauth.id
assert oauth.provider == "google"
assert oauth.provider_user_id == "google_123456"
assert oauth.user_id == test_oauth.user_id
@pytest.mark.asyncio
async def test_get_by_provider_user_id_nonexistent(
self,
user_oauth_repository: UserOauthRepository,
) -> None:
"""Test getting OAuth by provider user ID when it doesn't exist."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "nonexistent_id"
)
assert oauth is None
@pytest.mark.asyncio
async def test_get_by_user_id_and_provider_existing(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
test_user_id: int,
) -> None:
"""Test getting OAuth by user ID and provider when it exists."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
)
assert oauth is not None
assert oauth.id == test_oauth.id
assert oauth.provider == "google"
assert oauth.user_id == test_user_id
@pytest.mark.asyncio
async def test_get_by_user_id_and_provider_nonexistent(
self,
user_oauth_repository: UserOauthRepository,
test_user_id: int,
) -> None:
"""Test getting OAuth by user ID and provider when it doesn't exist."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
)
assert oauth is None
@pytest.mark.asyncio
async def test_create_oauth(
self,
user_oauth_repository: UserOauthRepository,
test_user_id: int,
) -> None:
"""Test creating a new OAuth record."""
oauth_data = {
"user_id": test_user_id,
"provider": "github",
"provider_user_id": "github_789",
"email": "test@github.com",
"name": "Test User GitHub",
"picture": None,
}
oauth = await user_oauth_repository.create(oauth_data)
assert oauth.id is not None
assert oauth.user_id == test_user_id
assert oauth.provider == "github"
assert oauth.provider_user_id == "github_789"
assert oauth.email == "test@github.com"
assert oauth.name == "Test User GitHub"
@pytest.mark.asyncio
async def test_update_oauth(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
) -> None:
"""Test updating an OAuth record."""
update_data = {
"email": "updated@gmail.com",
"name": "Updated User Name",
"picture": "https://example.com/photo.jpg",
}
updated_oauth = await user_oauth_repository.update(test_oauth, update_data)
assert updated_oauth.id == test_oauth.id
assert updated_oauth.email == "updated@gmail.com"
assert updated_oauth.name == "Updated User Name"
assert updated_oauth.picture == "https://example.com/photo.jpg"
assert updated_oauth.provider == test_oauth.provider # Unchanged
assert updated_oauth.provider_user_id == test_oauth.provider_user_id # Unchanged
@pytest.mark.asyncio
async def test_delete_oauth(
self,
user_oauth_repository: UserOauthRepository,
test_session: AsyncSession,
test_user_id: int,
) -> None:
"""Test deleting an OAuth record."""
# Create an OAuth record to delete
oauth_data = {
"user_id": test_user_id,
"provider": "twitter",
"provider_user_id": "twitter_456",
"email": "test@twitter.com",
"name": "Test User Twitter",
"picture": None,
}
oauth = await user_oauth_repository.create(oauth_data)
oauth_id = oauth.id
# Delete the OAuth record
await user_oauth_repository.delete(oauth)
# Verify it's deleted by trying to find it
deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
"twitter", "twitter_456"
)
assert deleted_oauth is None
@pytest.mark.asyncio
async def test_create_duplicate_provider_user_id(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
test_user_id: int,
) -> None:
"""Test creating OAuth with duplicate provider user ID should fail."""
# Try to create another OAuth with the same provider and provider_user_id
duplicate_oauth_data = {
"user_id": test_user_id,
"provider": "google",
"provider_user_id": "google_123456", # Same as test_oauth
"email": "another@gmail.com",
"name": "Another User",
"picture": None,
}
# This should fail due to unique constraint
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
await user_oauth_repository.create(duplicate_oauth_data)
@pytest.mark.asyncio
async def test_multiple_providers_same_user(
self,
user_oauth_repository: UserOauthRepository,
test_user_id: int,
) -> None:
"""Test that a user can have multiple OAuth providers."""
# Create Google OAuth
google_oauth_data = {
"user_id": test_user_id,
"provider": "google",
"provider_user_id": "google_user_1",
"email": "user@gmail.com",
"name": "Test User Google",
"picture": None,
}
google_oauth = await user_oauth_repository.create(google_oauth_data)
# Create GitHub OAuth for the same user
github_oauth_data = {
"user_id": test_user_id,
"provider": "github",
"provider_user_id": "github_user_1",
"email": "user@github.com",
"name": "Test User GitHub",
"picture": None,
}
github_oauth = await user_oauth_repository.create(github_oauth_data)
# Verify both exist by querying back from database
found_google = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
)
found_github = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
)
assert found_google is not None
assert found_github is not None
assert found_google.provider == "google"
assert found_github.provider == "github"
assert found_google.user_id == test_user_id
assert found_github.user_id == test_user_id
assert found_google.provider_user_id == "google_user_1"
assert found_github.provider_user_id == "github_user_1"
# Verify we can also find them by provider_user_id
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
"google", "google_user_1"
)
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
"github", "github_user_1"
)
assert found_google_by_provider is not None
assert found_github_by_provider is not None
assert found_google_by_provider.user_id == test_user_id
assert found_github_by_provider.user_id == test_user_id