Compare commits

...

2 Commits

Author SHA1 Message Date
JSC
e69098d633 refactor: Update player seek functionality to use consistent position field across schemas and services
All checks were successful
Backend CI / test (push) Successful in 4m5s
2025-07-31 21:33:00 +02:00
JSC
3405d817d5 refactor: Simplify repository classes by inheriting from BaseRepository and removing redundant methods 2025-07-31 21:32:46 +02:00
13 changed files with 80 additions and 319 deletions

View File

@@ -137,10 +137,10 @@ async def seek(
"""Seek to specific position in current track.""" """Seek to specific position in current track."""
try: try:
player = get_player_service() player = get_player_service()
await player.seek(request.position_ms) await player.seek(request.position)
return MessageResponse(message=f"Seeked to position {request.position_ms}ms") return MessageResponse(message=f"Seeked to position {request.position}ms")
except Exception as e: except Exception as e:
logger.exception("Error seeking to position %s", request.position_ms) logger.exception("Error seeking to position %s", request.position)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to seek", detail="Failed to seek",

View File

@@ -38,7 +38,7 @@ class BaseRepository(Generic[ModelType]):
""" """
try: try:
statement = select(self.model).where(getattr(self.model, "id") == entity_id) statement = select(self.model).where(self.model.id == entity_id)
result = await self.session.exec(statement) result = await self.session.exec(statement)
return result.first() return result.first()
except Exception: except Exception:

View File

@@ -1,42 +1,29 @@
"""Extraction repository for database operations.""" """Extraction repository for database operations."""
from sqlalchemy import desc from sqlalchemy import desc
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.extraction import Extraction from app.models.extraction import Extraction
from app.repositories.base import BaseRepository
class ExtractionRepository: class ExtractionRepository(BaseRepository[Extraction]):
"""Repository for extraction database operations.""" """Repository for extraction database operations."""
def __init__(self, session: AsyncSession) -> None: def __init__(self, session: AsyncSession) -> None:
"""Initialize the extraction repository.""" """Initialize the extraction repository."""
self.session = session super().__init__(Extraction, session)
async def create(self, extraction_data: dict) -> Extraction:
"""Create a new extraction."""
extraction = Extraction(**extraction_data)
self.session.add(extraction)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def get_by_id(self, extraction_id: int) -> Extraction | None:
"""Get an extraction by ID."""
result = await self.session.exec(
select(Extraction).where(Extraction.id == extraction_id)
)
return result.first()
async def get_by_service_and_id( async def get_by_service_and_id(
self, service: str, service_id: str self, service: str, service_id: str,
) -> Extraction | None: ) -> Extraction | None:
"""Get an extraction by service and service_id.""" """Get an extraction by service and service_id."""
result = await self.session.exec( result = await self.session.exec(
select(Extraction).where( select(Extraction).where(
Extraction.service == service, Extraction.service_id == service_id Extraction.service == service, Extraction.service_id == service_id,
) ),
) )
return result.first() return result.first()
@@ -45,7 +32,7 @@ class ExtractionRepository:
result = await self.session.exec( result = await self.session.exec(
select(Extraction) select(Extraction)
.where(Extraction.user_id == user_id) .where(Extraction.user_id == user_id)
.order_by(desc(Extraction.created_at)) .order_by(desc(Extraction.created_at)),
) )
return list(result.all()) return list(result.all())
@@ -54,29 +41,15 @@ class ExtractionRepository:
result = await self.session.exec( result = await self.session.exec(
select(Extraction) select(Extraction)
.where(Extraction.status == "pending") .where(Extraction.status == "pending")
.order_by(Extraction.created_at) .order_by(Extraction.created_at),
) )
return list(result.all()) return list(result.all())
async def update(self, extraction: Extraction, update_data: dict) -> Extraction:
"""Update an extraction."""
for key, value in update_data.items():
setattr(extraction, key, value)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def delete(self, extraction: Extraction) -> None:
"""Delete an extraction."""
await self.session.delete(extraction)
await self.session.commit()
async def get_extractions_by_status(self, status: str) -> list[Extraction]: async def get_extractions_by_status(self, status: str) -> list[Extraction]:
"""Get extractions by status.""" """Get extractions by status."""
result = await self.session.exec( result = await self.session.exec(
select(Extraction) select(Extraction)
.where(Extraction.status == status) .where(Extraction.status == status)
.order_by(desc(Extraction.created_at)) .order_by(desc(Extraction.created_at)),
) )
return list(result.all()) return list(result.all())

View File

@@ -1,6 +1,5 @@
"""Playlist repository for database operations.""" """Playlist repository for database operations."""
from typing import Any
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import select from sqlmodel import select
@@ -10,26 +9,17 @@ from app.core.logging import get_logger
from app.models.playlist import Playlist from app.models.playlist import Playlist
from app.models.playlist_sound import PlaylistSound from app.models.playlist_sound import PlaylistSound
from app.models.sound import Sound from app.models.sound import Sound
from app.repositories.base import BaseRepository
logger = get_logger(__name__) logger = get_logger(__name__)
class PlaylistRepository: class PlaylistRepository(BaseRepository[Playlist]):
"""Repository for playlist operations.""" """Repository for playlist operations."""
def __init__(self, session: AsyncSession) -> None: def __init__(self, session: AsyncSession) -> None:
"""Initialize the playlist repository.""" """Initialize the playlist repository."""
self.session = session super().__init__(Playlist, session)
async def get_by_id(self, playlist_id: int) -> Playlist | None:
"""Get a playlist by ID."""
try:
statement = select(Playlist).where(Playlist.id == playlist_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get playlist by ID: %s", playlist_id)
raise
async def get_by_name(self, name: str) -> Playlist | None: async def get_by_name(self, name: str) -> Playlist | None:
"""Get a playlist by name.""" """Get a playlist by name."""
@@ -51,16 +41,6 @@ class PlaylistRepository:
logger.exception("Failed to get playlists for user: %s", user_id) logger.exception("Failed to get playlists for user: %s", user_id)
raise raise
async def get_all(self) -> list[Playlist]:
"""Get all playlists from all users."""
try:
statement = select(Playlist)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all playlists")
raise
async def get_main_playlist(self) -> Playlist | None: async def get_main_playlist(self) -> Playlist | None:
"""Get the global main playlist.""" """Get the global main playlist."""
try: try:
@@ -86,50 +66,8 @@ class PlaylistRepository:
logger.exception("Failed to get current playlist for user: %s", user_id) logger.exception("Failed to get current playlist for user: %s", user_id)
raise raise
async def create(self, playlist_data: dict[str, Any]) -> Playlist:
"""Create a new playlist."""
try:
playlist = Playlist(**playlist_data)
self.session.add(playlist)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to create playlist")
raise
else:
logger.info("Created new playlist: %s", playlist.name)
return playlist
async def update(self, playlist: Playlist, update_data: dict[str, Any]) -> Playlist:
"""Update a playlist."""
try:
for field, value in update_data.items():
setattr(playlist, field, value)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to update playlist")
raise
else:
logger.info("Updated playlist: %s", playlist.name)
return playlist
async def delete(self, playlist: Playlist) -> None:
"""Delete a playlist."""
try:
await self.session.delete(playlist)
await self.session.commit()
logger.info("Deleted playlist: %s", playlist.name)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete playlist")
raise
async def search_by_name( async def search_by_name(
self, query: str, user_id: int | None = None self, query: str, user_id: int | None = None,
) -> list[Playlist]: ) -> list[Playlist]:
"""Search playlists by name (case-insensitive).""" """Search playlists by name (case-insensitive)."""
try: try:
@@ -161,14 +99,14 @@ class PlaylistRepository:
raise raise
async def add_sound_to_playlist( async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, position: int | None = None self, playlist_id: int, sound_id: int, position: int | None = None,
) -> PlaylistSound: ) -> PlaylistSound:
"""Add a sound to a playlist.""" """Add a sound to a playlist."""
try: try:
if position is None: if position is None:
# Get the next available position # Get the next available position
statement = select( statement = select(
func.coalesce(func.max(PlaylistSound.position), -1) + 1 func.coalesce(func.max(PlaylistSound.position), -1) + 1,
).where(PlaylistSound.playlist_id == playlist_id) ).where(PlaylistSound.playlist_id == playlist_id)
result = await self.session.exec(statement) result = await self.session.exec(statement)
position = result.first() or 0 position = result.first() or 0
@@ -184,7 +122,7 @@ class PlaylistRepository:
except Exception: except Exception:
await self.session.rollback() await self.session.rollback()
logger.exception( logger.exception(
"Failed to add sound %s to playlist %s", sound_id, playlist_id "Failed to add sound %s to playlist %s", sound_id, playlist_id,
) )
raise raise
else: else:
@@ -213,18 +151,19 @@ class PlaylistRepository:
except Exception: except Exception:
await self.session.rollback() await self.session.rollback()
logger.exception( logger.exception(
"Failed to remove sound %s from playlist %s", sound_id, playlist_id "Failed to remove sound %s from playlist %s", sound_id, playlist_id,
) )
raise raise
async def reorder_playlist_sounds( async def reorder_playlist_sounds(
self, playlist_id: int, sound_positions: list[tuple[int, int]] self, playlist_id: int, sound_positions: list[tuple[int, int]],
) -> None: ) -> None:
"""Reorder sounds in a playlist. """Reorder sounds in a playlist.
Args: Args:
playlist_id: The playlist ID playlist_id: The playlist ID
sound_positions: List of (sound_id, new_position) tuples sound_positions: List of (sound_id, new_position) tuples
""" """
try: try:
for sound_id, new_position in sound_positions: for sound_id, new_position in sound_positions:
@@ -249,7 +188,7 @@ class PlaylistRepository:
"""Get the number of sounds in a playlist.""" """Get the number of sounds in a playlist."""
try: try:
statement = select(func.count(PlaylistSound.id)).where( statement = select(func.count(PlaylistSound.id)).where(
PlaylistSound.playlist_id == playlist_id PlaylistSound.playlist_id == playlist_id,
) )
result = await self.session.exec(statement) result = await self.session.exec(statement)
return result.first() or 0 return result.first() or 0
@@ -268,6 +207,6 @@ class PlaylistRepository:
return result.first() is not None return result.first() is not None
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id "Failed to check if sound %s is in playlist %s", sound_id, playlist_id,
) )
raise raise

