Refactor test files for improved readability and consistency

- Removed unnecessary blank lines and adjusted formatting in test files.
- Ensured consistent use of commas in function calls and assertions across various test cases.
- Updated import statements for better organization and clarity.
- Enhanced mock setups in tests for better isolation and reliability.
- Improved assertions to follow a consistent style for better readability.
This commit is contained in:
JSC
2025-07-31 21:37:04 +02:00
parent e69098d633
commit 8847131f24
42 changed files with 602 additions and 616 deletions

View File

@@ -42,7 +42,7 @@ class TestCreditTransactionRepository:
"""Create test credit transactions."""
transactions = []
user_id = test_user_id
# Create various types of transactions
transaction_data = [
{
@@ -105,9 +105,8 @@ class TestCreditTransactionRepository:
ensure_plans: tuple[Any, ...], # noqa: ARG002
) -> AsyncGenerator[CreditTransaction, None]:
"""Create a transaction for a different user."""
from app.models.plan import Plan
from app.repositories.user import UserRepository
# Create another user
user_repo = UserRepository(test_session)
other_user_data = {
@@ -134,7 +133,7 @@ class TestCreditTransactionRepository:
test_session.add(transaction)
await test_session.commit()
await test_session.refresh(transaction)
yield transaction
@pytest.mark.asyncio
@@ -178,7 +177,7 @@ class TestCreditTransactionRepository:
assert len(transactions) == 4
# Should be ordered by created_at desc (newest first)
assert all(t.user_id == test_user_id for t in transactions)
# Should not include other user's transaction
other_user_ids = [t.user_id for t in transactions]
assert other_user_transaction.user_id not in other_user_ids
@@ -193,13 +192,13 @@ class TestCreditTransactionRepository:
"""Test getting transactions by user ID with pagination."""
# Get first 2 transactions
first_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=0
test_user_id, limit=2, offset=0,
)
assert len(first_page) == 2
# Get next 2 transactions
second_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=2
test_user_id, limit=2, offset=2,
)
assert len(second_page) == 2
@@ -216,17 +215,17 @@ class TestCreditTransactionRepository:
) -> None:
"""Test getting transactions by action type."""
vlc_transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound"
"vlc_play_sound",
)
# Should return 2 VLC transactions (1 successful, 1 failed)
assert len(vlc_transactions) >= 2
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
extraction_transactions = await credit_transaction_repository.get_by_action_type(
"audio_extraction"
"audio_extraction",
)
# Should return 1 extraction transaction
assert len(extraction_transactions) >= 1
assert all(t.action_type == "audio_extraction" for t in extraction_transactions)
@@ -240,14 +239,14 @@ class TestCreditTransactionRepository:
"""Test getting transactions by action type with pagination."""
# Test with limit
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1
"vlc_play_sound", limit=1,
)
assert len(transactions) == 1
assert transactions[0].action_type == "vlc_play_sound"
# Test with offset
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1, offset=1
"vlc_play_sound", limit=1, offset=1,
)
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
@@ -275,7 +274,7 @@ class TestCreditTransactionRepository:
) -> None:
"""Test getting successful transactions filtered by user."""
successful_transactions = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id
user_id=test_user_id,
)
# Should only return successful transactions for test_user
@@ -294,14 +293,14 @@ class TestCreditTransactionRepository:
"""Test getting successful transactions with pagination."""
# Get first 2 successful transactions
first_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=0
user_id=test_user_id, limit=2, offset=0,
)
assert len(first_page) == 2
assert all(t.success is True for t in first_page)
# Get next successful transaction
second_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=2
user_id=test_user_id, limit=2, offset=2,
)
assert len(second_page) == 1 # Should be 1 remaining
assert all(t.success is True for t in second_page)
@@ -363,7 +362,7 @@ class TestCreditTransactionRepository:
}
updated_transaction = await credit_transaction_repository.update(
transaction, update_data
transaction, update_data,
)
assert updated_transaction.id == transaction.id
@@ -413,7 +412,7 @@ class TestCreditTransactionRepository:
) -> None:
"""Test that transactions are ordered by created_at desc."""
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
# Should be ordered by created_at desc (newest first)
for i in range(len(transactions) - 1):
assert transactions[i].created_at >= transactions[i + 1].created_at
assert transactions[i].created_at >= transactions[i + 1].created_at

View File

@@ -52,7 +52,7 @@ class TestExtractionRepository:
assert result.service_id == extraction_data["service_id"]
assert result.title == extraction_data["title"]
assert result.status == extraction_data["status"]
# Verify session methods were called
extraction_repo.session.add.assert_called_once()
extraction_repo.session.commit.assert_called_once()

View File

