- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities. - Create tests for UserOauthRepository to validate OAuth record management. - Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits. - Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
376 lines
13 KiB
Python
376 lines
13 KiB
Python
"""Tests for sound repository."""
|
|
|
|
from collections.abc import AsyncGenerator
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.models.sound import Sound
|
|
from app.repositories.sound import SoundRepository
|
|
|
|
|
|
class TestSoundRepository:
|
|
"""Test sound repository operations."""
|
|
|
|
@pytest_asyncio.fixture
|
|
async def sound_repository(
|
|
self,
|
|
test_session: AsyncSession,
|
|
) -> AsyncGenerator[SoundRepository, None]: # type: ignore[misc]
|
|
"""Create a sound repository instance."""
|
|
yield SoundRepository(test_session)
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_sound(
|
|
self,
|
|
test_session: AsyncSession,
|
|
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
|
|
"""Create a test sound."""
|
|
sound_data = {
|
|
"name": "Test Sound",
|
|
"filename": "test_sound.mp3",
|
|
"type": "SDB",
|
|
"duration": 5000,
|
|
"size": 1024000,
|
|
"hash": "test_hash_123",
|
|
"play_count": 0,
|
|
"is_normalized": False,
|
|
}
|
|
sound = Sound(**sound_data)
|
|
test_session.add(sound)
|
|
await test_session.commit()
|
|
await test_session.refresh(sound)
|
|
yield sound
|
|
|
|
@pytest_asyncio.fixture
|
|
async def normalized_sound(
|
|
self,
|
|
test_session: AsyncSession,
|
|
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
|
|
"""Create a normalized test sound."""
|
|
sound_data = {
|
|
"name": "Normalized Sound",
|
|
"filename": "normalized_sound.mp3",
|
|
"type": "TTS",
|
|
"duration": 3000,
|
|
"size": 512000,
|
|
"hash": "normalized_hash_456",
|
|
"play_count": 5,
|
|
"is_normalized": True,
|
|
"normalized_filename": "normalized_sound_norm.mp3",
|
|
"normalized_duration": 3000,
|
|
"normalized_size": 480000,
|
|
"normalized_hash": "normalized_hash_norm_456",
|
|
}
|
|
sound = Sound(**sound_data)
|
|
test_session.add(sound)
|
|
await test_session.commit()
|
|
await test_session.refresh(sound)
|
|
yield sound
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_id_existing(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
) -> None:
|
|
"""Test getting sound by ID when it exists."""
|
|
sound = await sound_repository.get_by_id(test_sound.id)
|
|
|
|
assert sound is not None
|
|
assert sound.id == test_sound.id
|
|
assert sound.name == test_sound.name
|
|
assert sound.filename == test_sound.filename
|
|
assert sound.type == test_sound.type
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_id_nonexistent(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
) -> None:
|
|
"""Test getting sound by ID when it doesn't exist."""
|
|
sound = await sound_repository.get_by_id(99999)
|
|
|
|
assert sound is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_filename_existing(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
) -> None:
|
|
"""Test getting sound by filename when it exists."""
|
|
sound = await sound_repository.get_by_filename(test_sound.filename)
|
|
|
|
assert sound is not None
|
|
assert sound.id == test_sound.id
|
|
assert sound.filename == test_sound.filename
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_filename_nonexistent(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
) -> None:
|
|
"""Test getting sound by filename when it doesn't exist."""
|
|
sound = await sound_repository.get_by_filename("nonexistent.mp3")
|
|
|
|
assert sound is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_hash_existing(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
) -> None:
|
|
"""Test getting sound by hash when it exists."""
|
|
sound = await sound_repository.get_by_hash(test_sound.hash)
|
|
|
|
assert sound is not None
|
|
assert sound.id == test_sound.id
|
|
assert sound.hash == test_sound.hash
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_hash_nonexistent(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
) -> None:
|
|
"""Test getting sound by hash when it doesn't exist."""
|
|
sound = await sound_repository.get_by_hash("nonexistent_hash")
|
|
|
|
assert sound is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_type(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
normalized_sound: Sound,
|
|
) -> None:
|
|
"""Test getting sounds by type."""
|
|
sdb_sounds = await sound_repository.get_by_type("SDB")
|
|
tts_sounds = await sound_repository.get_by_type("TTS")
|
|
ext_sounds = await sound_repository.get_by_type("EXT")
|
|
|
|
# Should find the SDB sound
|
|
assert len(sdb_sounds) >= 1
|
|
assert any(sound.id == test_sound.id for sound in sdb_sounds)
|
|
|
|
# Should find the TTS sound
|
|
assert len(tts_sounds) >= 1
|
|
assert any(sound.id == normalized_sound.id for sound in tts_sounds)
|
|
|
|
# Should not find any EXT sounds
|
|
assert len(ext_sounds) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_sound(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
) -> None:
|
|
"""Test creating a new sound."""
|
|
sound_data = {
|
|
"name": "New Sound",
|
|
"filename": "new_sound.wav",
|
|
"type": "EXT",
|
|
"duration": 7500,
|
|
"size": 2048000,
|
|
"hash": "new_hash_789",
|
|
"play_count": 0,
|
|
"is_normalized": False,
|
|
}
|
|
|
|
sound = await sound_repository.create(sound_data)
|
|
|
|
assert sound.id is not None
|
|
assert sound.name == sound_data["name"]
|
|
assert sound.filename == sound_data["filename"]
|
|
assert sound.type == sound_data["type"]
|
|
assert sound.duration == sound_data["duration"]
|
|
assert sound.size == sound_data["size"]
|
|
assert sound.hash == sound_data["hash"]
|
|
assert sound.play_count == 0
|
|
assert sound.is_normalized is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_sound(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
) -> None:
|
|
"""Test updating a sound."""
|
|
update_data = {
|
|
"name": "Updated Sound Name",
|
|
"play_count": 10,
|
|
"is_normalized": True,
|
|
"normalized_filename": "updated_norm.mp3",
|
|
}
|
|
|
|
updated_sound = await sound_repository.update(test_sound, update_data)
|
|
|
|
assert updated_sound.id == test_sound.id
|
|
assert updated_sound.name == "Updated Sound Name"
|
|
assert updated_sound.play_count == 10
|
|
assert updated_sound.is_normalized is True
|
|
assert updated_sound.normalized_filename == "updated_norm.mp3"
|
|
assert updated_sound.filename == test_sound.filename # Unchanged
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_sound(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_session: AsyncSession,
|
|
) -> None:
|
|
"""Test deleting a sound."""
|
|
# Create a sound to delete
|
|
sound_data = {
|
|
"name": "To Delete",
|
|
"filename": "to_delete.mp3",
|
|
"type": "SDB",
|
|
"duration": 1000,
|
|
"size": 256000,
|
|
"hash": "delete_hash",
|
|
"play_count": 0,
|
|
"is_normalized": False,
|
|
}
|
|
sound = await sound_repository.create(sound_data)
|
|
sound_id = sound.id
|
|
|
|
# Delete the sound
|
|
await sound_repository.delete(sound)
|
|
|
|
# Verify sound is deleted
|
|
deleted_sound = await sound_repository.get_by_id(sound_id)
|
|
assert deleted_sound is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_by_name(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
normalized_sound: Sound,
|
|
) -> None:
|
|
"""Test searching sounds by name."""
|
|
# Search for "test" should find test_sound
|
|
results = await sound_repository.search_by_name("test")
|
|
assert len(results) >= 1
|
|
assert any(sound.id == test_sound.id for sound in results)
|
|
|
|
# Search for "normalized" should find normalized_sound
|
|
results = await sound_repository.search_by_name("normalized")
|
|
assert len(results) >= 1
|
|
assert any(sound.id == normalized_sound.id for sound in results)
|
|
|
|
# Case insensitive search
|
|
results = await sound_repository.search_by_name("TEST")
|
|
assert len(results) >= 1
|
|
assert any(sound.id == test_sound.id for sound in results)
|
|
|
|
# Partial match
|
|
results = await sound_repository.search_by_name("norm")
|
|
assert len(results) >= 1
|
|
assert any(sound.id == normalized_sound.id for sound in results)
|
|
|
|
# No matches
|
|
results = await sound_repository.search_by_name("nonexistent")
|
|
assert len(results) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_popular_sounds(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
normalized_sound: Sound,
|
|
) -> None:
|
|
"""Test getting popular sounds."""
|
|
# Update play counts to test ordering
|
|
await sound_repository.update(test_sound, {"play_count": 15})
|
|
await sound_repository.update(normalized_sound, {"play_count": 5})
|
|
|
|
# Create another sound with higher play count
|
|
high_play_sound_data = {
|
|
"name": "Popular Sound",
|
|
"filename": "popular.mp3",
|
|
"type": "SDB",
|
|
"duration": 2000,
|
|
"size": 300000,
|
|
"hash": "popular_hash",
|
|
"play_count": 25,
|
|
"is_normalized": False,
|
|
}
|
|
high_play_sound = await sound_repository.create(high_play_sound_data)
|
|
|
|
# Get popular sounds
|
|
popular_sounds = await sound_repository.get_popular_sounds(limit=10)
|
|
|
|
assert len(popular_sounds) >= 3
|
|
# Should be ordered by play_count desc
|
|
assert popular_sounds[0].play_count >= popular_sounds[1].play_count
|
|
# The highest play count sound should be first
|
|
assert popular_sounds[0].id == high_play_sound.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_unnormalized_sounds(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
normalized_sound: Sound,
|
|
) -> None:
|
|
"""Test getting unnormalized sounds."""
|
|
unnormalized_sounds = await sound_repository.get_unnormalized_sounds()
|
|
|
|
# Should include test_sound (not normalized)
|
|
assert any(sound.id == test_sound.id for sound in unnormalized_sounds)
|
|
# Should not include normalized_sound (already normalized)
|
|
assert not any(sound.id == normalized_sound.id for sound in unnormalized_sounds)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_unnormalized_sounds_by_type(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
normalized_sound: Sound,
|
|
) -> None:
|
|
"""Test getting unnormalized sounds by type."""
|
|
# Get unnormalized SDB sounds
|
|
sdb_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("SDB")
|
|
# Should include test_sound (SDB, not normalized)
|
|
assert any(sound.id == test_sound.id for sound in sdb_unnormalized)
|
|
|
|
# Get unnormalized TTS sounds
|
|
tts_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("TTS")
|
|
# Should not include normalized_sound (TTS, but already normalized)
|
|
assert not any(sound.id == normalized_sound.id for sound in tts_unnormalized)
|
|
|
|
# Get unnormalized EXT sounds
|
|
ext_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("EXT")
|
|
# Should be empty
|
|
assert len(ext_unnormalized) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_duplicate_hash(
|
|
self,
|
|
sound_repository: SoundRepository,
|
|
test_sound: Sound,
|
|
) -> None:
|
|
"""Test creating sound with duplicate hash is allowed."""
|
|
# Store the hash to avoid lazy loading issues
|
|
original_hash = test_sound.hash
|
|
|
|
duplicate_sound_data = {
|
|
"name": "Duplicate Hash Sound",
|
|
"filename": "duplicate.mp3",
|
|
"type": "SDB",
|
|
"duration": 1000,
|
|
"size": 100000,
|
|
"hash": original_hash, # Same hash as test_sound
|
|
"play_count": 0,
|
|
"is_normalized": False,
|
|
}
|
|
|
|
# Should succeed - duplicate hashes are allowed
|
|
duplicate_sound = await sound_repository.create(duplicate_sound_data)
|
|
|
|
assert duplicate_sound.id is not None
|
|
assert duplicate_sound.name == "Duplicate Hash Sound"
|
|
assert duplicate_sound.hash == original_hash # Same hash is allowed |