View File

@@ -1,33 +1,22 @@
"""Sound repository for database operations.""" """Sound repository for database operations."""
from typing import Any from sqlalchemy import func
from sqlalchemy import desc, func
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.sound import Sound from app.models.sound import Sound
from app.repositories.base import BaseRepository
logger = get_logger(__name__) logger = get_logger(__name__)
class SoundRepository: class SoundRepository(BaseRepository[Sound]):
"""Repository for sound operations.""" """Repository for sound operations."""
def __init__(self, session: AsyncSession) -> None: def __init__(self, session: AsyncSession) -> None:
"""Initialize the sound repository.""" """Initialize the sound repository."""
self.session = session super().__init__(Sound, session)
async def get_by_id(self, sound_id: int) -> Sound | None:
"""Get a sound by ID."""
try:
statement = select(Sound).where(Sound.id == sound_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get sound by ID: %s", sound_id)
raise
async def get_by_filename(self, filename: str) -> Sound | None: async def get_by_filename(self, filename: str) -> Sound | None:
"""Get a sound by filename.""" """Get a sound by filename."""
@@ -59,48 +48,6 @@ class SoundRepository:
logger.exception("Failed to get sounds by type: %s", sound_type) logger.exception("Failed to get sounds by type: %s", sound_type)
raise raise
async def create(self, sound_data: dict[str, Any]) -> Sound:
"""Create a new sound."""
try:
sound = Sound(**sound_data)
self.session.add(sound)
await self.session.commit()
await self.session.refresh(sound)
except Exception:
await self.session.rollback()
logger.exception("Failed to create sound")
raise
else:
logger.info("Created new sound: %s", sound.name)
return sound
async def update(self, sound: Sound, update_data: dict[str, Any]) -> Sound:
"""Update a sound."""
try:
for field, value in update_data.items():
setattr(sound, field, value)
await self.session.commit()
await self.session.refresh(sound)
except Exception:
await self.session.rollback()
logger.exception("Failed to update sound")
raise
else:
logger.info("Updated sound: %s", sound.name)
return sound
async def delete(self, sound: Sound) -> None:
"""Delete a sound."""
try:
await self.session.delete(sound)
await self.session.commit()
logger.info("Deleted sound: %s", sound.name)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete sound")
raise
async def search_by_name(self, query: str) -> list[Sound]: async def search_by_name(self, query: str) -> list[Sound]:
"""Search sounds by name (case-insensitive).""" """Search sounds by name (case-insensitive)."""
try: try:
@@ -144,6 +91,6 @@ class SoundRepository:
return list(result.all()) return list(result.all())
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to get unnormalized sounds by type: %s", sound_type "Failed to get unnormalized sounds by type: %s", sound_type,
) )
raise raise

