Compare commits
2 Commits
c63997f591
...
e69098d633
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e69098d633 | ||
|
|
3405d817d5 |
@@ -137,10 +137,10 @@ async def seek(
|
|||||||
"""Seek to specific position in current track."""
|
"""Seek to specific position in current track."""
|
||||||
try:
|
try:
|
||||||
player = get_player_service()
|
player = get_player_service()
|
||||||
await player.seek(request.position_ms)
|
await player.seek(request.position)
|
||||||
return MessageResponse(message=f"Seeked to position {request.position_ms}ms")
|
return MessageResponse(message=f"Seeked to position {request.position}ms")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error seeking to position %s", request.position_ms)
|
logger.exception("Error seeking to position %s", request.position)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to seek",
|
detail="Failed to seek",
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class BaseRepository(Generic[ModelType]):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
|
statement = select(self.model).where(self.model.id == entity_id)
|
||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
return result.first()
|
return result.first()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -1,42 +1,29 @@
|
|||||||
"""Extraction repository for database operations."""
|
"""Extraction repository for database operations."""
|
||||||
|
|
||||||
|
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import desc
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.models.extraction import Extraction
|
from app.models.extraction import Extraction
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
class ExtractionRepository:
|
class ExtractionRepository(BaseRepository[Extraction]):
|
||||||
"""Repository for extraction database operations."""
|
"""Repository for extraction database operations."""
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession) -> None:
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
"""Initialize the extraction repository."""
|
"""Initialize the extraction repository."""
|
||||||
self.session = session
|
super().__init__(Extraction, session)
|
||||||
|
|
||||||
async def create(self, extraction_data: dict) -> Extraction:
|
|
||||||
"""Create a new extraction."""
|
|
||||||
extraction = Extraction(**extraction_data)
|
|
||||||
self.session.add(extraction)
|
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(extraction)
|
|
||||||
return extraction
|
|
||||||
|
|
||||||
async def get_by_id(self, extraction_id: int) -> Extraction | None:
|
|
||||||
"""Get an extraction by ID."""
|
|
||||||
result = await self.session.exec(
|
|
||||||
select(Extraction).where(Extraction.id == extraction_id)
|
|
||||||
)
|
|
||||||
return result.first()
|
|
||||||
|
|
||||||
async def get_by_service_and_id(
|
async def get_by_service_and_id(
|
||||||
self, service: str, service_id: str
|
self, service: str, service_id: str,
|
||||||
) -> Extraction | None:
|
) -> Extraction | None:
|
||||||
"""Get an extraction by service and service_id."""
|
"""Get an extraction by service and service_id."""
|
||||||
result = await self.session.exec(
|
result = await self.session.exec(
|
||||||
select(Extraction).where(
|
select(Extraction).where(
|
||||||
Extraction.service == service, Extraction.service_id == service_id
|
Extraction.service == service, Extraction.service_id == service_id,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
return result.first()
|
return result.first()
|
||||||
|
|
||||||
@@ -45,7 +32,7 @@ class ExtractionRepository:
|
|||||||
result = await self.session.exec(
|
result = await self.session.exec(
|
||||||
select(Extraction)
|
select(Extraction)
|
||||||
.where(Extraction.user_id == user_id)
|
.where(Extraction.user_id == user_id)
|
||||||
.order_by(desc(Extraction.created_at))
|
.order_by(desc(Extraction.created_at)),
|
||||||
)
|
)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|
||||||
@@ -54,29 +41,15 @@ class ExtractionRepository:
|
|||||||
result = await self.session.exec(
|
result = await self.session.exec(
|
||||||
select(Extraction)
|
select(Extraction)
|
||||||
.where(Extraction.status == "pending")
|
.where(Extraction.status == "pending")
|
||||||
.order_by(Extraction.created_at)
|
.order_by(Extraction.created_at),
|
||||||
)
|
)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|
||||||
async def update(self, extraction: Extraction, update_data: dict) -> Extraction:
|
|
||||||
"""Update an extraction."""
|
|
||||||
for key, value in update_data.items():
|
|
||||||
setattr(extraction, key, value)
|
|
||||||
|
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(extraction)
|
|
||||||
return extraction
|
|
||||||
|
|
||||||
async def delete(self, extraction: Extraction) -> None:
|
|
||||||
"""Delete an extraction."""
|
|
||||||
await self.session.delete(extraction)
|
|
||||||
await self.session.commit()
|
|
||||||
|
|
||||||
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
|
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
|
||||||
"""Get extractions by status."""
|
"""Get extractions by status."""
|
||||||
result = await self.session.exec(
|
result = await self.session.exec(
|
||||||
select(Extraction)
|
select(Extraction)
|
||||||
.where(Extraction.status == status)
|
.where(Extraction.status == status)
|
||||||
.order_by(desc(Extraction.created_at))
|
.order_by(desc(Extraction.created_at)),
|
||||||
)
|
)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Playlist repository for database operations."""
|
"""Playlist repository for database operations."""
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
@@ -10,26 +9,17 @@ from app.core.logging import get_logger
|
|||||||
from app.models.playlist import Playlist
|
from app.models.playlist import Playlist
|
||||||
from app.models.playlist_sound import PlaylistSound
|
from app.models.playlist_sound import PlaylistSound
|
||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PlaylistRepository:
|
class PlaylistRepository(BaseRepository[Playlist]):
|
||||||
"""Repository for playlist operations."""
|
"""Repository for playlist operations."""
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession) -> None:
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
"""Initialize the playlist repository."""
|
"""Initialize the playlist repository."""
|
||||||
self.session = session
|
super().__init__(Playlist, session)
|
||||||
|
|
||||||
async def get_by_id(self, playlist_id: int) -> Playlist | None:
|
|
||||||
"""Get a playlist by ID."""
|
|
||||||
try:
|
|
||||||
statement = select(Playlist).where(Playlist.id == playlist_id)
|
|
||||||
result = await self.session.exec(statement)
|
|
||||||
return result.first()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get playlist by ID: %s", playlist_id)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_name(self, name: str) -> Playlist | None:
|
async def get_by_name(self, name: str) -> Playlist | None:
|
||||||
"""Get a playlist by name."""
|
"""Get a playlist by name."""
|
||||||
@@ -51,16 +41,6 @@ class PlaylistRepository:
|
|||||||
logger.exception("Failed to get playlists for user: %s", user_id)
|
logger.exception("Failed to get playlists for user: %s", user_id)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_all(self) -> list[Playlist]:
|
|
||||||
"""Get all playlists from all users."""
|
|
||||||
try:
|
|
||||||
statement = select(Playlist)
|
|
||||||
result = await self.session.exec(statement)
|
|
||||||
return list(result.all())
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get all playlists")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_main_playlist(self) -> Playlist | None:
|
async def get_main_playlist(self) -> Playlist | None:
|
||||||
"""Get the global main playlist."""
|
"""Get the global main playlist."""
|
||||||
try:
|
try:
|
||||||
@@ -86,50 +66,8 @@ class PlaylistRepository:
|
|||||||
logger.exception("Failed to get current playlist for user: %s", user_id)
|
logger.exception("Failed to get current playlist for user: %s", user_id)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def create(self, playlist_data: dict[str, Any]) -> Playlist:
|
|
||||||
"""Create a new playlist."""
|
|
||||||
try:
|
|
||||||
playlist = Playlist(**playlist_data)
|
|
||||||
self.session.add(playlist)
|
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(playlist)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to create playlist")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.info("Created new playlist: %s", playlist.name)
|
|
||||||
return playlist
|
|
||||||
|
|
||||||
async def update(self, playlist: Playlist, update_data: dict[str, Any]) -> Playlist:
|
|
||||||
"""Update a playlist."""
|
|
||||||
try:
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(playlist, field, value)
|
|
||||||
|
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(playlist)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to update playlist")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.info("Updated playlist: %s", playlist.name)
|
|
||||||
return playlist
|
|
||||||
|
|
||||||
async def delete(self, playlist: Playlist) -> None:
|
|
||||||
"""Delete a playlist."""
|
|
||||||
try:
|
|
||||||
await self.session.delete(playlist)
|
|
||||||
await self.session.commit()
|
|
||||||
logger.info("Deleted playlist: %s", playlist.name)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to delete playlist")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def search_by_name(
|
async def search_by_name(
|
||||||
self, query: str, user_id: int | None = None
|
self, query: str, user_id: int | None = None,
|
||||||
) -> list[Playlist]:
|
) -> list[Playlist]:
|
||||||
"""Search playlists by name (case-insensitive)."""
|
"""Search playlists by name (case-insensitive)."""
|
||||||
try:
|
try:
|
||||||
@@ -161,14 +99,14 @@ class PlaylistRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def add_sound_to_playlist(
|
async def add_sound_to_playlist(
|
||||||
self, playlist_id: int, sound_id: int, position: int | None = None
|
self, playlist_id: int, sound_id: int, position: int | None = None,
|
||||||
) -> PlaylistSound:
|
) -> PlaylistSound:
|
||||||
"""Add a sound to a playlist."""
|
"""Add a sound to a playlist."""
|
||||||
try:
|
try:
|
||||||
if position is None:
|
if position is None:
|
||||||
# Get the next available position
|
# Get the next available position
|
||||||
statement = select(
|
statement = select(
|
||||||
func.coalesce(func.max(PlaylistSound.position), -1) + 1
|
func.coalesce(func.max(PlaylistSound.position), -1) + 1,
|
||||||
).where(PlaylistSound.playlist_id == playlist_id)
|
).where(PlaylistSound.playlist_id == playlist_id)
|
||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
position = result.first() or 0
|
position = result.first() or 0
|
||||||
@@ -184,7 +122,7 @@ class PlaylistRepository:
|
|||||||
except Exception:
|
except Exception:
|
||||||
await self.session.rollback()
|
await self.session.rollback()
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to add sound %s to playlist %s", sound_id, playlist_id
|
"Failed to add sound %s to playlist %s", sound_id, playlist_id,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@@ -213,18 +151,19 @@ class PlaylistRepository:
|
|||||||
except Exception:
|
except Exception:
|
||||||
await self.session.rollback()
|
await self.session.rollback()
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to remove sound %s from playlist %s", sound_id, playlist_id
|
"Failed to remove sound %s from playlist %s", sound_id, playlist_id,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def reorder_playlist_sounds(
|
async def reorder_playlist_sounds(
|
||||||
self, playlist_id: int, sound_positions: list[tuple[int, int]]
|
self, playlist_id: int, sound_positions: list[tuple[int, int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Reorder sounds in a playlist.
|
"""Reorder sounds in a playlist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
playlist_id: The playlist ID
|
playlist_id: The playlist ID
|
||||||
sound_positions: List of (sound_id, new_position) tuples
|
sound_positions: List of (sound_id, new_position) tuples
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
for sound_id, new_position in sound_positions:
|
for sound_id, new_position in sound_positions:
|
||||||
@@ -249,7 +188,7 @@ class PlaylistRepository:
|
|||||||
"""Get the number of sounds in a playlist."""
|
"""Get the number of sounds in a playlist."""
|
||||||
try:
|
try:
|
||||||
statement = select(func.count(PlaylistSound.id)).where(
|
statement = select(func.count(PlaylistSound.id)).where(
|
||||||
PlaylistSound.playlist_id == playlist_id
|
PlaylistSound.playlist_id == playlist_id,
|
||||||
)
|
)
|
||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
return result.first() or 0
|
return result.first() or 0
|
||||||
@@ -268,6 +207,6 @@ class PlaylistRepository:
|
|||||||
return result.first() is not None
|
return result.first() is not None
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id
|
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,33 +1,22 @@
|
|||||||
"""Sound repository for database operations."""
|
"""Sound repository for database operations."""
|
||||||
|
|
||||||
from typing import Any
|
from sqlalchemy import func
|
||||||
|
|
||||||
from sqlalchemy import desc, func
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SoundRepository:
|
class SoundRepository(BaseRepository[Sound]):
|
||||||
"""Repository for sound operations."""
|
"""Repository for sound operations."""
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession) -> None:
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
"""Initialize the sound repository."""
|
"""Initialize the sound repository."""
|
||||||
self.session = session
|
super().__init__(Sound, session)
|
||||||
|
|
||||||
async def get_by_id(self, sound_id: int) -> Sound | None:
|
|
||||||
"""Get a sound by ID."""
|
|
||||||
try:
|
|
||||||
statement = select(Sound).where(Sound.id == sound_id)
|
|
||||||
result = await self.session.exec(statement)
|
|
||||||
return result.first()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get sound by ID: %s", sound_id)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_filename(self, filename: str) -> Sound | None:
|
async def get_by_filename(self, filename: str) -> Sound | None:
|
||||||
"""Get a sound by filename."""
|
"""Get a sound by filename."""
|
||||||
@@ -59,48 +48,6 @@ class SoundRepository:
|
|||||||
logger.exception("Failed to get sounds by type: %s", sound_type)
|
logger.exception("Failed to get sounds by type: %s", sound_type)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def create(self, sound_data: dict[str, Any]) -> Sound:
|
|
||||||
"""Create a new sound."""
|
|
||||||
try:
|
|
||||||
sound = Sound(**sound_data)
|
|
||||||
self.session.add(sound)
|
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(sound)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to create sound")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.info("Created new sound: %s", sound.name)
|
|
||||||
return sound
|
|
||||||
|
|
||||||
async def update(self, sound: Sound, update_data: dict[str, Any]) -> Sound:
|
|
||||||
"""Update a sound."""
|
|
||||||
try:
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(sound, field, value)
|
|
||||||
|
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(sound)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to update sound")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.info("Updated sound: %s", sound.name)
|
|
||||||
return sound
|
|
||||||
|
|
||||||
async def delete(self, sound: Sound) -> None:
|
|
||||||
"""Delete a sound."""
|
|
||||||
try:
|
|
||||||
await self.session.delete(sound)
|
|
||||||
await self.session.commit()
|
|
||||||
logger.info("Deleted sound: %s", sound.name)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to delete sound")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def search_by_name(self, query: str) -> list[Sound]:
|
async def search_by_name(self, query: str) -> list[Sound]:
|
||||||
"""Search sounds by name (case-insensitive)."""
|
"""Search sounds by name (case-insensitive)."""
|
||||||
try:
|
try:
|
||||||
@@ -144,6 +91,6 @@ class SoundRepository:
|
|||||||
return list(result.all())
|
return list(result.all())
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to get unnormalized sounds by type: %s", sound_type
|
"Failed to get unnormalized sounds by type: %s", sound_type,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -8,26 +8,17 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UserRepository:
|
class UserRepository(BaseRepository[User]):
|
||||||
"""Repository for user operations."""
|
"""Repository for user operations."""
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession) -> None:
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
"""Initialize the user repository."""
|
"""Initialize the user repository."""
|
||||||
self.session = session
|
super().__init__(User, session)
|
||||||
|
|
||||||
async def get_by_id(self, user_id: int) -> User | None:
|
|
||||||
"""Get a user by ID."""
|
|
||||||
try:
|
|
||||||
statement = select(User).where(User.id == user_id)
|
|
||||||
result = await self.session.exec(statement)
|
|
||||||
return result.first()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get user by ID: %s", user_id)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> User | None:
|
async def get_by_email(self, email: str) -> User | None:
|
||||||
"""Get a user by email address."""
|
"""Get a user by email address."""
|
||||||
@@ -50,7 +41,7 @@ class UserRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def create(self, user_data: dict[str, Any]) -> User:
|
async def create(self, user_data: dict[str, Any]) -> User:
|
||||||
"""Create a new user."""
|
"""Create a new user with plan assignment and first user admin logic."""
|
||||||
|
|
||||||
def _raise_plan_not_found() -> None:
|
def _raise_plan_not_found() -> None:
|
||||||
msg = "Default plan not found"
|
msg = "Default plan not found"
|
||||||
@@ -84,45 +75,11 @@ class UserRepository:
|
|||||||
user_data["plan_id"] = default_plan.id
|
user_data["plan_id"] = default_plan.id
|
||||||
user_data["credits"] = default_plan.credits
|
user_data["credits"] = default_plan.credits
|
||||||
|
|
||||||
user = User(**user_data)
|
# Use BaseRepository's create method
|
||||||
self.session.add(user)
|
return await super().create(user_data)
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(user)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to create user")
|
logger.exception("Failed to create user")
|
||||||
raise
|
raise
|
||||||
else:
|
|
||||||
logger.info("Created new user with email: %s", user.email)
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def update(self, user: User, update_data: dict[str, Any]) -> User:
|
|
||||||
"""Update a user."""
|
|
||||||
try:
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(user, field, value)
|
|
||||||
|
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(user)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to update user")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.info("Updated user: %s", user.email)
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def delete(self, user: User) -> None:
|
|
||||||
"""Delete a user."""
|
|
||||||
try:
|
|
||||||
await self.session.delete(user)
|
|
||||||
await self.session.commit()
|
|
||||||
|
|
||||||
logger.info("Deleted user: %s", user.email)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to delete user")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def email_exists(self, email: str) -> bool:
|
async def email_exists(self, email: str) -> bool:
|
||||||
"""Check if an email address is already registered."""
|
"""Check if an email address is already registered."""
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
"""Repository for user OAuth operations."""
|
"""Repository for user OAuth operations."""
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.user_oauth import UserOauth
|
from app.models.user_oauth import UserOauth
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UserOauthRepository:
|
class UserOauthRepository(BaseRepository[UserOauth]):
|
||||||
"""Repository for user OAuth operations."""
|
"""Repository for user OAuth operations."""
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession) -> None:
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
"""Initialize repository with database session."""
|
"""Initialize repository with database session."""
|
||||||
self.session = session
|
super().__init__(UserOauth, session)
|
||||||
|
|
||||||
async def get_by_provider_user_id(
|
async def get_by_provider_user_id(
|
||||||
self,
|
self,
|
||||||
@@ -61,57 +61,3 @@ class UserOauthRepository:
|
|||||||
else:
|
else:
|
||||||
return result.first()
|
return result.first()
|
||||||
|
|
||||||
async def create(self, oauth_data: dict[str, Any]) -> UserOauth:
|
|
||||||
"""Create a new user OAuth record."""
|
|
||||||
try:
|
|
||||||
oauth = UserOauth(**oauth_data)
|
|
||||||
self.session.add(oauth)
|
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(oauth)
|
|
||||||
logger.info(
|
|
||||||
"Created OAuth link for user %s with provider %s",
|
|
||||||
oauth.user_id,
|
|
||||||
oauth.provider,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to create user OAuth")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
return oauth
|
|
||||||
|
|
||||||
async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth:
|
|
||||||
"""Update a user OAuth record."""
|
|
||||||
try:
|
|
||||||
for key, value in update_data.items():
|
|
||||||
setattr(oauth, key, value)
|
|
||||||
|
|
||||||
self.session.add(oauth)
|
|
||||||
await self.session.commit()
|
|
||||||
await self.session.refresh(oauth)
|
|
||||||
logger.info(
|
|
||||||
"Updated OAuth link for user %s with provider %s",
|
|
||||||
oauth.user_id,
|
|
||||||
oauth.provider,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to update user OAuth")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
return oauth
|
|
||||||
|
|
||||||
async def delete(self, oauth: UserOauth) -> None:
|
|
||||||
"""Delete a user OAuth record."""
|
|
||||||
try:
|
|
||||||
await self.session.delete(oauth)
|
|
||||||
await self.session.commit()
|
|
||||||
logger.info(
|
|
||||||
"Deleted OAuth link for user %s with provider %s",
|
|
||||||
oauth.user_id,
|
|
||||||
oauth.provider,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
await self.session.rollback()
|
|
||||||
logger.exception("Failed to delete user OAuth")
|
|
||||||
raise
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from app.services.player import PlayerMode
|
|||||||
class PlayerSeekRequest(BaseModel):
|
class PlayerSeekRequest(BaseModel):
|
||||||
"""Request model for seek operation."""
|
"""Request model for seek operation."""
|
||||||
|
|
||||||
position_ms: int = Field(ge=0, description="Position in milliseconds")
|
position: int = Field(ge=0, description="Position in milliseconds")
|
||||||
|
|
||||||
|
|
||||||
class PlayerVolumeRequest(BaseModel):
|
class PlayerVolumeRequest(BaseModel):
|
||||||
@@ -35,8 +35,8 @@ class PlayerStateResponse(BaseModel):
|
|||||||
playlist: dict[str, Any] | None = Field(
|
playlist: dict[str, Any] | None = Field(
|
||||||
None, description="Current playlist information"
|
None, description="Current playlist information"
|
||||||
)
|
)
|
||||||
position_ms: int = Field(description="Current position in milliseconds")
|
position: int = Field(description="Current position in milliseconds")
|
||||||
duration_ms: int | None = Field(
|
duration: int | None = Field(
|
||||||
None, description="Total duration in milliseconds",
|
None, description="Total duration in milliseconds",
|
||||||
)
|
)
|
||||||
volume: int = Field(description="Current volume (0-100)")
|
volume: int = Field(description="Current volume (0-100)")
|
||||||
|
|||||||
@@ -64,8 +64,8 @@ class PlayerState:
|
|||||||
"status": self.status.value,
|
"status": self.status.value,
|
||||||
"mode": self.mode.value,
|
"mode": self.mode.value,
|
||||||
"volume": self.volume,
|
"volume": self.volume,
|
||||||
"position_ms": self.current_sound_position or 0,
|
"position": self.current_sound_position or 0,
|
||||||
"duration_ms": self.current_sound_duration,
|
"duration": self.current_sound_duration,
|
||||||
"index": self.current_sound_index,
|
"index": self.current_sound_index,
|
||||||
"current_sound": self._serialize_sound(self.current_sound),
|
"current_sound": self._serialize_sound(self.current_sound),
|
||||||
"playlist": {
|
"playlist": {
|
||||||
|
|||||||
@@ -278,17 +278,17 @@ class TestPlayerEndpoints:
|
|||||||
mock_player_service,
|
mock_player_service,
|
||||||
):
|
):
|
||||||
"""Test seeking to position successfully."""
|
"""Test seeking to position successfully."""
|
||||||
position_ms = 5000
|
position = 5000
|
||||||
response = await authenticated_client.post(
|
response = await authenticated_client.post(
|
||||||
"/api/v1/player/seek",
|
"/api/v1/player/seek",
|
||||||
json={"position_ms": position_ms},
|
json={"position": position},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["message"] == f"Seeked to position {position_ms}ms"
|
assert data["message"] == f"Seeked to position {position}ms"
|
||||||
|
|
||||||
mock_player_service.seek.assert_called_once_with(position_ms)
|
mock_player_service.seek.assert_called_once_with(position)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_seek_invalid_position(
|
async def test_seek_invalid_position(
|
||||||
@@ -300,7 +300,7 @@ class TestPlayerEndpoints:
|
|||||||
"""Test seeking with invalid position."""
|
"""Test seeking with invalid position."""
|
||||||
response = await authenticated_client.post(
|
response = await authenticated_client.post(
|
||||||
"/api/v1/player/seek",
|
"/api/v1/player/seek",
|
||||||
json={"position_ms": -1000}, # Negative position
|
json={"position": -1000}, # Negative position
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 422 # Validation error
|
assert response.status_code == 422 # Validation error
|
||||||
@@ -310,7 +310,7 @@ class TestPlayerEndpoints:
|
|||||||
"""Test seeking without authentication."""
|
"""Test seeking without authentication."""
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/v1/player/seek",
|
"/api/v1/player/seek",
|
||||||
json={"position_ms": 5000},
|
json={"position": 5000},
|
||||||
)
|
)
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
@@ -326,7 +326,7 @@ class TestPlayerEndpoints:
|
|||||||
|
|
||||||
response = await authenticated_client.post(
|
response = await authenticated_client.post(
|
||||||
"/api/v1/player/seek",
|
"/api/v1/player/seek",
|
||||||
json={"position_ms": 5000},
|
json={"position": 5000},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
@@ -516,8 +516,8 @@ class TestPlayerEndpoints:
|
|||||||
"status": PlayerStatus.PLAYING.value,
|
"status": PlayerStatus.PLAYING.value,
|
||||||
"mode": PlayerMode.CONTINUOUS.value,
|
"mode": PlayerMode.CONTINUOUS.value,
|
||||||
"volume": 50,
|
"volume": 50,
|
||||||
"position_ms": 5000,
|
"position": 5000,
|
||||||
"duration_ms": 30000,
|
"duration": 30000,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"current_sound": {
|
"current_sound": {
|
||||||
"id": 1,
|
"id": 1,
|
||||||
@@ -625,7 +625,7 @@ class TestPlayerEndpoints:
|
|||||||
"""Test seeking to position zero."""
|
"""Test seeking to position zero."""
|
||||||
response = await authenticated_client.post(
|
response = await authenticated_client.post(
|
||||||
"/api/v1/player/seek",
|
"/api/v1/player/seek",
|
||||||
json={"position_ms": 0},
|
json={"position": 0},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -39,24 +39,24 @@ class TestExtractionRepository:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Mock the session operations
|
# Mock the session operations
|
||||||
mock_extraction = Extraction(**extraction_data, id=1)
|
|
||||||
extraction_repo.session.add = Mock()
|
extraction_repo.session.add = Mock()
|
||||||
extraction_repo.session.commit = AsyncMock()
|
extraction_repo.session.commit = AsyncMock()
|
||||||
extraction_repo.session.refresh = AsyncMock()
|
extraction_repo.session.refresh = AsyncMock()
|
||||||
|
|
||||||
# Mock the Extraction constructor to return our mock
|
|
||||||
with pytest.MonkeyPatch().context() as m:
|
|
||||||
m.setattr(
|
|
||||||
"app.repositories.extraction.Extraction",
|
|
||||||
lambda **kwargs: mock_extraction,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await extraction_repo.create(extraction_data)
|
result = await extraction_repo.create(extraction_data)
|
||||||
|
|
||||||
assert result == mock_extraction
|
# Verify the result has the expected attributes
|
||||||
|
assert result.url == extraction_data["url"]
|
||||||
|
assert result.user_id == extraction_data["user_id"]
|
||||||
|
assert result.service == extraction_data["service"]
|
||||||
|
assert result.service_id == extraction_data["service_id"]
|
||||||
|
assert result.title == extraction_data["title"]
|
||||||
|
assert result.status == extraction_data["status"]
|
||||||
|
|
||||||
|
# Verify session methods were called
|
||||||
extraction_repo.session.add.assert_called_once()
|
extraction_repo.session.add.assert_called_once()
|
||||||
extraction_repo.session.commit.assert_called_once()
|
extraction_repo.session.commit.assert_called_once()
|
||||||
extraction_repo.session.refresh.assert_called_once_with(mock_extraction)
|
extraction_repo.session.refresh.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_by_service_and_id(self, extraction_repo):
|
async def test_get_by_service_and_id(self, extraction_repo):
|
||||||
|
|||||||
@@ -65,14 +65,13 @@ class TestPlayerState:
|
|||||||
assert result["status"] == "playing"
|
assert result["status"] == "playing"
|
||||||
assert result["mode"] == "loop"
|
assert result["mode"] == "loop"
|
||||||
assert result["volume"] == 75
|
assert result["volume"] == 75
|
||||||
assert result["current_sound_id"] == 1
|
assert result["position"] == 5000
|
||||||
assert result["current_sound_index"] == 0
|
assert result["duration"] == 30000
|
||||||
assert result["current_sound_position"] == 5000
|
assert result["index"] == 0
|
||||||
assert result["current_sound_duration"] == 30000
|
assert result["playlist"]["id"] == 1
|
||||||
assert result["playlist_id"] == 1
|
assert result["playlist"]["name"] == "Test Playlist"
|
||||||
assert result["playlist_name"] == "Test Playlist"
|
assert result["playlist"]["length"] == 5
|
||||||
assert result["playlist_length"] == 5
|
assert result["playlist"]["duration"] == 150000
|
||||||
assert result["playlist_duration"] == 150000
|
|
||||||
|
|
||||||
def test_serialize_sound_with_sound_object(self):
|
def test_serialize_sound_with_sound_object(self):
|
||||||
"""Test serializing a sound object."""
|
"""Test serializing a sound object."""
|
||||||
|
|||||||
Reference in New Issue
Block a user