@@ -151,10 +151,10 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Create test playlist for this user
playlist = Playlist(
user_id=user_id,
@@ -167,10 +167,10 @@ class TestPlaylistRepository:
)
test_session.add(playlist)
await test_session.commit()
# Test the repository method
playlists = await playlist_repository.get_by_user_id(user_id)
# Should only return user's playlists, not the main playlist (user_id=None)
assert len(playlists) == 1
assert playlists[0].name == "Test Playlist"
@@ -194,13 +194,13 @@ class TestPlaylistRepository:
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(main_playlist)
# Extract ID before async call
main_playlist_id = main_playlist.id
# Test the repository method
playlist = await playlist_repository.get_main_playlist()
assert playlist is not None
assert playlist.id == main_playlist_id
assert playlist.is_main is True
@@ -227,13 +227,13 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Test the repository method - should return None when no current playlist
playlist = await playlist_repository.get_current_playlist(user_id)
# Should return None since no user playlist is marked as current
assert playlist is None
@@ -319,10 +319,10 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Create test playlist
test_playlist = Playlist(
user_id=user_id,
@@ -334,7 +334,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(test_playlist)
# Create main playlist
main_playlist = Playlist(
user_id=None,
@@ -346,7 +346,7 @@ class TestPlaylistRepository:
)
test_session.add(main_playlist)
await test_session.commit()
# Search for all playlists (no user filter)
all_results = await playlist_repository.search_by_name("playlist")
assert len(all_results) >= 2 # Should include both user and main playlists
@@ -382,7 +382,7 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Create test playlist
playlist = Playlist(
user_id=user.id,
@@ -394,7 +394,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
# Create test sound
sound = Sound(
name="Test Sound",
@@ -409,14 +409,14 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async call
playlist_id = playlist.id
sound_id = sound.id
# Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
assert playlist_sound.playlist_id == playlist_id
@@ -445,10 +445,10 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Create test playlist
playlist = Playlist(
user_id=user_id,
@@ -460,7 +460,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
# Create test sound
sound = Sound(
name="Test Sound",
@@ -475,14 +475,14 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async call
playlist_id = playlist.id
sound_id = sound.id
# Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id, position=5
playlist_id, sound_id, position=5,
)
assert playlist_sound.position == 5
@@ -509,9 +509,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -522,7 +522,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -536,7 +536,7 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -546,17 +546,17 @@ class TestPlaylistRepository:
# Verify it was added
assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
# Remove the sound
await playlist_repository.remove_sound_from_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
# Verify it was removed
assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
@pytest.mark.asyncio
@@ -581,9 +581,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -594,7 +594,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -608,7 +608,7 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -647,9 +647,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -660,7 +660,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -674,7 +674,7 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -712,9 +712,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -725,7 +725,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -739,14 +739,14 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
# Initially not in playlist
assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
# Add sound
@@ -754,7 +754,7 @@ class TestPlaylistRepository:
# Now in playlist
assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
@pytest.mark.asyncio
@@ -779,9 +779,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -801,7 +801,7 @@ class TestPlaylistRepository:
await test_session.refresh(playlist)
await test_session.refresh(sound1)
await test_session.refresh(sound2)
# Extract IDs before async calls
playlist_id = playlist.id
sound1_id = sound1.id
@@ -809,16 +809,16 @@ class TestPlaylistRepository:
# Add sounds to playlist
await playlist_repository.add_sound_to_playlist(
playlist_id, sound1_id, position=0
playlist_id, sound1_id, position=0,
)
await playlist_repository.add_sound_to_playlist(
playlist_id, sound2_id, position=1
playlist_id, sound2_id, position=1,
)
# Reorder sounds - use different positions to avoid constraint issues
sound_positions = [(sound1_id, 10), (sound2_id, 5)]
await playlist_repository.reorder_playlist_sounds(
playlist_id, sound_positions
playlist_id, sound_positions,
)
# Verify new order

View File

@@ -359,7 +359,7 @@ class TestSoundRepository:
"""Test creating sound with duplicate hash should fail."""
# Store the hash to avoid lazy loading issues
original_hash = test_sound.hash
duplicate_sound_data = {
"name": "Duplicate Hash Sound",
"filename": "duplicate.mp3",
@@ -373,4 +373,4 @@ class TestSoundRepository:
# Should fail due to unique constraint on hash
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
await sound_repository.create(duplicate_sound_data)
await sound_repository.create(duplicate_sound_data)

View File

@@ -60,7 +60,7 @@ class TestUserOauthRepository:
) -> None:
"""Test getting OAuth by provider user ID when it exists."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "google_123456"
"google", "google_123456",
)
assert oauth is not None
@@ -76,7 +76,7 @@ class TestUserOauthRepository:
) -> None:
"""Test getting OAuth by provider user ID when it doesn't exist."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "nonexistent_id"
"google", "nonexistent_id",
)
assert oauth is None
@@ -90,7 +90,7 @@ class TestUserOauthRepository:
) -> None:
"""Test getting OAuth by user ID and provider when it exists."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
test_user_id, "google",
)
assert oauth is not None
@@ -106,7 +106,7 @@ class TestUserOauthRepository:
) -> None:
"""Test getting OAuth by user ID and provider when it doesn't exist."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
test_user_id, "github",
)
assert oauth is None
@@ -183,7 +183,7 @@ class TestUserOauthRepository:
# Verify it's deleted by trying to find it
deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
"twitter", "twitter_456"
"twitter", "twitter_456",
)
assert deleted_oauth is None
@@ -240,10 +240,10 @@ class TestUserOauthRepository:
# Verify both exist by querying back from database
found_google = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
test_user_id, "google",
)
found_github = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
test_user_id, "github",
)
assert found_google is not None
@@ -257,13 +257,13 @@ class TestUserOauthRepository:
# Verify we can also find them by provider_user_id
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
"google", "google_user_1"
"google", "google_user_1",
)
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
"github", "github_user_1"
"github", "github_user_1",
)
assert found_google_by_provider is not None
assert found_github_by_provider is not None
assert found_google_by_provider.user_id == test_user_id
assert found_github_by_provider.user_id == test_user_id
assert found_github_by_provider.user_id == test_user_id