View File

@@ -8,26 +8,17 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.plan import Plan from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.repositories.base import BaseRepository
logger = get_logger(__name__) logger = get_logger(__name__)
class UserRepository: class UserRepository(BaseRepository[User]):
"""Repository for user operations.""" """Repository for user operations."""
def __init__(self, session: AsyncSession) -> None: def __init__(self, session: AsyncSession) -> None:
"""Initialize the user repository.""" """Initialize the user repository."""
self.session = session super().__init__(User, session)
async def get_by_id(self, user_id: int) -> User | None:
"""Get a user by ID."""
try:
statement = select(User).where(User.id == user_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by ID: %s", user_id)
raise
async def get_by_email(self, email: str) -> User | None: async def get_by_email(self, email: str) -> User | None:
"""Get a user by email address.""" """Get a user by email address."""
@@ -50,7 +41,7 @@ class UserRepository:
raise raise
async def create(self, user_data: dict[str, Any]) -> User: async def create(self, user_data: dict[str, Any]) -> User:
"""Create a new user.""" """Create a new user with plan assignment and first user admin logic."""
def _raise_plan_not_found() -> None: def _raise_plan_not_found() -> None:
msg = "Default plan not found" msg = "Default plan not found"
@@ -84,45 +75,11 @@ class UserRepository:
user_data["plan_id"] = default_plan.id user_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits user_data["credits"] = default_plan.credits
user = User(**user_data) # Use BaseRepository's create method
self.session.add(user) return await super().create(user_data)
await self.session.commit()
await self.session.refresh(user)
except Exception: except Exception:
await self.session.rollback()
logger.exception("Failed to create user") logger.exception("Failed to create user")
raise raise
else:
logger.info("Created new user with email: %s", user.email)
return user
async def update(self, user: User, update_data: dict[str, Any]) -> User:
"""Update a user."""
try:
for field, value in update_data.items():
setattr(user, field, value)
await self.session.commit()
await self.session.refresh(user)
except Exception:
await self.session.rollback()
logger.exception("Failed to update user")
raise
else:
logger.info("Updated user: %s", user.email)
return user
async def delete(self, user: User) -> None:
"""Delete a user."""
try:
await self.session.delete(user)
await self.session.commit()
logger.info("Deleted user: %s", user.email)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete user")
raise
async def email_exists(self, email: str) -> bool: async def email_exists(self, email: str) -> bool:
"""Check if an email address is already registered.""" """Check if an email address is already registered."""

View File

@@ -1,22 +1,22 @@
"""Repository for user OAuth operations.""" """Repository for user OAuth operations."""
from typing import Any
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.user_oauth import UserOauth from app.models.user_oauth import UserOauth
from app.repositories.base import BaseRepository
logger = get_logger(__name__) logger = get_logger(__name__)
class UserOauthRepository: class UserOauthRepository(BaseRepository[UserOauth]):
"""Repository for user OAuth operations.""" """Repository for user OAuth operations."""
def __init__(self, session: AsyncSession) -> None: def __init__(self, session: AsyncSession) -> None:
"""Initialize repository with database session.""" """Initialize repository with database session."""
self.session = session super().__init__(UserOauth, session)
async def get_by_provider_user_id( async def get_by_provider_user_id(
self, self,
@@ -61,57 +61,3 @@ class UserOauthRepository:
else: else:
return result.first() return result.first()
async def create(self, oauth_data: dict[str, Any]) -> UserOauth:
"""Create a new user OAuth record."""
try:
oauth = UserOauth(**oauth_data)
self.session.add(oauth)
await self.session.commit()
await self.session.refresh(oauth)
logger.info(
"Created OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to create user OAuth")
raise
else:
return oauth
async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth:
"""Update a user OAuth record."""
try:
for key, value in update_data.items():
setattr(oauth, key, value)
self.session.add(oauth)
await self.session.commit()
await self.session.refresh(oauth)
logger.info(
"Updated OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to update user OAuth")
raise
else:
return oauth
async def delete(self, oauth: UserOauth) -> None:
"""Delete a user OAuth record."""
try:
await self.session.delete(oauth)
await self.session.commit()
logger.info(
"Deleted OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete user OAuth")
raise

View File

@@ -10,7 +10,7 @@ from app.services.player import PlayerMode
class PlayerSeekRequest(BaseModel): class PlayerSeekRequest(BaseModel):
"""Request model for seek operation.""" """Request model for seek operation."""
position_ms: int = Field(ge=0, description="Position in milliseconds") position: int = Field(ge=0, description="Position in milliseconds")
class PlayerVolumeRequest(BaseModel): class PlayerVolumeRequest(BaseModel):
@@ -35,8 +35,8 @@ class PlayerStateResponse(BaseModel):
playlist: dict[str, Any] | None = Field( playlist: dict[str, Any] | None = Field(
None, description="Current playlist information" None, description="Current playlist information"
) )
position_ms: int = Field(description="Current position in milliseconds") position: int = Field(description="Current position in milliseconds")
duration_ms: int | None = Field( duration: int | None = Field(
None, description="Total duration in milliseconds", None, description="Total duration in milliseconds",
) )
volume: int = Field(description="Current volume (0-100)") volume: int = Field(description="Current volume (0-100)")

View File

@@ -64,8 +64,8 @@ class PlayerState:
"status": self.status.value, "status": self.status.value,
"mode": self.mode.value, "mode": self.mode.value,
"volume": self.volume, "volume": self.volume,
"position_ms": self.current_sound_position or 0, "position": self.current_sound_position or 0,
"duration_ms": self.current_sound_duration, "duration": self.current_sound_duration,
"index": self.current_sound_index, "index": self.current_sound_index,
"current_sound": self._serialize_sound(self.current_sound), "current_sound": self._serialize_sound(self.current_sound),
"playlist": { "playlist": {

View File

@@ -278,17 +278,17 @@ class TestPlayerEndpoints:
mock_player_service, mock_player_service,
): ):
"""Test seeking to position successfully.""" """Test seeking to position successfully."""
position_ms = 5000 position = 5000
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/player/seek", "/api/v1/player/seek",
json={"position_ms": position_ms}, json={"position": position},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["message"] == f"Seeked to position {position_ms}ms" assert data["message"] == f"Seeked to position {position}ms"
mock_player_service.seek.assert_called_once_with(position_ms) mock_player_service.seek.assert_called_once_with(position)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_seek_invalid_position( async def test_seek_invalid_position(
@@ -300,7 +300,7 @@ class TestPlayerEndpoints:
"""Test seeking with invalid position.""" """Test seeking with invalid position."""
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/player/seek", "/api/v1/player/seek",
json={"position_ms": -1000}, # Negative position json={"position": -1000}, # Negative position
) )
assert response.status_code == 422 # Validation error assert response.status_code == 422 # Validation error
@@ -310,7 +310,7 @@ class TestPlayerEndpoints:
"""Test seeking without authentication.""" """Test seeking without authentication."""
response = await client.post( response = await client.post(
"/api/v1/player/seek", "/api/v1/player/seek",
json={"position_ms": 5000}, json={"position": 5000},
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -326,7 +326,7 @@ class TestPlayerEndpoints:
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/player/seek", "/api/v1/player/seek",
json={"position_ms": 5000}, json={"position": 5000},
) )
assert response.status_code == 500 assert response.status_code == 500
@@ -516,8 +516,8 @@ class TestPlayerEndpoints:
"status": PlayerStatus.PLAYING.value, "status": PlayerStatus.PLAYING.value,
"mode": PlayerMode.CONTINUOUS.value, "mode": PlayerMode.CONTINUOUS.value,
"volume": 50, "volume": 50,
"position_ms": 5000, "position": 5000,
"duration_ms": 30000, "duration": 30000,
"index": 0, "index": 0,
"current_sound": { "current_sound": {
"id": 1, "id": 1,
@@ -625,7 +625,7 @@ class TestPlayerEndpoints:
"""Test seeking to position zero.""" """Test seeking to position zero."""
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/player/seek", "/api/v1/player/seek",
json={"position_ms": 0}, json={"position": 0},
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -39,24 +39,24 @@ class TestExtractionRepository:
} }
# Mock the session operations # Mock the session operations
mock_extraction = Extraction(**extraction_data, id=1)
extraction_repo.session.add = Mock() extraction_repo.session.add = Mock()
extraction_repo.session.commit = AsyncMock() extraction_repo.session.commit = AsyncMock()
extraction_repo.session.refresh = AsyncMock() extraction_repo.session.refresh = AsyncMock()
# Mock the Extraction constructor to return our mock
with pytest.MonkeyPatch().context() as m:
m.setattr(
"app.repositories.extraction.Extraction",
lambda **kwargs: mock_extraction,
)
result = await extraction_repo.create(extraction_data) result = await extraction_repo.create(extraction_data)
assert result == mock_extraction # Verify the result has the expected attributes
assert result.url == extraction_data["url"]
assert result.user_id == extraction_data["user_id"]
assert result.service == extraction_data["service"]
assert result.service_id == extraction_data["service_id"]
assert result.title == extraction_data["title"]
assert result.status == extraction_data["status"]
# Verify session methods were called
extraction_repo.session.add.assert_called_once() extraction_repo.session.add.assert_called_once()
extraction_repo.session.commit.assert_called_once() extraction_repo.session.commit.assert_called_once()
extraction_repo.session.refresh.assert_called_once_with(mock_extraction) extraction_repo.session.refresh.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_by_service_and_id(self, extraction_repo): async def test_get_by_service_and_id(self, extraction_repo):

View File

@@ -65,14 +65,13 @@ class TestPlayerState:
assert result["status"] == "playing" assert result["status"] == "playing"
assert result["mode"] == "loop" assert result["mode"] == "loop"
assert result["volume"] == 75 assert result["volume"] == 75
assert result["current_sound_id"] == 1 assert result["position"] == 5000
assert result["current_sound_index"] == 0 assert result["duration"] == 30000
assert result["current_sound_position"] == 5000 assert result["index"] == 0
assert result["current_sound_duration"] == 30000 assert result["playlist"]["id"] == 1
assert result["playlist_id"] == 1 assert result["playlist"]["name"] == "Test Playlist"
assert result["playlist_name"] == "Test Playlist" assert result["playlist"]["length"] == 5
assert result["playlist_length"] == 5 assert result["playlist"]["duration"] == 150000
assert result["playlist_duration"] == 150000
def test_serialize_sound_with_sound_object(self): def test_serialize_sound_with_sound_object(self):
"""Test serializing a sound object.""" """Test serializing a sound object."""