Compare commits
110 Commits
b8346ab667
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4c72f3b19 | ||
|
|
17eafa4872 | ||
|
|
c9f6bff723 | ||
|
|
12243b1424 | ||
|
|
f7197a89a7 | ||
|
|
b66b8e36bb | ||
|
|
95e166eefb | ||
|
|
d9697c2dd7 | ||
|
|
7b59a8216a | ||
|
|
4b8496d025 | ||
|
|
0806d541f2 | ||
|
|
acdf191a5a | ||
|
|
35b857fd0d | ||
|
|
c13e18c290 | ||
|
|
702d7ee577 | ||
|
|
d3b6e90262 | ||
|
|
50eeae4c62 | ||
|
|
e005dedcd3 | ||
|
|
72ddd98b25 | ||
|
|
b2e513a915 | ||
|
|
c8b796aa94 | ||
|
|
d5f9a3c736 | ||
|
|
2b61d35d6a | ||
|
|
5e8d619736 | ||
|
|
fb0e5e919c | ||
|
|
bccfcafe0e | ||
|
|
1bef694f38 | ||
|
|
b87a47f199 | ||
|
|
83239cb4fa | ||
|
|
e8f979c137 | ||
|
|
92571f4de9 | ||
|
|
1388ede1dc | ||
|
|
75569a60b5 | ||
|
|
2bdd109492 | ||
|
|
dc89e45675 | ||
|
|
96801dc4d6 | ||
|
|
6e74d9b940 | ||
|
|
03abed6d39 | ||
|
|
7dee6e320e | ||
|
|
d3ce17f10d | ||
|
|
da66516bb3 | ||
|
|
d81a54207c | ||
|
|
16eb789539 | ||
|
|
28faca55bc | ||
|
|
821093f64f | ||
|
|
9653062003 | ||
|
|
b808cfaddf | ||
|
|
a82acfae50 | ||
|
|
560ccd3f7e | ||
|
|
a660cc1861 | ||
|
|
6b55ff0e81 | ||
|
|
e6f796a3c9 | ||
|
|
99c757a073 | ||
|
|
f598ec2c12 | ||
|
|
66d22df7dd | ||
|
|
3326e406f8 | ||
|
|
fe15e7a6af | ||
|
|
f56cc8b4cc | ||
|
|
f906b6d643 | ||
|
|
78508c84eb | ||
|
|
a947fd830b | ||
|
|
5e6cc04ad2 | ||
|
|
c27530a25f | ||
|
|
a109a88eed | ||
|
|
4cec3b9d18 | ||
|
|
b691649f7e | ||
|
|
87d6e6ed67 | ||
|
|
bee1076239 | ||
|
|
f094fbf140 | ||
|
|
d3d7edb287 | ||
|
|
cba1653565 | ||
|
|
c69a45c9b4 | ||
|
|
53b6c4bca5 | ||
|
|
49ad6c8581 | ||
|
|
bb1f036caa | ||
|
|
d1bf2fe0a4 | ||
|
|
bdeb00d562 | ||
|
|
13e0db1fe9 | ||
|
|
357fbcecac | ||
|
|
aa9a73ac1d | ||
|
|
8544a3ce22 | ||
|
|
0a8b50a0be | ||
|
|
9e07ce393f | ||
|
|
734521c5c3 | ||
|
|
69544b6bb8 | ||
|
|
b4f0f54516 | ||
|
|
77446cb5a8 | ||
|
|
4bbae4c5d4 | ||
|
|
d2d0240fdb | ||
|
|
6068599a47 | ||
|
|
d926779fe4 | ||
|
|
0575d12b0e | ||
|
|
c0f51b2e23 | ||
|
|
3132175354 | ||
|
|
43be92c8f9 | ||
|
|
f68f4d9046 | ||
|
|
fceff92ca1 | ||
|
|
dc29915fbc | ||
|
|
389cfe2d6a | ||
|
|
502feea035 | ||
|
|
5fdc7aae85 | ||
|
|
69cdc7567d | ||
|
|
a10111793c | ||
|
|
95ccb76233 | ||
|
|
7ba52ad6fc | ||
|
|
01bb48c206 | ||
|
|
8847131f24 | ||
|
|
e69098d633 | ||
|
|
3405d817d5 | ||
|
|
c63997f591 |
55
.env.development.template
Normal file
55
.env.development.template
Normal file
@@ -0,0 +1,55 @@
|
||||
# Development Environment Configuration
|
||||
# Copy this file to .env for development setup
|
||||
|
||||
# Application Configuration
|
||||
HOST=localhost
|
||||
PORT=8000
|
||||
RELOAD=true
|
||||
|
||||
# Development URLs (for local development)
|
||||
FRONTEND_URL=http://localhost:8001
|
||||
BACKEND_URL=http://localhost:8000
|
||||
CORS_ORIGINS=["http://localhost:8001"]
|
||||
|
||||
# Database Configuration
|
||||
DATABASE_URL=sqlite+aiosqlite:///data/soundboard.db
|
||||
DATABASE_ECHO=false
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=debug
|
||||
LOG_FILE=logs/app.log
|
||||
LOG_MAX_SIZE=10485760
|
||||
LOG_BACKUP_COUNT=5
|
||||
|
||||
# JWT Configuration (Use a secure key even in development)
|
||||
JWT_SECRET_KEY=development-secret-key-change-for-production
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# Cookie Configuration (Development settings)
|
||||
COOKIE_SECURE=false
|
||||
COOKIE_SAMESITE=lax
|
||||
COOKIE_DOMAIN=localhost
|
||||
|
||||
# OAuth2 Configuration (Get these from OAuth providers)
|
||||
# Google: https://console.developers.google.com/
|
||||
# Redirect URI: http://localhost:8000/api/v1/auth/google/callback
|
||||
GOOGLE_CLIENT_ID=your-google-client-id
|
||||
GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
|
||||
# GitHub: https://github.com/settings/developers
|
||||
# Redirect URI: http://localhost:8000/api/v1/auth/github/callback
|
||||
GITHUB_CLIENT_ID=your-github-client-id
|
||||
GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||
|
||||
# Audio Normalization Configuration
|
||||
NORMALIZED_AUDIO_FORMAT=mp3
|
||||
NORMALIZED_AUDIO_BITRATE=256k
|
||||
NORMALIZED_AUDIO_PASSES=2
|
||||
|
||||
# Audio Extraction Configuration
|
||||
EXTRACTION_AUDIO_FORMAT=mp3
|
||||
EXTRACTION_AUDIO_BITRATE=256k
|
||||
EXTRACTION_TEMP_DIR=sounds/temp
|
||||
EXTRACTION_THUMBNAILS_DIR=sounds/originals/extracted/thumbnails
|
||||
EXTRACTION_MAX_CONCURRENT=2
|
||||
50
.env.production.template
Normal file
50
.env.production.template
Normal file
@@ -0,0 +1,50 @@
|
||||
# Production Environment Configuration
|
||||
# Copy this file to .env and configure for your production environment
|
||||
|
||||
# Application Configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
RELOAD=false
|
||||
|
||||
# Production URLs (configure for your domain)
|
||||
FRONTEND_URL=https://yourdomain.com
|
||||
BACKEND_URL=https://yourdomain.com
|
||||
CORS_ORIGINS=["https://yourdomain.com"]
|
||||
|
||||
# Database Configuration (consider using PostgreSQL in production)
|
||||
DATABASE_URL=sqlite+aiosqlite:///data/soundboard.db
|
||||
DATABASE_ECHO=false
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=info
|
||||
LOG_FILE=logs/app.log
|
||||
LOG_MAX_SIZE=10485760
|
||||
LOG_BACKUP_COUNT=5
|
||||
|
||||
# JWT Configuration (IMPORTANT: Generate secure keys for production)
|
||||
JWT_SECRET_KEY=your-super-secure-secret-key-change-this-in-production
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# Cookie Configuration (Production settings)
|
||||
COOKIE_SECURE=true
|
||||
COOKIE_SAMESITE=lax
|
||||
COOKIE_DOMAIN= # Leave empty for same-origin cookies in production with reverse proxy
|
||||
|
||||
# OAuth2 Configuration (Configure with your OAuth providers)
|
||||
GOOGLE_CLIENT_ID=your-google-client-id
|
||||
GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
GITHUB_CLIENT_ID=your-github-client-id
|
||||
GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||
|
||||
# Audio Normalization Configuration
|
||||
NORMALIZED_AUDIO_FORMAT=mp3
|
||||
NORMALIZED_AUDIO_BITRATE=256k
|
||||
NORMALIZED_AUDIO_PASSES=2
|
||||
|
||||
# Audio Extraction Configuration
|
||||
EXTRACTION_AUDIO_FORMAT=mp3
|
||||
EXTRACTION_AUDIO_BITRATE=256k
|
||||
EXTRACTION_TEMP_DIR=sounds/temp
|
||||
EXTRACTION_THUMBNAILS_DIR=sounds/originals/extracted/thumbnails
|
||||
EXTRACTION_MAX_CONCURRENT=2
|
||||
@@ -1,29 +0,0 @@
|
||||
# Application Configuration
|
||||
HOST=localhost
|
||||
PORT=8000
|
||||
RELOAD=true
|
||||
|
||||
# Database Configuration
|
||||
DATABASE_URL=sqlite+aiosqlite:///data/soundboard.db
|
||||
DATABASE_ECHO=false
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=info
|
||||
LOG_FILE=logs/app.log
|
||||
LOG_MAX_SIZE=10485760
|
||||
LOG_BACKUP_COUNT=5
|
||||
|
||||
# JWT Configuration
|
||||
JWT_SECRET_KEY=your-secret-key-change-in-production
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# Cookie Configuration
|
||||
COOKIE_SECURE=false
|
||||
COOKIE_SAMESITE=lax
|
||||
|
||||
# OAuth2 Configuration
|
||||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_CLIENT_SECRET=
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
@@ -9,28 +9,28 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
# lint:
|
||||
# runs-on: ubuntu-latest
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# - name: "Set up python"
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version-file: "pyproject.toml"
|
||||
- name: "Set up python"
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version-file: "pyproject.toml"
|
||||
|
||||
# - name: Install uv
|
||||
# uses: astral-sh/setup-uv@v6
|
||||
# with:
|
||||
# enable-cache: true
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
# - name: Install requirements
|
||||
# run: uv sync --locked --all-extras --dev
|
||||
- name: Install requirements
|
||||
run: uv sync --locked --all-extras --dev
|
||||
|
||||
# - name: Run linter
|
||||
# run: uv run ruff check
|
||||
- name: Run linter
|
||||
run: uv run ruff check
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -8,4 +8,6 @@ wheels/
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
.env
|
||||
.env
|
||||
|
||||
.coverage
|
||||
148
alembic.ini
Normal file
148
alembic.ini
Normal file
@@ -0,0 +1,148 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
# sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
# URL will be set dynamically in env.py from config
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
alembic/README
Normal file
1
alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
86
alembic/env.py
Normal file
86
alembic/env.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from alembic import context
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.core.config import settings
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Set the database URL from settings - convert async URL to sync for alembic
|
||||
sync_db_url = settings.DATABASE_URL.replace("sqlite+aiosqlite", "sqlite")
|
||||
sync_db_url = sync_db_url.replace("postgresql+asyncpg", "postgresql+psycopg")
|
||||
config.set_main_option("sqlalchemy.url", sync_db_url)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = SQLModel.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Add status and error fields to TTS table
|
||||
|
||||
Revision ID: 0d9b7f1c367f
|
||||
Revises: e617c155eea9
|
||||
Create Date: 2025-09-21 14:09:56.418372
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0d9b7f1c367f'
|
||||
down_revision: Union[str, Sequence[str], None] = 'e617c155eea9'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('tts', sa.Column('status', sa.String(), nullable=False, server_default='pending'))
|
||||
op.add_column('tts', sa.Column('error', sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('tts', 'error')
|
||||
op.drop_column('tts', 'status')
|
||||
# ### end Alembic commands ###
|
||||
222
alembic/versions/7aa9892ceff3_initial_migration.py
Normal file
222
alembic/versions/7aa9892ceff3_initial_migration.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Initial migration
|
||||
|
||||
Revision ID: 7aa9892ceff3
|
||||
Revises:
|
||||
Create Date: 2025-09-16 13:16:58.233360
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7aa9892ceff3'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('plan',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('code', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('credits', sa.Integer(), nullable=False),
|
||||
sa.Column('max_credits', sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_plan_code'), 'plan', ['code'], unique=True)
|
||||
op.create_table('sound',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('filename', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('duration', sa.Integer(), nullable=False),
|
||||
sa.Column('size', sa.Integer(), nullable=False),
|
||||
sa.Column('hash', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('normalized_filename', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('normalized_duration', sa.Integer(), nullable=True),
|
||||
sa.Column('normalized_size', sa.Integer(), nullable=True),
|
||||
sa.Column('normalized_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('thumbnail', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('play_count', sa.Integer(), nullable=False),
|
||||
sa.Column('is_normalized', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_music', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_deletable', sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('hash', name='uq_sound_hash')
|
||||
)
|
||||
op.create_table('user',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('plan_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('picture', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('password_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('credits', sa.Integer(), nullable=False),
|
||||
sa.Column('api_token', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('api_token_expires_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('refresh_token_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('refresh_token_expires_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['plan_id'], ['plan.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('api_token'),
|
||||
sa.UniqueConstraint('email')
|
||||
)
|
||||
op.create_table('credit_transaction',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('action_type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('amount', sa.Integer(), nullable=False),
|
||||
sa.Column('balance_before', sa.Integer(), nullable=False),
|
||||
sa.Column('balance_after', sa.Integer(), nullable=False),
|
||||
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('success', sa.Boolean(), nullable=False),
|
||||
sa.Column('metadata_json', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('extraction',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('service', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('service_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('sound_id', sa.Integer(), nullable=True),
|
||||
sa.Column('url', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('title', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('track', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('artist', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('album', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('genre', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('status', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('error', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('playlist',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('genre', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('is_main', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_current', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_deletable', sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('name')
|
||||
)
|
||||
op.create_table('scheduled_task',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
|
||||
sa.Column('task_type', sa.Enum('CREDIT_RECHARGE', 'PLAY_SOUND', 'PLAY_PLAYLIST', name='tasktype'), nullable=False),
|
||||
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='taskstatus'), nullable=False),
|
||||
sa.Column('scheduled_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('timezone', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('recurrence_type', sa.Enum('NONE', 'MINUTELY', 'HOURLY', 'DAILY', 'WEEKLY', 'MONTHLY', 'YEARLY', 'CRON', name='recurrencetype'), nullable=False),
|
||||
sa.Column('cron_expression', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('recurrence_count', sa.Integer(), nullable=True),
|
||||
sa.Column('executions_count', sa.Integer(), nullable=False),
|
||||
sa.Column('parameters', sa.JSON(), nullable=True),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.Column('last_executed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('next_execution_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('sound_played',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.Column('sound_id', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('user_oauth',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('provider_user_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('picture', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_user_oauth_provider_user_id')
|
||||
)
|
||||
op.create_table('favorite',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('sound_id', sa.Integer(), nullable=True),
|
||||
sa.Column('playlist_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['playlist_id'], ['playlist.id'], ),
|
||||
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'playlist_id', name='uq_favorite_user_playlist'),
|
||||
sa.UniqueConstraint('user_id', 'sound_id', name='uq_favorite_user_sound')
|
||||
)
|
||||
op.create_table('playlist_sound',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('playlist_id', sa.Integer(), nullable=False),
|
||||
sa.Column('sound_id', sa.Integer(), nullable=False),
|
||||
sa.Column('position', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['playlist_id'], ['playlist.id'], ),
|
||||
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('playlist_id', 'position', name='uq_playlist_sound_playlist_position'),
|
||||
sa.UniqueConstraint('playlist_id', 'sound_id', name='uq_playlist_sound_playlist_sound')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('playlist_sound')
|
||||
op.drop_table('favorite')
|
||||
op.drop_table('user_oauth')
|
||||
op.drop_table('sound_played')
|
||||
op.drop_table('scheduled_task')
|
||||
op.drop_table('playlist')
|
||||
op.drop_table('extraction')
|
||||
op.drop_table('credit_transaction')
|
||||
op.drop_table('user')
|
||||
op.drop_table('sound')
|
||||
op.drop_index(op.f('ix_plan_code'), table_name='plan')
|
||||
op.drop_table('plan')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Add initial plan and playlist data
|
||||
|
||||
Revision ID: a0d322857b2c
|
||||
Revises: 7aa9892ceff3
|
||||
Create Date: 2025-09-16 13:23:31.682276
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from datetime import datetime
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a0d322857b2c'
|
||||
down_revision: Union[str, Sequence[str], None] = '7aa9892ceff3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema and add initial data."""
|
||||
# Get the current timestamp
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Insert initial plans
|
||||
plans_table = sa.table(
|
||||
'plan',
|
||||
sa.column('code', sa.String),
|
||||
sa.column('name', sa.String),
|
||||
sa.column('description', sa.String),
|
||||
sa.column('credits', sa.Integer),
|
||||
sa.column('max_credits', sa.Integer),
|
||||
sa.column('created_at', sa.DateTime),
|
||||
sa.column('updated_at', sa.DateTime),
|
||||
)
|
||||
|
||||
op.bulk_insert(
|
||||
plans_table,
|
||||
[
|
||||
{
|
||||
'code': 'free',
|
||||
'name': 'Free Plan',
|
||||
'description': 'Basic free plan with limited features',
|
||||
'credits': 25,
|
||||
'max_credits': 75,
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
},
|
||||
{
|
||||
'code': 'premium',
|
||||
'name': 'Premium Plan',
|
||||
'description': 'Premium plan with more features',
|
||||
'credits': 50,
|
||||
'max_credits': 150,
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
},
|
||||
{
|
||||
'code': 'pro',
|
||||
'name': 'Pro Plan',
|
||||
'description': 'Pro plan with unlimited features',
|
||||
'credits': 100,
|
||||
'max_credits': 300,
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Insert main playlist
|
||||
playlist_table = sa.table(
|
||||
'playlist',
|
||||
sa.column('name', sa.String),
|
||||
sa.column('description', sa.String),
|
||||
sa.column('is_main', sa.Boolean),
|
||||
sa.column('is_deletable', sa.Boolean),
|
||||
sa.column('is_current', sa.Boolean),
|
||||
sa.column('created_at', sa.DateTime),
|
||||
sa.column('updated_at', sa.DateTime),
|
||||
)
|
||||
|
||||
op.bulk_insert(
|
||||
playlist_table,
|
||||
[
|
||||
{
|
||||
'name': 'All',
|
||||
'description': 'The default main playlist with all the tracks',
|
||||
'is_main': True,
|
||||
'is_deletable': False,
|
||||
'is_current': True,
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema and remove initial data."""
|
||||
# Remove initial plans
|
||||
op.execute("DELETE FROM plan WHERE code IN ('free', 'premium', 'pro')")
|
||||
|
||||
# Remove main playlist
|
||||
op.execute("DELETE FROM playlist WHERE is_main = 1")
|
||||
45
alembic/versions/e617c155eea9_add_tts_table.py
Normal file
45
alembic/versions/e617c155eea9_add_tts_table.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Add TTS table
|
||||
|
||||
Revision ID: e617c155eea9
|
||||
Revises: a0d322857b2c
|
||||
Create Date: 2025-09-20 21:51:26.557738
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e617c155eea9'
|
||||
down_revision: Union[str, Sequence[str], None] = 'a0d322857b2c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tts',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('text', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=False),
|
||||
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False),
|
||||
sa.Column('options', sa.JSON(), nullable=True),
|
||||
sa.Column('sound_id', sa.Integer(), nullable=True),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('tts')
|
||||
# ### end Alembic commands ###
|
||||
@@ -2,15 +2,36 @@
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import auth, main, player, playlists, socket, sounds
|
||||
from app.api.v1 import (
|
||||
admin,
|
||||
auth,
|
||||
dashboard,
|
||||
extractions,
|
||||
favorites,
|
||||
files,
|
||||
main,
|
||||
player,
|
||||
playlists,
|
||||
scheduler,
|
||||
socket,
|
||||
sounds,
|
||||
tts,
|
||||
)
|
||||
|
||||
# V1 API router with v1 prefix
|
||||
api_router = APIRouter(prefix="/v1")
|
||||
|
||||
# Include all route modules
|
||||
api_router.include_router(auth.router, tags=["authentication"])
|
||||
api_router.include_router(dashboard.router, tags=["dashboard"])
|
||||
api_router.include_router(extractions.router, tags=["extractions"])
|
||||
api_router.include_router(favorites.router, tags=["favorites"])
|
||||
api_router.include_router(files.router, tags=["files"])
|
||||
api_router.include_router(main.router, tags=["main"])
|
||||
api_router.include_router(player.router, tags=["player"])
|
||||
api_router.include_router(playlists.router, tags=["playlists"])
|
||||
api_router.include_router(scheduler.router, tags=["scheduler"])
|
||||
api_router.include_router(socket.router, tags=["socket"])
|
||||
api_router.include_router(sounds.router, tags=["sounds"])
|
||||
api_router.include_router(tts.router, tags=["tts"])
|
||||
api_router.include_router(admin.router)
|
||||
|
||||
12
app/api/v1/admin/__init__.py
Normal file
12
app/api/v1/admin/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Admin API endpoints."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1.admin import extractions, sounds, users
|
||||
|
||||
router = APIRouter(prefix="/admin")
|
||||
|
||||
# Include all admin sub-routers
|
||||
router.include_router(extractions.router)
|
||||
router.include_router(sounds.router)
|
||||
router.include_router(users.router)
|
||||
59
app/api/v1/admin/extractions.py
Normal file
59
app/api/v1/admin/extractions.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Admin audio extraction API endpoints."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_admin_user
|
||||
from app.models.user import User
|
||||
from app.services.extraction import ExtractionService
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
|
||||
router = APIRouter(prefix="/extractions", tags=["admin-extractions"])
|
||||
|
||||
|
||||
async def get_extraction_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> ExtractionService:
|
||||
"""Get the extraction service."""
|
||||
return ExtractionService(session)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_extraction_processor_status(
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
) -> dict:
|
||||
"""Get the status of the extraction processor. Admin only."""
|
||||
return extraction_processor.get_status()
|
||||
|
||||
|
||||
@router.delete("/{extraction_id}")
|
||||
async def delete_extraction(
|
||||
extraction_id: int,
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Delete any extraction and its associated sound and files. Admin only."""
|
||||
try:
|
||||
deleted = await extraction_service.delete_extraction(extraction_id, None)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Extraction {extraction_id} not found",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete extraction: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"message": f"Extraction {extraction_id} deleted successfully",
|
||||
}
|
||||
228
app/api/v1/admin/sounds.py
Normal file
228
app/api/v1/admin/sounds.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Admin sound management API endpoints."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_admin_user
|
||||
from app.models.user import User
|
||||
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
|
||||
from app.services.sound_scanner import ScanResults, SoundScannerService
|
||||
|
||||
router = APIRouter(prefix="/sounds", tags=["admin-sounds"])
|
||||
|
||||
|
||||
async def get_sound_scanner_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> SoundScannerService:
|
||||
"""Get the sound scanner service."""
|
||||
return SoundScannerService(session)
|
||||
|
||||
|
||||
async def get_sound_normalizer_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> SoundNormalizerService:
|
||||
"""Get the sound normalizer service."""
|
||||
return SoundNormalizerService(session)
|
||||
|
||||
|
||||
# SCAN ENDPOINTS
|
||||
@router.post("/scan")
|
||||
async def scan_sounds(
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
|
||||
) -> dict[str, ScanResults | str]:
|
||||
"""Sync the soundboard directory (add/update/delete sounds). Admin only."""
|
||||
try:
|
||||
results = await scanner_service.scan_soundboard_directory()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to sync sounds: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"message": "Sound sync completed",
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/scan/custom")
|
||||
async def scan_custom_directory(
|
||||
directory: str,
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
|
||||
sound_type: str = "SDB",
|
||||
) -> dict[str, ScanResults | str]:
|
||||
"""Sync a custom directory with the database. Admin only."""
|
||||
try:
|
||||
results = await scanner_service.scan_directory(directory, sound_type)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to sync directory: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"message": f"Sync of directory '{directory}' completed",
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
# NORMALIZE ENDPOINTS
|
||||
@router.post("/normalize/all")
|
||||
async def normalize_all_sounds(
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
normalizer_service: Annotated[
|
||||
SoundNormalizerService,
|
||||
Depends(get_sound_normalizer_service),
|
||||
],
|
||||
*,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query(
|
||||
description="Force normalization of already normalized sounds",
|
||||
),
|
||||
] = False,
|
||||
one_pass: Annotated[
|
||||
bool | None,
|
||||
Query(
|
||||
description="Use one-pass normalization (overrides config)",
|
||||
),
|
||||
] = None,
|
||||
) -> dict[str, NormalizationResults | str]:
|
||||
"""Normalize all unnormalized sounds. Admin only."""
|
||||
try:
|
||||
results = await normalizer_service.normalize_all_sounds(
|
||||
force=force,
|
||||
one_pass=one_pass,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to normalize sounds: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"message": "Sound normalization completed",
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/normalize/type/{sound_type}")
|
||||
async def normalize_sounds_by_type(
|
||||
sound_type: str,
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
normalizer_service: Annotated[
|
||||
SoundNormalizerService,
|
||||
Depends(get_sound_normalizer_service),
|
||||
],
|
||||
*,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query(
|
||||
description="Force normalization of already normalized sounds",
|
||||
),
|
||||
] = False,
|
||||
one_pass: Annotated[
|
||||
bool | None,
|
||||
Query(
|
||||
description="Use one-pass normalization (overrides config)",
|
||||
),
|
||||
] = None,
|
||||
) -> dict[str, NormalizationResults | str]:
|
||||
"""Normalize all sounds of a specific type (SDB, TTS, EXT). Admin only."""
|
||||
# Validate sound type
|
||||
valid_types = ["SDB", "TTS", "EXT"]
|
||||
if sound_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid sound type. Must be one of: {', '.join(valid_types)}",
|
||||
)
|
||||
|
||||
try:
|
||||
results = await normalizer_service.normalize_sounds_by_type(
|
||||
sound_type=sound_type,
|
||||
force=force,
|
||||
one_pass=one_pass,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to normalize {sound_type} sounds: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"message": f"Normalization of {sound_type} sounds completed",
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/normalize/{sound_id}")
|
||||
async def normalize_sound_by_id(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
normalizer_service: Annotated[
|
||||
SoundNormalizerService,
|
||||
Depends(get_sound_normalizer_service),
|
||||
],
|
||||
*,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query(
|
||||
description="Force normalization of already normalized sound",
|
||||
),
|
||||
] = False,
|
||||
one_pass: Annotated[
|
||||
bool | None,
|
||||
Query(
|
||||
description="Use one-pass normalization (overrides config)",
|
||||
),
|
||||
] = None,
|
||||
) -> dict[str, str]:
|
||||
"""Normalize a specific sound by ID. Admin only."""
|
||||
try:
|
||||
# Get the sound
|
||||
sound = await normalizer_service.sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Sound with ID {sound_id} not found",
|
||||
)
|
||||
|
||||
# Normalize the sound
|
||||
result = await normalizer_service.normalize_sound(
|
||||
sound=sound,
|
||||
force=force,
|
||||
one_pass=one_pass,
|
||||
)
|
||||
|
||||
# Check result status
|
||||
if result["status"] == "error":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to normalize sound: {result['error']}",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Sound normalization {result['status']}: {sound.filename}",
|
||||
"status": result["status"],
|
||||
"reason": result["reason"] or "",
|
||||
"normalized_filename": result["normalized_filename"] or "",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to normalize sound: {e!s}",
|
||||
) from e
|
||||
176
app/api/v1/admin/users.py
Normal file
176
app/api/v1/admin/users.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Admin users endpoints."""
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_admin_user
|
||||
from app.models.plan import Plan
|
||||
from app.models.user import User
|
||||
from app.repositories.plan import PlanRepository
|
||||
from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus
|
||||
from app.schemas.auth import UserResponse
|
||||
from app.schemas.user import UserUpdate
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/users",
|
||||
tags=["admin-users"],
|
||||
dependencies=[Depends(get_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
def _user_to_response(user: User) -> UserResponse:
|
||||
"""Convert User model to UserResponse."""
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
picture=user.picture,
|
||||
role=user.role,
|
||||
credits=user.credits,
|
||||
is_active=user.is_active,
|
||||
plan={
|
||||
"id": user.plan.id,
|
||||
"name": user.plan.name,
|
||||
"max_credits": user.plan.max_credits,
|
||||
"features": [], # Add features if needed
|
||||
}
|
||||
if user.plan
|
||||
else {},
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_users( # noqa: PLR0913
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||
search: Annotated[str | None, Query(description="Search in name or email")] = None,
|
||||
sort_by: Annotated[
|
||||
UserSortField,
|
||||
Query(description="Sort by field"),
|
||||
] = UserSortField.NAME,
|
||||
sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC,
|
||||
status_filter: Annotated[
|
||||
UserStatus,
|
||||
Query(description="Filter by status"),
|
||||
] = UserStatus.ALL,
|
||||
) -> dict[str, Any]:
|
||||
"""Get all users with pagination, search, and filters (admin only)."""
|
||||
user_repo = UserRepository(session)
|
||||
users, total_count = await user_repo.get_all_with_plan_paginated(
|
||||
page=page,
|
||||
limit=limit,
|
||||
search=search,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
status_filter=status_filter,
|
||||
)
|
||||
|
||||
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||
|
||||
return {
|
||||
"users": [_user_to_response(user) for user in users],
|
||||
"total": total_count,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"total_pages": total_pages,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{user_id}")
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> UserResponse:
|
||||
"""Get a specific user by ID (admin only)."""
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id_with_plan(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
return _user_to_response(user)
|
||||
|
||||
|
||||
@router.patch("/{user_id}")
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_update: UserUpdate,
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> UserResponse:
|
||||
"""Update a user (admin only)."""
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id_with_plan(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
update_data = user_update.model_dump(exclude_unset=True)
|
||||
|
||||
# If plan_id is being updated, validate it exists
|
||||
if "plan_id" in update_data:
|
||||
plan_repo = PlanRepository(session)
|
||||
plan = await plan_repo.get_by_id(update_data["plan_id"])
|
||||
if not plan:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Plan not found",
|
||||
)
|
||||
|
||||
updated_user = await user_repo.update(user, update_data)
|
||||
# Need to refresh the plan relationship after update
|
||||
await session.refresh(updated_user, ["plan"])
|
||||
return _user_to_response(updated_user)
|
||||
|
||||
|
||||
@router.post("/{user_id}/disable")
|
||||
async def disable_user(
|
||||
user_id: int,
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> dict[str, str]:
|
||||
"""Disable a user (admin only)."""
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id_with_plan(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
await user_repo.update(user, {"is_active": False})
|
||||
return {"message": "User disabled successfully"}
|
||||
|
||||
|
||||
@router.post("/{user_id}/enable")
|
||||
async def enable_user(
|
||||
user_id: int,
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> dict[str, str]:
|
||||
"""Enable a user (admin only)."""
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id_with_plan(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
await user_repo.update(user, {"is_active": True})
|
||||
return {"message": "User enabled successfully"}
|
||||
|
||||
|
||||
@router.get("/plans/list")
|
||||
async def list_plans(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> list[Plan]:
|
||||
"""Get all plans for user editing (admin only)."""
|
||||
plan_repo = PlanRepository(session)
|
||||
return await plan_repo.get_all()
|
||||
@@ -20,6 +20,8 @@ from app.schemas.auth import (
|
||||
ApiTokenRequest,
|
||||
ApiTokenResponse,
|
||||
ApiTokenStatusResponse,
|
||||
ChangePasswordRequest,
|
||||
UpdateProfileRequest,
|
||||
UserLoginRequest,
|
||||
UserRegisterRequest,
|
||||
UserResponse,
|
||||
@@ -27,6 +29,7 @@ from app.schemas.auth import (
|
||||
from app.services.auth import AuthService
|
||||
from app.services.oauth import OAuthService
|
||||
from app.utils.auth import JWTUtils, TokenUtils
|
||||
from app.utils.cookies import set_access_token_cookie, set_auth_cookies
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -54,26 +57,11 @@ async def register(
|
||||
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
||||
|
||||
# Set HTTP-only cookies for both tokens
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=auth_response.token.access_token,
|
||||
max_age=auth_response.token.expires_in,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost", # Allow cookie across localhost ports
|
||||
)
|
||||
response.set_cookie(
|
||||
key="refresh_token",
|
||||
value=refresh_token,
|
||||
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
* 24
|
||||
* 60
|
||||
* 60, # Convert days to seconds
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost", # Allow cookie across localhost ports
|
||||
set_auth_cookies(
|
||||
response=response,
|
||||
access_token=auth_response.token.access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=auth_response.token.expires_in,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -103,26 +91,11 @@ async def login(
|
||||
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
||||
|
||||
# Set HTTP-only cookies for both tokens
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=auth_response.token.access_token,
|
||||
max_age=auth_response.token.expires_in,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost", # Allow cookie across localhost ports
|
||||
)
|
||||
response.set_cookie(
|
||||
key="refresh_token",
|
||||
value=refresh_token,
|
||||
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
* 24
|
||||
* 60
|
||||
* 60, # Convert days to seconds
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost", # Allow cookie across localhost ports
|
||||
set_auth_cookies(
|
||||
response=response,
|
||||
access_token=auth_response.token.access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=auth_response.token.expires_in,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -171,14 +144,10 @@ async def refresh_token(
|
||||
token_response = await auth_service.refresh_access_token(refresh_token)
|
||||
|
||||
# Set new access token cookie
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token_response.access_token,
|
||||
max_age=token_response.expires_in,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost", # Allow cookie across localhost ports
|
||||
set_access_token_cookie(
|
||||
response=response,
|
||||
access_token=token_response.access_token,
|
||||
expires_in=token_response.expires_in,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -212,7 +181,9 @@ async def logout(
|
||||
user_id = int(user_id_str)
|
||||
user = await auth_service.get_current_user(user_id)
|
||||
logger.info("Found user from access token: %s", user.email)
|
||||
except (HTTPException, Exception) as e:
|
||||
except HTTPException as e:
|
||||
logger.info("Access token validation failed: %s", str(e))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.info("Access token validation failed: %s", str(e))
|
||||
|
||||
# If no user found, try refresh token
|
||||
@@ -224,7 +195,9 @@ async def logout(
|
||||
user_id = int(user_id_str)
|
||||
user = await auth_service.get_current_user(user_id)
|
||||
logger.info("Found user from refresh token: %s", user.email)
|
||||
except (HTTPException, Exception) as e:
|
||||
except HTTPException as e:
|
||||
logger.info("Refresh token validation failed: %s", str(e))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.info("Refresh token validation failed: %s", str(e))
|
||||
|
||||
# If we found a user, revoke their refresh token
|
||||
@@ -240,14 +213,14 @@ async def logout(
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost", # Match the domain used when setting cookies
|
||||
domain=settings.COOKIE_DOMAIN, # Match the domain used when setting cookies
|
||||
)
|
||||
response.delete_cookie(
|
||||
key="refresh_token",
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost", # Match the domain used when setting cookies
|
||||
domain=settings.COOKIE_DOMAIN, # Match the domain used when setting cookies
|
||||
)
|
||||
|
||||
return {"message": "Successfully logged out"}
|
||||
@@ -307,24 +280,11 @@ async def oauth_callback(
|
||||
# Set HTTP-only cookies for both tokens (not used due to cross-port issues)
|
||||
# These cookies are kept for potential future same-origin scenarios
|
||||
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=auth_response.token.access_token,
|
||||
max_age=auth_response.token.expires_in,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost", # Allow cookie across localhost ports
|
||||
path="/", # Ensure cookie is available for all paths
|
||||
)
|
||||
response.set_cookie(
|
||||
key="refresh_token",
|
||||
value=refresh_token,
|
||||
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost", # Allow cookie across localhost ports
|
||||
set_auth_cookies(
|
||||
response=response,
|
||||
access_token=auth_response.token.access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=auth_response.token.expires_in,
|
||||
path="/", # Ensure cookie is available for all paths
|
||||
)
|
||||
|
||||
@@ -349,7 +309,7 @@ async def oauth_callback(
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
redirect_url = f"http://localhost:8001/auth/callback?code={temp_code}"
|
||||
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?code={temp_code}"
|
||||
logger.info("Redirecting to: %s", redirect_url)
|
||||
|
||||
return RedirectResponse(
|
||||
@@ -410,24 +370,11 @@ async def exchange_oauth_token(
|
||||
)
|
||||
|
||||
# Set the proper auth cookies
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token_data["access_token"],
|
||||
max_age=token_data["expires_in"],
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost",
|
||||
path="/",
|
||||
)
|
||||
response.set_cookie(
|
||||
key="refresh_token",
|
||||
value=token_data["refresh_token"],
|
||||
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain="localhost",
|
||||
set_auth_cookies(
|
||||
response=response,
|
||||
access_token=token_data["access_token"],
|
||||
refresh_token=token_data["refresh_token"],
|
||||
expires_in=token_data["expires_in"],
|
||||
path="/",
|
||||
)
|
||||
|
||||
@@ -505,3 +452,93 @@ async def revoke_api_token(
|
||||
) from e
|
||||
else:
|
||||
return {"message": "API token revoked successfully"}
|
||||
|
||||
|
||||
# Profile management endpoints
|
||||
@router.patch("/me")
|
||||
async def update_profile(
|
||||
request: UpdateProfileRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> UserResponse:
|
||||
"""Update the current user's profile."""
|
||||
try:
|
||||
updated_user = await auth_service.update_user_profile(
|
||||
current_user,
|
||||
request.model_dump(exclude_unset=True),
|
||||
)
|
||||
return await auth_service.user_to_response(updated_user)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update profile for user: %s", current_user.email)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update profile",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Change the current user's password."""
|
||||
# Store user email before operations to avoid session detachment issues
|
||||
user_email = current_user.email
|
||||
try:
|
||||
await auth_service.change_user_password(
|
||||
current_user,
|
||||
request.current_password,
|
||||
request.new_password,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to change password for user: %s", user_email)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to change password",
|
||||
) from e
|
||||
else:
|
||||
return {"message": "Password changed successfully"}
|
||||
|
||||
|
||||
@router.get("/user-providers")
|
||||
async def get_user_providers(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Get the current user's connected authentication providers."""
|
||||
providers = []
|
||||
|
||||
# Add password provider if user has password
|
||||
if current_user.password_hash:
|
||||
providers.append(
|
||||
{
|
||||
"provider": "password",
|
||||
"display_name": "Password",
|
||||
"connected_at": current_user.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Get OAuth providers from the database
|
||||
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
|
||||
for oauth in oauth_providers:
|
||||
display_name = oauth.provider.title() # Capitalize first letter
|
||||
if oauth.provider == "github":
|
||||
display_name = "GitHub"
|
||||
elif oauth.provider == "google":
|
||||
display_name = "Google"
|
||||
|
||||
providers.append(
|
||||
{
|
||||
"provider": oauth.provider,
|
||||
"display_name": display_name,
|
||||
"connected_at": oauth.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
return providers
|
||||
|
||||
127
app/api/v1/dashboard.py
Normal file
127
app/api/v1/dashboard.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Dashboard API endpoints."""
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
from app.core.dependencies import get_current_user, get_dashboard_service
|
||||
from app.models.user import User
|
||||
from app.services.dashboard import DashboardService
|
||||
|
||||
router = APIRouter(prefix="/dashboard", tags=["dashboard"])
|
||||
|
||||
|
||||
@router.get("/soundboard-statistics")
|
||||
async def get_soundboard_statistics(
|
||||
_current_user: Annotated[User, Depends(get_current_user)],
|
||||
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
|
||||
) -> dict[str, Any]:
|
||||
"""Get soundboard statistics."""
|
||||
try:
|
||||
return await dashboard_service.get_soundboard_statistics()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to fetch soundboard statistics: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/track-statistics")
|
||||
async def get_track_statistics(
|
||||
_current_user: Annotated[User, Depends(get_current_user)],
|
||||
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
|
||||
) -> dict[str, Any]:
|
||||
"""Get track statistics."""
|
||||
try:
|
||||
return await dashboard_service.get_track_statistics()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to fetch track statistics: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/tts-statistics")
|
||||
async def get_tts_statistics(
|
||||
_current_user: Annotated[User, Depends(get_current_user)],
|
||||
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
|
||||
) -> dict[str, Any]:
|
||||
"""Get TTS statistics."""
|
||||
try:
|
||||
return await dashboard_service.get_tts_statistics()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to fetch TTS statistics: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/top-users")
|
||||
async def get_top_users(
|
||||
_current_user: Annotated[User, Depends(get_current_user)],
|
||||
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
|
||||
metric_type: Annotated[
|
||||
str,
|
||||
Query(
|
||||
description=(
|
||||
"Metric type: sounds_played, credits_used, tracks_added, "
|
||||
"tts_added, playlists_created"
|
||||
),
|
||||
),
|
||||
],
|
||||
period: Annotated[
|
||||
str,
|
||||
Query(
|
||||
description="Time period (today, 1_day, 1_week, 1_month, 1_year, all_time)",
|
||||
),
|
||||
] = "all_time",
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(description="Number of top users to return", ge=1, le=100),
|
||||
] = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get top users by metric for a specific period."""
|
||||
try:
|
||||
return await dashboard_service.get_top_users(
|
||||
metric_type=metric_type,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to fetch top users: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/top-sounds")
|
||||
async def get_top_sounds(
|
||||
_current_user: Annotated[User, Depends(get_current_user)],
|
||||
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
|
||||
sound_type: Annotated[
|
||||
str,
|
||||
Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
|
||||
],
|
||||
period: Annotated[
|
||||
str,
|
||||
Query(
|
||||
description="Time period (today, 1_day, 1_week, 1_month, 1_year, all_time)",
|
||||
),
|
||||
] = "all_time",
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(description="Number of top sounds to return", ge=1, le=100),
|
||||
] = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get top sounds by play count for a specific period."""
|
||||
try:
|
||||
return await dashboard_service.get_top_sounds(
|
||||
sound_type=sound_type,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to fetch top sounds: {e!s}",
|
||||
) from e
|
||||
248
app/api/v1/extractions.py
Normal file
248
app/api/v1/extractions.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Audio extraction API endpoints."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.user import User
|
||||
from app.services.extraction import ExtractionInfo, ExtractionService
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
|
||||
router = APIRouter(prefix="/extractions", tags=["extractions"])
|
||||
|
||||
|
||||
async def get_extraction_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> ExtractionService:
|
||||
"""Get the extraction service."""
|
||||
return ExtractionService(session)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_extraction(
|
||||
url: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> dict[str, ExtractionInfo | str]:
|
||||
"""Create a new extraction job for a URL."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
extraction_info = await extraction_service.create_extraction(
|
||||
url,
|
||||
current_user.id,
|
||||
)
|
||||
|
||||
# Queue the extraction for background processing
|
||||
await extraction_processor.queue_extraction(extraction_info["id"])
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create extraction: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"message": "Extraction queued successfully",
|
||||
"extraction": extraction_info,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/user")
|
||||
async def get_user_extractions( # noqa: PLR0913
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Query(description="Search in title, URL, or service"),
|
||||
] = None,
|
||||
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
|
||||
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
|
||||
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||
) -> dict:
|
||||
"""Get all extractions for the current user."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
result = await extraction_service.get_user_extractions(
|
||||
user_id=current_user.id,
|
||||
search=search,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
status_filter=status_filter,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extractions: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{extraction_id}")
|
||||
async def get_extraction(
|
||||
extraction_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> ExtractionInfo:
|
||||
"""Get extraction information by ID."""
|
||||
try:
|
||||
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
|
||||
|
||||
if not extraction_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Extraction {extraction_id} not found",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extraction: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return extraction_info
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_all_extractions( # noqa: PLR0913
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Query(description="Search in title, URL, or service"),
|
||||
] = None,
|
||||
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
|
||||
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
|
||||
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||
) -> dict:
|
||||
"""Get all extractions with optional filtering, search, and sorting."""
|
||||
try:
|
||||
result = await extraction_service.get_all_extractions(
|
||||
search=search,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
status_filter=status_filter,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extractions: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/processing/current")
|
||||
async def get_processing_extractions(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> list[ExtractionInfo]:
|
||||
"""Get all currently processing extractions for showing ongoing toasts."""
|
||||
try:
|
||||
# Get all extractions with processing status
|
||||
processing_extractions = await extraction_service.extraction_repo.get_by_status(
|
||||
"processing",
|
||||
)
|
||||
|
||||
# Convert to ExtractionInfo format
|
||||
result = []
|
||||
for extraction in processing_extractions:
|
||||
# Get user information
|
||||
user = await extraction_service.user_repo.get_by_id(extraction.user_id)
|
||||
user_name = user.name if user else None
|
||||
|
||||
extraction_info: ExtractionInfo = {
|
||||
"id": extraction.id or 0,
|
||||
"url": extraction.url,
|
||||
"service": extraction.service,
|
||||
"service_id": extraction.service_id,
|
||||
"title": extraction.title,
|
||||
"status": extraction.status,
|
||||
"error": extraction.error,
|
||||
"sound_id": extraction.sound_id,
|
||||
"user_id": extraction.user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": extraction.created_at.isoformat(),
|
||||
"updated_at": extraction.updated_at.isoformat(),
|
||||
}
|
||||
result.append(extraction_info)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get processing extractions: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{extraction_id}")
|
||||
async def delete_extraction(
|
||||
extraction_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Delete extraction and associated sound/files. Users can only delete their own."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
deleted = await extraction_service.delete_extraction(
|
||||
extraction_id, current_user.id,
|
||||
)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Extraction {extraction_id} not found",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete extraction: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"message": f"Extraction {extraction_id} deleted successfully",
|
||||
}
|
||||
197
app/api/v1/favorites.py
Normal file
197
app/api/v1/favorites.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Favorites management API endpoints."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
from app.core.database import get_session_factory
|
||||
from app.core.dependencies import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.favorite import (
|
||||
FavoriteCountsResponse,
|
||||
FavoriteResponse,
|
||||
FavoritesListResponse,
|
||||
)
|
||||
from app.services.favorite import FavoriteService
|
||||
|
||||
router = APIRouter(prefix="/favorites", tags=["favorites"])
|
||||
|
||||
|
||||
def get_favorite_service() -> FavoriteService:
|
||||
"""Get the favorite service."""
|
||||
return FavoriteService(get_session_factory())
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_user_favorites(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> FavoritesListResponse:
|
||||
"""Get all favorites for the current user."""
|
||||
favorites = await favorite_service.get_user_favorites(
|
||||
current_user.id,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
return FavoritesListResponse(favorites=favorites)
|
||||
|
||||
|
||||
@router.get("/sounds")
|
||||
async def get_user_sound_favorites(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> FavoritesListResponse:
|
||||
"""Get sound favorites for the current user."""
|
||||
favorites = await favorite_service.get_user_sound_favorites(
|
||||
current_user.id,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
return FavoritesListResponse(favorites=favorites)
|
||||
|
||||
|
||||
@router.get("/playlists")
|
||||
async def get_user_playlist_favorites(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> FavoritesListResponse:
|
||||
"""Get playlist favorites for the current user."""
|
||||
favorites = await favorite_service.get_user_playlist_favorites(
|
||||
current_user.id,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
return FavoritesListResponse(favorites=favorites)
|
||||
|
||||
|
||||
@router.get("/counts")
|
||||
async def get_favorite_counts(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> FavoriteCountsResponse:
|
||||
"""Get favorite counts for the current user."""
|
||||
counts = await favorite_service.get_favorite_counts(current_user.id)
|
||||
return FavoriteCountsResponse(**counts)
|
||||
|
||||
|
||||
@router.post("/sounds/{sound_id}")
|
||||
async def add_sound_favorite(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> FavoriteResponse:
|
||||
"""Add a sound to favorites."""
|
||||
try:
|
||||
favorite = await favorite_service.add_sound_favorite(current_user.id, sound_id)
|
||||
return FavoriteResponse.model_validate(favorite)
|
||||
except ValueError as e:
|
||||
if "not found" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
if "already favorited" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/playlists/{playlist_id}")
|
||||
async def add_playlist_favorite(
|
||||
playlist_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> FavoriteResponse:
|
||||
"""Add a playlist to favorites."""
|
||||
try:
|
||||
favorite = await favorite_service.add_playlist_favorite(
|
||||
current_user.id,
|
||||
playlist_id,
|
||||
)
|
||||
return FavoriteResponse.model_validate(favorite)
|
||||
except ValueError as e:
|
||||
if "not found" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
if "already favorited" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/sounds/{sound_id}")
|
||||
async def remove_sound_favorite(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> MessageResponse:
|
||||
"""Remove a sound from favorites."""
|
||||
try:
|
||||
await favorite_service.remove_sound_favorite(current_user.id, sound_id)
|
||||
return MessageResponse(message="Sound removed from favorites")
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/playlists/{playlist_id}")
|
||||
async def remove_playlist_favorite(
|
||||
playlist_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> MessageResponse:
|
||||
"""Remove a playlist from favorites."""
|
||||
try:
|
||||
await favorite_service.remove_playlist_favorite(current_user.id, playlist_id)
|
||||
return MessageResponse(message="Playlist removed from favorites")
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/sounds/{sound_id}/check")
|
||||
async def check_sound_favorited(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> dict[str, bool]:
|
||||
"""Check if a sound is favorited by the current user."""
|
||||
is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound_id)
|
||||
return {"is_favorited": is_favorited}
|
||||
|
||||
|
||||
@router.get("/playlists/{playlist_id}/check")
|
||||
async def check_playlist_favorited(
|
||||
playlist_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> dict[str, bool]:
|
||||
"""Check if a playlist is favorited by the current user."""
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id,
|
||||
playlist_id,
|
||||
)
|
||||
return {"is_favorited": is_favorited}
|
||||
153
app/api/v1/files.py
Normal file
153
app/api/v1/files.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""File serving API endpoints for audio files and thumbnails."""
|
||||
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.utils.audio import get_sound_file_path
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/files", tags=["files"])
|
||||
|
||||
|
||||
async def get_sound_repository(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> SoundRepository:
|
||||
"""Get the sound repository."""
|
||||
return SoundRepository(session)
|
||||
|
||||
|
||||
@router.get("/sounds/{sound_id}/download")
|
||||
async def download_sound(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
||||
) -> FileResponse:
|
||||
"""Download a sound file."""
|
||||
try:
|
||||
# Get the sound record
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Sound with ID {sound_id} not found",
|
||||
)
|
||||
|
||||
# Get the file path using the audio utility
|
||||
file_path = get_sound_file_path(sound)
|
||||
|
||||
# Determine filename based on normalization status
|
||||
if sound.is_normalized and sound.normalized_filename:
|
||||
filename = sound.normalized_filename
|
||||
else:
|
||||
filename = sound.filename
|
||||
|
||||
# Check if file exists
|
||||
if not file_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Sound file not found on disk",
|
||||
)
|
||||
|
||||
# Get MIME type
|
||||
mime_type, _ = mimetypes.guess_type(str(file_path))
|
||||
if not mime_type:
|
||||
mime_type = "audio/mpeg" # Default to MP3
|
||||
|
||||
logger.info(
|
||||
"Serving sound download: %s (user: %d, sound: %d)",
|
||||
filename,
|
||||
current_user.id,
|
||||
sound_id,
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=filename,
|
||||
media_type=mime_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error serving sound download for sound %d", sound_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to serve sound file",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/sounds/{sound_id}/thumbnail")
|
||||
async def get_sound_thumbnail(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
||||
) -> FileResponse:
|
||||
"""Get a sound's thumbnail image."""
|
||||
try:
|
||||
# Get the sound record
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Sound with ID {sound_id} not found",
|
||||
)
|
||||
|
||||
# Check if sound has a thumbnail
|
||||
if not sound.thumbnail:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No thumbnail available for this sound",
|
||||
)
|
||||
|
||||
# Get thumbnail file path
|
||||
thumbnail_path = Path(settings.EXTRACTION_THUMBNAILS_DIR) / sound.thumbnail
|
||||
|
||||
# Check if thumbnail file exists
|
||||
if not thumbnail_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Thumbnail file not found on disk",
|
||||
)
|
||||
|
||||
# Get MIME type
|
||||
mime_type, _ = mimetypes.guess_type(str(thumbnail_path))
|
||||
if not mime_type:
|
||||
mime_type = "image/jpeg" # Default to JPEG
|
||||
|
||||
logger.debug(
|
||||
"Serving thumbnail: %s (user: %d, sound: %d)",
|
||||
sound.thumbnail,
|
||||
current_user.id,
|
||||
sound_id,
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(thumbnail_path),
|
||||
media_type=mime_type,
|
||||
headers={
|
||||
"Cache-Control": "public, max-age=3600", # Cache for 1 hour
|
||||
"Content-Disposition": "inline", # Display inline, not download
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error serving thumbnail for sound %d", sound_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to serve thumbnail",
|
||||
) from e
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Main router for v1 endpoints."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.schemas.common import HealthResponse
|
||||
@@ -10,8 +11,77 @@ router = APIRouter()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.get("/health")
|
||||
def health() -> HealthResponse:
|
||||
"""Health check endpoint."""
|
||||
logger.info("Health check endpoint accessed")
|
||||
return HealthResponse(status="healthy")
|
||||
|
||||
|
||||
@router.get("/docs/scalar", response_class=HTMLResponse)
|
||||
def scalar_docs() -> HTMLResponse:
|
||||
"""Serve the API documentation using Scalar."""
|
||||
return """
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>API Documentation - Scalar</title>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
</head>
|
||||
<body>
|
||||
<script
|
||||
id="api-reference"
|
||||
data-url="http://localhost:8000/api/openapi.json"
|
||||
src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@router.get("/docs/rapidoc", response_class=HTMLResponse)
|
||||
async def rapidoc_docs() -> HTMLResponse:
|
||||
"""Serve the API documentation using Rapidoc."""
|
||||
return """
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>API Documentation - Rapidoc</title>
|
||||
<meta charset="utf-8">
|
||||
<script type="module" src="https://unpkg.com/rapidoc/dist/rapidoc-min.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<rapi-doc
|
||||
spec-url="http://localhost:8000/api/openapi.json"
|
||||
theme="dark"
|
||||
render-style="read">
|
||||
</rapi-doc>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@router.get("/docs/elements", response_class=HTMLResponse)
|
||||
async def elements_docs() -> HTMLResponse:
|
||||
"""Serve the API documentation using Stoplight Elements."""
|
||||
return """
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta
|
||||
name="viewport"
|
||||
content="width=device-width, initial-scale=1, shrink-to-fit=no"
|
||||
>
|
||||
<title>API Documentation - elements</title>
|
||||
<script src="https://unpkg.com/@stoplight/elements/web-components.min.js"></script>
|
||||
<link rel="stylesheet" href="https://unpkg.com/@stoplight/elements/styles.min.css">
|
||||
</head>
|
||||
<body>
|
||||
<elements-api
|
||||
apiDescriptionUrl="http://localhost:8000/api/openapi.json"
|
||||
router="hash"
|
||||
/>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Player API endpoints."""
|
||||
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
@@ -137,10 +137,10 @@ async def seek(
|
||||
"""Seek to specific position in current track."""
|
||||
try:
|
||||
player = get_player_service()
|
||||
await player.seek(request.position_ms)
|
||||
return MessageResponse(message=f"Seeked to position {request.position_ms}ms")
|
||||
await player.seek(request.position)
|
||||
return MessageResponse(message=f"Seeked to position {request.position}ms")
|
||||
except Exception as e:
|
||||
logger.exception("Error seeking to position %s", request.position_ms)
|
||||
logger.exception("Error seeking to position %s", request.position)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to seek",
|
||||
@@ -165,6 +165,40 @@ async def set_volume(
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/mute")
|
||||
async def mute(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
) -> MessageResponse:
|
||||
"""Mute playback."""
|
||||
try:
|
||||
player = get_player_service()
|
||||
await player.mute()
|
||||
return MessageResponse(message="Playback muted")
|
||||
except Exception as e:
|
||||
logger.exception("Error muting playback")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to mute playback",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/unmute")
|
||||
async def unmute(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
) -> MessageResponse:
|
||||
"""Unmute playback."""
|
||||
try:
|
||||
player = get_player_service()
|
||||
await player.unmute()
|
||||
return MessageResponse(message="Playback unmuted")
|
||||
except Exception as e:
|
||||
logger.exception("Error unmuting playback")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to unmute playback",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/mode")
|
||||
async def set_mode(
|
||||
request: PlayerModeRequest,
|
||||
@@ -214,4 +248,22 @@ async def get_state(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get player state",
|
||||
) from e
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/play-next/{sound_id}")
|
||||
async def add_to_play_next(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
) -> MessageResponse:
|
||||
"""Add a sound to the play next queue."""
|
||||
try:
|
||||
player = get_player_service()
|
||||
await player.add_to_play_next(sound_id)
|
||||
return MessageResponse(message=f"Added sound {sound_id} to play next queue")
|
||||
except Exception as e:
|
||||
logger.exception("Error adding sound to play next queue: %s", sound_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to add sound to play next queue",
|
||||
) from e
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Playlist management API endpoints."""
|
||||
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, get_session_factory
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.user import User
|
||||
from app.repositories.playlist import PlaylistSortField, SortOrder
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.playlist import (
|
||||
PlaylistAddSoundRequest,
|
||||
@@ -18,6 +19,7 @@ from app.schemas.playlist import (
|
||||
PlaylistStatsResponse,
|
||||
PlaylistUpdateRequest,
|
||||
)
|
||||
from app.services.favorite import FavoriteService
|
||||
from app.services.playlist import PlaylistService
|
||||
|
||||
router = APIRouter(prefix="/playlists", tags=["playlists"])
|
||||
@@ -30,44 +32,143 @@ async def get_playlist_service(
|
||||
return PlaylistService(session)
|
||||
|
||||
|
||||
def get_favorite_service() -> FavoriteService:
|
||||
"""Get the favorite service."""
|
||||
return FavoriteService(get_session_factory())
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_all_playlists(
|
||||
async def get_all_playlists( # noqa: PLR0913
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
) -> list[PlaylistResponse]:
|
||||
"""Get all playlists from all users."""
|
||||
playlists = await playlist_service.get_all_playlists()
|
||||
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Query(description="Search playlists by name"),
|
||||
] = None,
|
||||
sort_by: Annotated[
|
||||
PlaylistSortField | None,
|
||||
Query(description="Sort by field"),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
SortOrder,
|
||||
Query(description="Sort order (asc or desc)"),
|
||||
] = SortOrder.ASC,
|
||||
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||
favorites_only: Annotated[ # noqa: FBT002
|
||||
bool,
|
||||
Query(description="Show only favorited playlists"),
|
||||
] = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Get all playlists from all users with search and sorting."""
|
||||
result = await playlist_service.search_and_sort_playlists_paginated(
|
||||
search_query=search,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
user_id=None,
|
||||
include_stats=True,
|
||||
page=page,
|
||||
limit=limit,
|
||||
favorites_only=favorites_only,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
|
||||
# Convert to PlaylistResponse with favorite indicators
|
||||
playlist_responses = []
|
||||
for playlist_dict in result["playlists"]:
|
||||
# The playlist service returns dict, need to create playlist object structure
|
||||
playlist_id = playlist_dict["id"]
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id,
|
||||
playlist_id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id)
|
||||
|
||||
# Create a PlaylistResponse-like dict with proper datetime conversion
|
||||
playlist_response = {
|
||||
**playlist_dict,
|
||||
"created_at": (
|
||||
playlist_dict["created_at"].isoformat()
|
||||
if playlist_dict["created_at"]
|
||||
else None
|
||||
),
|
||||
"updated_at": (
|
||||
playlist_dict["updated_at"].isoformat()
|
||||
if playlist_dict["updated_at"]
|
||||
else None
|
||||
),
|
||||
"is_favorited": is_favorited,
|
||||
"favorite_count": favorite_count,
|
||||
}
|
||||
playlist_responses.append(playlist_response)
|
||||
|
||||
return {
|
||||
"playlists": playlist_responses,
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"limit": result["limit"],
|
||||
"total_pages": result["total_pages"],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/user")
|
||||
async def get_user_playlists(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> list[PlaylistResponse]:
|
||||
"""Get playlists for the current user only."""
|
||||
playlists = await playlist_service.get_user_playlists(current_user.id)
|
||||
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
|
||||
|
||||
# Add favorite indicators for each playlist
|
||||
playlist_responses = []
|
||||
for playlist in playlists:
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id,
|
||||
playlist.id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||
playlist_response = PlaylistResponse.from_playlist(
|
||||
playlist,
|
||||
is_favorited,
|
||||
favorite_count,
|
||||
)
|
||||
playlist_responses.append(playlist_response)
|
||||
|
||||
return playlist_responses
|
||||
|
||||
|
||||
@router.get("/main")
|
||||
async def get_main_playlist(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> PlaylistResponse:
|
||||
"""Get the global main playlist."""
|
||||
playlist = await playlist_service.get_main_playlist()
|
||||
return PlaylistResponse.from_playlist(playlist)
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id,
|
||||
playlist.id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||
|
||||
|
||||
@router.get("/current")
|
||||
async def get_current_playlist(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> PlaylistResponse:
|
||||
"""Get the user's current playlist (falls back to main playlist)."""
|
||||
playlist = await playlist_service.get_current_playlist(current_user.id)
|
||||
return PlaylistResponse.from_playlist(playlist)
|
||||
"""Get the global current playlist (falls back to main playlist)."""
|
||||
playlist = await playlist_service.get_current_playlist()
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id,
|
||||
playlist.id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@@ -91,10 +192,16 @@ async def get_playlist(
|
||||
playlist_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
) -> PlaylistResponse:
|
||||
"""Get a specific playlist."""
|
||||
playlist = await playlist_service.get_playlist_by_id(playlist_id)
|
||||
return PlaylistResponse.from_playlist(playlist)
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id,
|
||||
playlist.id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||
|
||||
|
||||
@router.put("/{playlist_id}")
|
||||
@@ -105,13 +212,19 @@ async def update_playlist(
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
) -> PlaylistResponse:
|
||||
"""Update a playlist."""
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
playlist = await playlist_service.update_playlist(
|
||||
playlist_id=playlist_id,
|
||||
user_id=current_user.id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
genre=request.genre,
|
||||
is_current=request.is_current,
|
||||
is_current=None, # is_current is not handled by this endpoint
|
||||
)
|
||||
return PlaylistResponse.from_playlist(playlist)
|
||||
|
||||
@@ -130,7 +243,7 @@ async def delete_playlist(
|
||||
@router.get("/search/{query}")
|
||||
async def search_playlists(
|
||||
query: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
) -> list[PlaylistResponse]:
|
||||
"""Search all playlists by name."""
|
||||
@@ -152,7 +265,7 @@ async def search_user_playlists(
|
||||
@router.get("/{playlist_id}/sounds")
|
||||
async def get_playlist_sounds(
|
||||
playlist_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
) -> list[PlaylistSoundResponse]:
|
||||
"""Get all sounds in a playlist."""
|
||||
@@ -212,28 +325,28 @@ async def reorder_playlist_sounds(
|
||||
@router.put("/{playlist_id}/set-current")
|
||||
async def set_current_playlist(
|
||||
playlist_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
) -> PlaylistResponse:
|
||||
"""Set a playlist as the current playlist."""
|
||||
playlist = await playlist_service.set_current_playlist(playlist_id, current_user.id)
|
||||
playlist = await playlist_service.set_current_playlist(playlist_id)
|
||||
return PlaylistResponse.from_playlist(playlist)
|
||||
|
||||
|
||||
@router.delete("/current")
|
||||
async def unset_current_playlist(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
) -> MessageResponse:
|
||||
"""Unset the current playlist."""
|
||||
await playlist_service.unset_current_playlist(current_user.id)
|
||||
await playlist_service.unset_current_playlist()
|
||||
return MessageResponse(message="Current playlist unset successfully")
|
||||
|
||||
|
||||
@router.get("/{playlist_id}/stats")
|
||||
async def get_playlist_stats(
|
||||
playlist_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||
) -> PlaylistStatsResponse:
|
||||
"""Get statistics for a playlist."""
|
||||
|
||||
230
app/api/v1/scheduler.py
Normal file
230
app/api/v1/scheduler.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""API endpoints for scheduled task management."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import (
|
||||
get_admin_user,
|
||||
get_current_active_user,
|
||||
)
|
||||
from app.core.services import get_global_scheduler_service
|
||||
from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
|
||||
from app.models.user import User
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
from app.schemas.scheduler import (
|
||||
ScheduledTaskCreate,
|
||||
ScheduledTaskResponse,
|
||||
ScheduledTaskUpdate,
|
||||
TaskFilterParams,
|
||||
)
|
||||
from app.services.scheduler import SchedulerService
|
||||
|
||||
router = APIRouter(prefix="/scheduler")
|
||||
|
||||
|
||||
def get_scheduler_service() -> SchedulerService:
|
||||
"""Get the global scheduler service instance."""
|
||||
return get_global_scheduler_service()
|
||||
|
||||
|
||||
def get_task_filters(
|
||||
status: Annotated[
|
||||
TaskStatus | None, Query(description="Filter by task status"),
|
||||
] = None,
|
||||
task_type: Annotated[
|
||||
TaskType | None, Query(description="Filter by task type"),
|
||||
] = None,
|
||||
limit: Annotated[int, Query(description="Maximum number of tasks to return")] = 50,
|
||||
offset: Annotated[int, Query(description="Number of tasks to skip")] = 0,
|
||||
) -> TaskFilterParams:
|
||||
"""Create task filter parameters from query parameters."""
|
||||
return TaskFilterParams(
|
||||
status=status,
|
||||
task_type=task_type,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tasks", response_model=ScheduledTaskResponse)
|
||||
async def create_task(
|
||||
task_data: ScheduledTaskCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
scheduler_service: Annotated[SchedulerService, Depends(get_scheduler_service)],
|
||||
) -> ScheduledTask:
|
||||
"""Create a new scheduled task."""
|
||||
try:
|
||||
return await scheduler_service.create_task(
|
||||
task_data=task_data,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_user_tasks(
|
||||
filters: Annotated[TaskFilterParams, Depends(get_task_filters)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
scheduler_service: Annotated[SchedulerService, Depends(get_scheduler_service)],
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get user's scheduled tasks."""
|
||||
return await scheduler_service.get_user_tasks(
|
||||
user_id=current_user.id,
|
||||
status=filters.status,
|
||||
task_type=filters.task_type,
|
||||
limit=filters.limit,
|
||||
offset=filters.offset,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=ScheduledTaskResponse)
|
||||
async def get_task(
|
||||
task_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> ScheduledTask:
|
||||
"""Get a specific scheduled task."""
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Check if user owns the task or is admin
|
||||
if task.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.patch("/tasks/{task_id}", response_model=ScheduledTaskResponse)
|
||||
async def update_task(
|
||||
task_id: int,
|
||||
task_update: ScheduledTaskUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> ScheduledTask:
|
||||
"""Update a scheduled task."""
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Check if user owns the task or is admin
|
||||
if task.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Update task fields
|
||||
update_data = task_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(task, field, value)
|
||||
|
||||
return await repo.update(task)
|
||||
|
||||
|
||||
@router.delete("/tasks/{task_id}")
|
||||
async def delete_task(
|
||||
task_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
|
||||
scheduler_service: Annotated[
|
||||
SchedulerService, Depends(get_scheduler_service),
|
||||
] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> dict:
|
||||
"""Delete a scheduled task completely."""
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Check if user owns the task or is admin
|
||||
if task.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
success = await scheduler_service.delete_task(task_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to delete task")
|
||||
|
||||
return {"message": "Task deleted successfully"}
|
||||
|
||||
|
||||
# Admin-only endpoints
|
||||
@router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_all_tasks(
|
||||
status: Annotated[
|
||||
TaskStatus | None, Query(description="Filter by task status"),
|
||||
] = None,
|
||||
task_type: Annotated[
|
||||
TaskType | None, Query(description="Filter by task type"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int | None, Query(description="Maximum number of tasks to return"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int | None, Query(description="Number of tasks to skip"),
|
||||
] = 0,
|
||||
_: Annotated[User, Depends(get_admin_user)] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get all scheduled tasks (admin only)."""
|
||||
# Build query with pagination and filtering
|
||||
|
||||
statement = select(ScheduledTask)
|
||||
|
||||
if status:
|
||||
statement = statement.where(ScheduledTask.status == status)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await db_session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
|
||||
@router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_system_tasks(
|
||||
status: Annotated[
|
||||
TaskStatus | None, Query(description="Filter by task status"),
|
||||
] = None,
|
||||
task_type: Annotated[
|
||||
TaskType | None, Query(description="Filter by task type"),
|
||||
] = None,
|
||||
_: Annotated[User, Depends(get_admin_user)] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get system tasks (admin only)."""
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
return await repo.get_system_tasks(status=status, task_type=task_type)
|
||||
|
||||
|
||||
@router.post("/admin/system-tasks", response_model=ScheduledTaskResponse)
|
||||
async def create_system_task(
|
||||
task_data: ScheduledTaskCreate,
|
||||
_: Annotated[User, Depends(get_admin_user)] = ...,
|
||||
scheduler_service: Annotated[
|
||||
SchedulerService, Depends(get_scheduler_service),
|
||||
] = ...,
|
||||
) -> ScheduledTask:
|
||||
"""Create a system task (admin only)."""
|
||||
try:
|
||||
return await scheduler_service.create_task(
|
||||
task_data=task_data,
|
||||
user_id=None, # System task
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Socket.IO API endpoints for WebSocket management."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
@@ -10,7 +12,9 @@ router = APIRouter(prefix="/socket", tags=["socket"])
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_socket_status(current_user: User = Depends(get_current_user)):
|
||||
async def get_socket_status(
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> dict[str, int | bool]:
|
||||
"""Get current socket connection status."""
|
||||
connected_users = socket_manager.get_connected_users()
|
||||
|
||||
@@ -25,8 +29,8 @@ async def get_socket_status(current_user: User = Depends(get_current_user)):
|
||||
async def send_message_to_user(
|
||||
target_user_id: int,
|
||||
message: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> dict[str, int | bool | str]:
|
||||
"""Send a message to a specific user via WebSocket."""
|
||||
success = await socket_manager.send_to_user(
|
||||
str(target_user_id),
|
||||
@@ -48,8 +52,8 @@ async def send_message_to_user(
|
||||
@router.post("/broadcast")
|
||||
async def broadcast_message(
|
||||
message: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> dict[str, bool | str]:
|
||||
"""Broadcast a message to all connected users."""
|
||||
await socket_manager.broadcast_to_all(
|
||||
"broadcast_message",
|
||||
|
||||
@@ -5,52 +5,25 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, get_session_factory
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.models.user import User
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.services.extraction import ExtractionInfo, ExtractionService
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
|
||||
from app.services.sound_scanner import ScanResults, SoundScannerService
|
||||
from app.services.vlc_player import get_vlc_player_service, VLCPlayerService
|
||||
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
|
||||
from app.schemas.sound import SoundResponse, SoundsListResponse
|
||||
from app.services.favorite import FavoriteService
|
||||
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
||||
|
||||
router = APIRouter(prefix="/sounds", tags=["sounds"])
|
||||
|
||||
|
||||
async def get_sound_scanner_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> SoundScannerService:
|
||||
"""Get the sound scanner service."""
|
||||
return SoundScannerService(session)
|
||||
|
||||
|
||||
async def get_sound_normalizer_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> SoundNormalizerService:
|
||||
"""Get the sound normalizer service."""
|
||||
return SoundNormalizerService(session)
|
||||
|
||||
|
||||
async def get_extraction_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> ExtractionService:
|
||||
"""Get the extraction service."""
|
||||
return ExtractionService(session)
|
||||
|
||||
|
||||
def get_vlc_player() -> VLCPlayerService:
|
||||
"""Get the VLC player service."""
|
||||
from app.core.database import get_session_factory
|
||||
return get_vlc_player_service(get_session_factory())
|
||||
|
||||
|
||||
def get_credit_service() -> CreditService:
|
||||
"""Get the credit service."""
|
||||
from app.core.database import get_session_factory
|
||||
return CreditService(get_session_factory())
|
||||
def get_favorite_service() -> FavoriteService:
|
||||
"""Get the favorite service."""
|
||||
return FavoriteService(get_session_factory())
|
||||
|
||||
|
||||
async def get_sound_repository(
|
||||
@@ -60,377 +33,92 @@ async def get_sound_repository(
|
||||
return SoundRepository(session)
|
||||
|
||||
|
||||
# SCAN
|
||||
@router.post("/scan")
|
||||
async def scan_sounds(
|
||||
@router.get("/")
|
||||
async def get_sounds( # noqa: PLR0913
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
|
||||
) -> dict[str, ScanResults | str]:
|
||||
"""Sync the soundboard directory (add/update/delete sounds)."""
|
||||
# Only allow admins to scan sounds
|
||||
if current_user.role not in ["admin", "superadmin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only administrators can sync sounds",
|
||||
)
|
||||
|
||||
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
types: Annotated[
|
||||
list[str] | None,
|
||||
Query(description="Filter by sound types (e.g., SDB, TTS, EXT)"),
|
||||
] = None,
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Query(description="Search sounds by name"),
|
||||
] = None,
|
||||
sort_by: Annotated[
|
||||
SoundSortField | None,
|
||||
Query(description="Sort by field"),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
SortOrder,
|
||||
Query(description="Sort order (asc or desc)"),
|
||||
] = SortOrder.ASC,
|
||||
limit: Annotated[
|
||||
int | None,
|
||||
Query(description="Maximum number of results", ge=1, le=1000),
|
||||
] = None,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(description="Number of results to skip", ge=0),
|
||||
] = 0,
|
||||
favorites_only: Annotated[ # noqa: FBT002
|
||||
bool,
|
||||
Query(description="Show only favorited sounds"),
|
||||
] = False,
|
||||
) -> SoundsListResponse:
|
||||
"""Get sounds with optional search, filtering, and sorting."""
|
||||
try:
|
||||
results = await scanner_service.scan_soundboard_directory()
|
||||
return {
|
||||
"message": "Sound sync completed",
|
||||
"results": results,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to sync sounds: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/scan/custom")
|
||||
async def scan_custom_directory(
|
||||
directory: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
|
||||
sound_type: str = "SDB",
|
||||
) -> dict[str, ScanResults | str]:
|
||||
"""Sync a custom directory with the database (add/update/delete sounds)."""
|
||||
# Only allow admins to sync sounds
|
||||
if current_user.role not in ["admin", "superadmin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only administrators can sync sounds",
|
||||
sounds = await sound_repo.search_and_sort(
|
||||
search_query=search,
|
||||
sound_types=types,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
favorites_only=favorites_only,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
try:
|
||||
results = await scanner_service.scan_directory(directory, sound_type)
|
||||
return {
|
||||
"message": f"Sync of directory '{directory}' completed",
|
||||
"results": results,
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to sync directory: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
# NORMALIZE
|
||||
@router.post("/normalize/all")
|
||||
async def normalize_all_sounds(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
normalizer_service: Annotated[
|
||||
SoundNormalizerService, Depends(get_sound_normalizer_service)
|
||||
],
|
||||
force: bool = Query(
|
||||
False, description="Force normalization of already normalized sounds"
|
||||
),
|
||||
one_pass: bool | None = Query(
|
||||
None, description="Use one-pass normalization (overrides config)"
|
||||
),
|
||||
) -> dict[str, NormalizationResults | str]:
|
||||
"""Normalize all unnormalized sounds."""
|
||||
# Only allow admins to normalize sounds
|
||||
if current_user.role not in ["admin", "superadmin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only administrators can normalize sounds",
|
||||
)
|
||||
|
||||
try:
|
||||
results = await normalizer_service.normalize_all_sounds(
|
||||
force=force,
|
||||
one_pass=one_pass,
|
||||
)
|
||||
return {
|
||||
"message": "Sound normalization completed",
|
||||
"results": results,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to normalize sounds: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/normalize/type/{sound_type}")
|
||||
async def normalize_sounds_by_type(
|
||||
sound_type: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
normalizer_service: Annotated[
|
||||
SoundNormalizerService, Depends(get_sound_normalizer_service)
|
||||
],
|
||||
force: bool = Query(
|
||||
False, description="Force normalization of already normalized sounds"
|
||||
),
|
||||
one_pass: bool | None = Query(
|
||||
None, description="Use one-pass normalization (overrides config)"
|
||||
),
|
||||
) -> dict[str, NormalizationResults | str]:
|
||||
"""Normalize all sounds of a specific type (SDB, TTS, EXT)."""
|
||||
# Only allow admins to normalize sounds
|
||||
if current_user.role not in ["admin", "superadmin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only administrators can normalize sounds",
|
||||
)
|
||||
|
||||
# Validate sound type
|
||||
valid_types = ["SDB", "TTS", "EXT"]
|
||||
if sound_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid sound type. Must be one of: {', '.join(valid_types)}",
|
||||
)
|
||||
|
||||
try:
|
||||
results = await normalizer_service.normalize_sounds_by_type(
|
||||
sound_type=sound_type,
|
||||
force=force,
|
||||
one_pass=one_pass,
|
||||
)
|
||||
return {
|
||||
"message": f"Normalization of {sound_type} sounds completed",
|
||||
"results": results,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to normalize {sound_type} sounds: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/normalize/{sound_id}")
|
||||
async def normalize_sound_by_id(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
normalizer_service: Annotated[
|
||||
SoundNormalizerService, Depends(get_sound_normalizer_service)
|
||||
],
|
||||
force: bool = Query(
|
||||
False, description="Force normalization of already normalized sound"
|
||||
),
|
||||
one_pass: bool | None = Query(
|
||||
None, description="Use one-pass normalization (overrides config)"
|
||||
),
|
||||
) -> dict[str, str]:
|
||||
"""Normalize a specific sound by ID."""
|
||||
# Only allow admins to normalize sounds
|
||||
if current_user.role not in ["admin", "superadmin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only administrators can normalize sounds",
|
||||
)
|
||||
|
||||
try:
|
||||
# Get the sound
|
||||
sound = await normalizer_service.sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Sound with ID {sound_id} not found",
|
||||
# Add favorite indicators for each sound
|
||||
sound_responses = []
|
||||
for sound in sounds:
|
||||
is_favorited = await favorite_service.is_sound_favorited(
|
||||
current_user.id,
|
||||
sound.id,
|
||||
)
|
||||
|
||||
# Normalize the sound
|
||||
result = await normalizer_service.normalize_sound(
|
||||
sound=sound,
|
||||
force=force,
|
||||
one_pass=one_pass,
|
||||
)
|
||||
|
||||
# Check result status
|
||||
if result["status"] == "error":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to normalize sound: {result['error']}",
|
||||
favorite_count = await favorite_service.get_sound_favorite_count(sound.id)
|
||||
sound_response = SoundResponse.from_sound(
|
||||
sound,
|
||||
is_favorited,
|
||||
favorite_count,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Sound normalization {result['status']}: {sound.filename}",
|
||||
"status": result["status"],
|
||||
"reason": result["reason"] or "",
|
||||
"normalized_filename": result["normalized_filename"] or "",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to normalize sound: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
# EXTRACT
|
||||
@router.post("/extract")
|
||||
async def create_extraction(
|
||||
url: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> dict[str, ExtractionInfo | str]:
|
||||
"""Create a new extraction job for a URL."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
extraction_info = await extraction_service.create_extraction(
|
||||
url, current_user.id
|
||||
)
|
||||
|
||||
# Queue the extraction for background processing
|
||||
await extraction_processor.queue_extraction(extraction_info["id"])
|
||||
|
||||
return {
|
||||
"message": "Extraction queued successfully",
|
||||
"extraction": extraction_info,
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create extraction: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/extract/status")
|
||||
async def get_extraction_processor_status(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
) -> dict:
|
||||
"""Get the status of the extraction processor."""
|
||||
# Only allow admins to see processor status
|
||||
if current_user.role not in ["admin", "superadmin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only administrators can view processor status",
|
||||
)
|
||||
|
||||
return extraction_processor.get_status()
|
||||
|
||||
|
||||
@router.get("/extract/{extraction_id}")
|
||||
async def get_extraction(
|
||||
extraction_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> ExtractionInfo:
|
||||
"""Get extraction information by ID."""
|
||||
try:
|
||||
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
|
||||
|
||||
if not extraction_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Extraction {extraction_id} not found",
|
||||
)
|
||||
|
||||
return extraction_info
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extraction: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/extract")
|
||||
async def get_user_extractions(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> dict[str, list[ExtractionInfo]]:
|
||||
"""Get all extractions for the current user."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
extractions = await extraction_service.get_user_extractions(current_user.id)
|
||||
|
||||
return {
|
||||
"extractions": extractions,
|
||||
}
|
||||
sound_responses.append(sound_response)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extractions: {e!s}",
|
||||
detail=f"Failed to get sounds: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return SoundsListResponse(sounds=sound_responses)
|
||||
|
||||
|
||||
# VLC PLAYER
|
||||
@router.post("/vlc/play/{sound_id}")
|
||||
@router.post("/play/{sound_id}")
|
||||
async def play_sound_with_vlc(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
|
||||
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
||||
credit_service: Annotated[CreditService, Depends(get_credit_service)],
|
||||
) -> dict[str, str | int | bool]:
|
||||
"""Play a sound using VLC subprocess (requires 1 credit)."""
|
||||
try:
|
||||
# Get the sound
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
if not current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Sound with ID {sound_id} not found",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID is required",
|
||||
)
|
||||
|
||||
# Check and validate credits before playing
|
||||
try:
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
current_user.id,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
{"sound_id": sound_id, "sound_name": sound.name}
|
||||
)
|
||||
except InsufficientCreditsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Insufficient credits: {e.required} required, {e.available} available",
|
||||
) from e
|
||||
|
||||
# Play the sound using VLC
|
||||
success = await vlc_player.play_sound(sound)
|
||||
|
||||
# Deduct credits based on success
|
||||
await credit_service.deduct_credits(
|
||||
current_user.id,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
success,
|
||||
{"sound_id": sound_id, "sound_name": sound.name},
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to launch VLC for sound playback",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Sound '{sound.name}' is now playing via VLC",
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound.name,
|
||||
"success": True,
|
||||
"credits_deducted": 1,
|
||||
}
|
||||
|
||||
return await vlc_player.play_sound_with_credits(sound_id, current_user.id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -440,16 +128,14 @@ async def play_sound_with_vlc(
|
||||
) from e
|
||||
|
||||
|
||||
|
||||
@router.post("/vlc/stop-all")
|
||||
@router.post("/stop")
|
||||
async def stop_all_vlc_instances(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
|
||||
) -> dict:
|
||||
"""Stop all running VLC instances."""
|
||||
try:
|
||||
result = await vlc_player.stop_all_vlc_instances()
|
||||
return result
|
||||
return await vlc_player.stop_all_vlc_instances()
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
||||
225
app/api/v1/tts.py
Normal file
225
app/api/v1/tts.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""TTS API endpoints."""
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.user import User
|
||||
from app.services.tts import TTSService
|
||||
|
||||
router = APIRouter(prefix="/tts", tags=["tts"])
|
||||
|
||||
|
||||
class TTSGenerateRequest(BaseModel):
|
||||
"""TTS generation request model."""
|
||||
|
||||
text: str = Field(
|
||||
..., min_length=1, max_length=1000, description="Text to convert to speech",
|
||||
)
|
||||
provider: str = Field(default="gtts", description="TTS provider to use")
|
||||
options: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Provider-specific options",
|
||||
)
|
||||
|
||||
|
||||
class TTSResponse(BaseModel):
|
||||
"""TTS generation response model."""
|
||||
|
||||
id: int
|
||||
text: str
|
||||
provider: str
|
||||
options: dict[str, Any]
|
||||
status: str
|
||||
error: str | None
|
||||
sound_id: int | None
|
||||
user_id: int
|
||||
created_at: str
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Provider information model."""
|
||||
|
||||
name: str
|
||||
file_extension: str
|
||||
supported_languages: list[str]
|
||||
option_schema: dict[str, Any]
|
||||
|
||||
|
||||
async def get_tts_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> TTSService:
|
||||
"""Get the TTS service."""
|
||||
return TTSService(session)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_tts_list(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[TTSResponse]:
|
||||
"""Get TTS list for the current user."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
tts_records = await tts_service.get_user_tts_history(
|
||||
user_id=current_user.id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return [
|
||||
TTSResponse(
|
||||
id=tts.id,
|
||||
text=tts.text,
|
||||
provider=tts.provider,
|
||||
options=tts.options,
|
||||
status=tts.status,
|
||||
error=tts.error,
|
||||
sound_id=tts.sound_id,
|
||||
user_id=tts.user_id,
|
||||
created_at=tts.created_at.isoformat(),
|
||||
)
|
||||
for tts in tts_records
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get TTS history: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def generate_tts(
|
||||
request: TTSGenerateRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
) -> dict[str, Any]:
|
||||
"""Generate TTS audio and create sound."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
result = await tts_service.create_tts_request(
|
||||
text=request.text,
|
||||
user_id=current_user.id,
|
||||
provider=request.provider,
|
||||
**request.options,
|
||||
)
|
||||
|
||||
tts_record = result["tts"]
|
||||
|
||||
return {
|
||||
"message": result["message"],
|
||||
"tts": TTSResponse(
|
||||
id=tts_record.id,
|
||||
text=tts_record.text,
|
||||
provider=tts_record.provider,
|
||||
options=tts_record.options,
|
||||
status=tts_record.status,
|
||||
error=tts_record.error,
|
||||
sound_id=tts_record.sound_id,
|
||||
user_id=tts_record.user_id,
|
||||
created_at=tts_record.created_at.isoformat(),
|
||||
),
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate TTS: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def get_providers(
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
) -> dict[str, ProviderInfo]:
|
||||
"""Get all available TTS providers."""
|
||||
providers = tts_service.get_providers()
|
||||
result = {}
|
||||
|
||||
for name, provider in providers.items():
|
||||
result[name] = ProviderInfo(
|
||||
name=provider.name,
|
||||
file_extension=provider.file_extension,
|
||||
supported_languages=provider.get_supported_languages(),
|
||||
option_schema=provider.get_option_schema(),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/providers/{provider_name}")
|
||||
async def get_provider(
|
||||
provider_name: str,
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
) -> ProviderInfo:
|
||||
"""Get information about a specific TTS provider."""
|
||||
provider = tts_service.get_provider(provider_name)
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_name}' not found",
|
||||
)
|
||||
|
||||
return ProviderInfo(
|
||||
name=provider.name,
|
||||
file_extension=provider.file_extension,
|
||||
supported_languages=provider.get_supported_languages(),
|
||||
option_schema=provider.get_option_schema(),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{tts_id}")
|
||||
async def delete_tts(
|
||||
tts_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Delete a TTS generation and its associated files."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete TTS: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {"message": "TTS generation deleted successfully"}
|
||||
@@ -18,6 +18,16 @@ class Settings(BaseSettings):
|
||||
PORT: int = 8000
|
||||
RELOAD: bool = True
|
||||
|
||||
# Production URLs (for reverse proxy deployment)
|
||||
FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production
|
||||
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
|
||||
|
||||
# CORS Configuration
|
||||
CORS_ORIGINS: list[str] = [
|
||||
"http://localhost:8001", # Frontend development
|
||||
"chrome-extension://*", # Chrome extensions
|
||||
]
|
||||
|
||||
# Database Configuration
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
|
||||
DATABASE_ECHO: bool = False
|
||||
@@ -30,7 +40,9 @@ class Settings(BaseSettings):
|
||||
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
# JWT Configuration
|
||||
JWT_SECRET_KEY: str = "your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
|
||||
JWT_SECRET_KEY: str = (
|
||||
"your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
|
||||
)
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
@@ -38,6 +50,7 @@ class Settings(BaseSettings):
|
||||
# Cookie Configuration
|
||||
COOKIE_SECURE: bool = True
|
||||
COOKIE_SAMESITE: Literal["strict", "lax", "none"] = "lax"
|
||||
COOKIE_DOMAIN: str | None = "localhost" # Cookie domain (None for production)
|
||||
|
||||
# OAuth2 Configuration
|
||||
GOOGLE_CLIENT_ID: str = ""
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
|
||||
from alembic.config import Config
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
# Import all models to ensure SQLModel metadata discovery
|
||||
import app.models # noqa: F401
|
||||
from alembic import command
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.seeds import seed_all_data
|
||||
from app.models import ( # noqa: F401
|
||||
extraction,
|
||||
plan,
|
||||
playlist,
|
||||
playlist_sound,
|
||||
sound,
|
||||
sound_played,
|
||||
user,
|
||||
user_oauth,
|
||||
)
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
@@ -24,7 +17,7 @@ engine: AsyncEngine = create_async_engine(
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async def get_db() -> AsyncGenerator[AsyncSession]:
|
||||
"""Get a database session for dependency injection."""
|
||||
logger = get_logger(__name__)
|
||||
async with AsyncSession(engine) as session:
|
||||
@@ -38,34 +31,33 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
await session.close()
|
||||
|
||||
|
||||
def get_session_factory():
|
||||
def get_session_factory() -> Callable[[], AsyncSession]:
|
||||
"""Get a session factory function for services."""
|
||||
def session_factory():
|
||||
|
||||
def session_factory() -> AsyncSession:
|
||||
return AsyncSession(engine)
|
||||
|
||||
return session_factory
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Initialize the database and create tables if they do not exist."""
|
||||
"""Initialize the database using Alembic migrations."""
|
||||
logger = get_logger(__name__)
|
||||
try:
|
||||
logger.info("Initializing database tables")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
logger.info("Database tables created successfully")
|
||||
logger.info("Running database migrations")
|
||||
# Run Alembic migrations programmatically
|
||||
|
||||
# Seed initial data
|
||||
await seed_initial_data()
|
||||
# Get the alembic config
|
||||
alembic_cfg = Config("alembic.ini")
|
||||
|
||||
# Run migrations to the latest revision in a thread pool to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, command.upgrade, alembic_cfg, "head",
|
||||
)
|
||||
logger.info("Database migrations completed successfully")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize database")
|
||||
raise
|
||||
|
||||
|
||||
async def seed_initial_data() -> None:
|
||||
"""Seed initial data into the database."""
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting initial data seeding")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
await seed_all_data(session)
|
||||
|
||||
@@ -8,7 +8,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.core.database import get_db
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.repositories.user import UserRepository
|
||||
from app.services.auth import AuthService
|
||||
from app.services.dashboard import DashboardService
|
||||
from app.services.oauth import OAuthService
|
||||
from app.utils.auth import JWTUtils, TokenUtils
|
||||
|
||||
@@ -135,8 +138,6 @@ async def get_current_user_api_token(
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
@@ -146,6 +147,8 @@ async def get_current_user_api_token(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate API token",
|
||||
) from e
|
||||
else:
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_flexible(
|
||||
@@ -184,3 +187,12 @@ async def get_admin_user(
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_dashboard_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> DashboardService:
|
||||
"""Get the dashboard service."""
|
||||
sound_repository = SoundRepository(session)
|
||||
user_repository = UserRepository(session)
|
||||
return DashboardService(sound_repository, user_repository)
|
||||
|
||||
23
app/core/services.py
Normal file
23
app/core/services.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Global services container to avoid circular imports."""
|
||||
|
||||
from app.services.scheduler import SchedulerService
|
||||
|
||||
|
||||
class AppServices:
|
||||
"""Container for application services."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the application services container."""
|
||||
self.scheduler_service: SchedulerService | None = None
|
||||
|
||||
|
||||
# Global service container
|
||||
app_services = AppServices()
|
||||
|
||||
|
||||
def get_global_scheduler_service() -> SchedulerService:
|
||||
"""Get the global scheduler service instance."""
|
||||
if app_services.scheduler_service is None:
|
||||
msg = "Scheduler service not initialized"
|
||||
raise RuntimeError(msg)
|
||||
return app_services.scheduler_service
|
||||
58
app/main.py
58
app/main.py
@@ -6,53 +6,93 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import api_router
|
||||
from app.core.database import get_session_factory, init_db
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_session_factory
|
||||
from app.core.logging import get_logger, setup_logging
|
||||
from app.core.services import app_services
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
from app.services.player import initialize_player_service, shutdown_player_service
|
||||
from app.services.player import (
|
||||
get_player_service,
|
||||
initialize_player_service,
|
||||
shutdown_player_service,
|
||||
)
|
||||
from app.services.scheduler import SchedulerService
|
||||
from app.services.socket import socket_manager
|
||||
from app.services.tts_processor import tts_processor
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
||||
"""Application lifespan context manager for setup and teardown."""
|
||||
setup_logging()
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting application")
|
||||
|
||||
await init_db()
|
||||
logger.info("Database initialized")
|
||||
|
||||
# Start the extraction processor
|
||||
await extraction_processor.start()
|
||||
logger.info("Extraction processor started")
|
||||
|
||||
# Start the TTS processor
|
||||
await tts_processor.start()
|
||||
logger.info("TTS processor started")
|
||||
|
||||
# Start the player service
|
||||
await initialize_player_service(get_session_factory())
|
||||
logger.info("Player service started")
|
||||
|
||||
# Start the scheduler service
|
||||
try:
|
||||
player_service = get_player_service() # Get the initialized player service
|
||||
app_services.scheduler_service = SchedulerService(
|
||||
get_session_factory(),
|
||||
player_service,
|
||||
)
|
||||
await app_services.scheduler_service.start()
|
||||
logger.info("Enhanced scheduler service started")
|
||||
except Exception:
|
||||
logger.exception("Failed to start scheduler service - continuing without it")
|
||||
app_services.scheduler_service = None
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down application")
|
||||
|
||||
# Stop the scheduler service
|
||||
if app_services.scheduler_service:
|
||||
await app_services.scheduler_service.stop()
|
||||
logger.info("Scheduler service stopped")
|
||||
|
||||
# Stop the player service
|
||||
await shutdown_player_service()
|
||||
logger.info("Player service stopped")
|
||||
|
||||
# Stop the TTS processor
|
||||
await tts_processor.stop()
|
||||
logger.info("TTS processor stopped")
|
||||
|
||||
# Stop the extraction processor
|
||||
await extraction_processor.stop()
|
||||
logger.info("Extraction processor stopped")
|
||||
|
||||
|
||||
def create_app():
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app = FastAPI(
|
||||
title="SBD v2 API",
|
||||
description=("API for the SBD v2 application"),
|
||||
version="2.0.0",
|
||||
lifespan=lifespan,
|
||||
# Configure docs URLs for reverse proxy setup
|
||||
docs_url="/api/docs", # Swagger UI at /api/docs
|
||||
redoc_url="/api/redoc", # ReDoc at /api/redoc
|
||||
openapi_url="/api/openapi.json", # OpenAPI schema at /api/openapi.json
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:8001"],
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -1 +1,34 @@
|
||||
"""Models package."""
|
||||
|
||||
# Import all models for SQLAlchemy metadata discovery
|
||||
from .base import BaseModel
|
||||
from .credit_action import CreditAction
|
||||
from .credit_transaction import CreditTransaction
|
||||
from .extraction import Extraction
|
||||
from .favorite import Favorite
|
||||
from .plan import Plan
|
||||
from .playlist import Playlist
|
||||
from .playlist_sound import PlaylistSound
|
||||
from .scheduled_task import ScheduledTask
|
||||
from .sound import Sound
|
||||
from .sound_played import SoundPlayed
|
||||
from .tts import TTS
|
||||
from .user import User
|
||||
from .user_oauth import UserOauth
|
||||
|
||||
__all__ = [
|
||||
"TTS",
|
||||
"BaseModel",
|
||||
"CreditAction",
|
||||
"CreditTransaction",
|
||||
"Extraction",
|
||||
"Favorite",
|
||||
"Plan",
|
||||
"Playlist",
|
||||
"PlaylistSound",
|
||||
"ScheduledTask",
|
||||
"Sound",
|
||||
"SoundPlayed",
|
||||
"User",
|
||||
"UserOauth",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.orm import Mapper
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
@@ -11,3 +15,14 @@ class BaseModel(SQLModel):
|
||||
# timestamps
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
# SQLAlchemy event listener to automatically update updated_at timestamp
|
||||
@event.listens_for(BaseModel, "before_update", propagate=True)
|
||||
def update_timestamp(
|
||||
mapper: Mapper[Any], # noqa: ARG001
|
||||
connection: Connection, # noqa: ARG001
|
||||
target: BaseModel,
|
||||
) -> None:
|
||||
"""Automatically set updated_at timestamp before update operations."""
|
||||
target.updated_at = datetime.now(UTC)
|
||||
|
||||
@@ -13,6 +13,7 @@ class CreditActionType(str, Enum):
|
||||
SOUND_NORMALIZATION = "sound_normalization"
|
||||
API_REQUEST = "api_request"
|
||||
PLAYLIST_CREATION = "playlist_creation"
|
||||
DAILY_RECHARGE = "daily_recharge"
|
||||
|
||||
|
||||
class CreditAction:
|
||||
@@ -92,6 +93,12 @@ CREDIT_ACTIONS = {
|
||||
description="Create a new playlist",
|
||||
requires_success=True,
|
||||
),
|
||||
CreditActionType.DAILY_RECHARGE: CreditAction(
|
||||
action_type=CreditActionType.DAILY_RECHARGE,
|
||||
cost=0, # This is a credit addition, not deduction
|
||||
description="Daily credit recharge",
|
||||
requires_success=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -118,4 +125,4 @@ def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
|
||||
Dictionary of all credit actions
|
||||
|
||||
"""
|
||||
return CREDIT_ACTIONS.copy()
|
||||
return CREDIT_ACTIONS.copy()
|
||||
|
||||
@@ -17,7 +17,8 @@ class CreditTransaction(BaseModel, table=True):
|
||||
|
||||
user_id: int = Field(foreign_key="user.id", nullable=False)
|
||||
action_type: str = Field(nullable=False)
|
||||
amount: int = Field(nullable=False) # Negative for deductions, positive for additions
|
||||
# Negative for deductions, positive for additions
|
||||
amount: int = Field(nullable=False)
|
||||
balance_before: int = Field(nullable=False)
|
||||
balance_after: int = Field(nullable=False)
|
||||
description: str = Field(nullable=False)
|
||||
@@ -26,4 +27,4 @@ class CreditTransaction(BaseModel, table=True):
|
||||
metadata_json: str | None = Field(default=None)
|
||||
|
||||
# relationships
|
||||
user: "User" = Relationship(back_populates="credit_transactions")
|
||||
user: "User" = Relationship(back_populates="credit_transactions")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
@@ -25,7 +25,8 @@ class Extraction(BaseModel, table=True):
|
||||
status: str = Field(nullable=False, default="pending")
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
# constraints - only enforce uniqueness when both service and service_id are not null
|
||||
# constraints - only enforce uniqueness when both service and service_id
|
||||
# are not null
|
||||
__table_args__ = ()
|
||||
|
||||
# relationships
|
||||
|
||||
29
app/models/favorite.py
Normal file
29
app/models/favorite.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.playlist import Playlist
|
||||
from app.models.sound import Sound
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class Favorite(BaseModel, table=True):
|
||||
"""Database model for user favorites (sounds and playlists)."""
|
||||
|
||||
user_id: int = Field(foreign_key="user.id", nullable=False)
|
||||
sound_id: int | None = Field(foreign_key="sound.id", default=None)
|
||||
playlist_id: int | None = Field(foreign_key="playlist.id", default=None)
|
||||
|
||||
# constraints
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "sound_id", name="uq_favorite_user_sound"),
|
||||
UniqueConstraint("user_id", "playlist_id", name="uq_favorite_user_playlist"),
|
||||
)
|
||||
|
||||
# relationships
|
||||
user: "User" = Relationship(back_populates="favorites")
|
||||
sound: "Sound" = Relationship(back_populates="favorites")
|
||||
playlist: "Playlist" = Relationship(back_populates="favorites")
|
||||
@@ -5,6 +5,7 @@ from sqlmodel import Field, Relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.favorite import Favorite
|
||||
from app.models.playlist_sound import PlaylistSound
|
||||
from app.models.user import User
|
||||
|
||||
@@ -23,3 +24,4 @@ class Playlist(BaseModel, table=True):
|
||||
# relationships
|
||||
user: "User" = Relationship(back_populates="playlists")
|
||||
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="playlist")
|
||||
favorites: list["Favorite"] = Relationship(back_populates="playlist")
|
||||
|
||||
125
app/models/scheduled_task.py
Normal file
125
app/models/scheduled_task.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Scheduled task model for flexible task scheduling with timezone support."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import JSON, Column, Field
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
"""Available task types."""
|
||||
|
||||
CREDIT_RECHARGE = "credit_recharge"
|
||||
PLAY_SOUND = "play_sound"
|
||||
PLAY_PLAYLIST = "play_playlist"
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task execution status."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class RecurrenceType(str, Enum):
|
||||
"""Recurrence patterns."""
|
||||
|
||||
NONE = "none" # One-shot task
|
||||
MINUTELY = "minutely"
|
||||
HOURLY = "hourly"
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
YEARLY = "yearly"
|
||||
CRON = "cron" # Custom cron expression
|
||||
|
||||
|
||||
class ScheduledTask(BaseModel, table=True):
|
||||
"""Model for scheduled tasks with timezone support."""
|
||||
|
||||
__tablename__ = "scheduled_task"
|
||||
|
||||
id: int | None = Field(primary_key=True, default=None)
|
||||
name: str = Field(max_length=255, description="Human-readable task name")
|
||||
task_type: TaskType = Field(description="Type of task to execute")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING)
|
||||
|
||||
# Scheduling fields with timezone support
|
||||
scheduled_at: datetime = Field(description="When the task should be executed (UTC)")
|
||||
timezone: str = Field(
|
||||
default="UTC",
|
||||
description="Timezone for scheduling (e.g., 'America/New_York')",
|
||||
)
|
||||
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
|
||||
cron_expression: str | None = Field(
|
||||
default=None,
|
||||
description="Cron expression for custom recurrence",
|
||||
)
|
||||
recurrence_count: int | None = Field(
|
||||
default=None,
|
||||
description="Number of times to repeat (None for infinite)",
|
||||
)
|
||||
executions_count: int = Field(default=0, description="Number of times executed")
|
||||
|
||||
# Task parameters
|
||||
parameters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
sa_column=Column(JSON),
|
||||
description="Task-specific parameters",
|
||||
)
|
||||
|
||||
# User association (None for system tasks)
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
foreign_key="user.id",
|
||||
description="User who created the task (None for system tasks)",
|
||||
)
|
||||
|
||||
# Execution tracking
|
||||
last_executed_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="When the task was last executed (UTC)",
|
||||
)
|
||||
next_execution_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="When the task should be executed next (UTC, for recurring tasks)",
|
||||
)
|
||||
error_message: str | None = Field(
|
||||
default=None,
|
||||
description="Error message if execution failed",
|
||||
)
|
||||
|
||||
# Task lifecycle
|
||||
is_active: bool = Field(default=True, description="Whether the task is active")
|
||||
expires_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="When the task expires (UTC, optional)",
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the task has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.now(tz=UTC).replace(tzinfo=None) > self.expires_at
|
||||
|
||||
def is_recurring(self) -> bool:
|
||||
"""Check if the task is recurring."""
|
||||
return self.recurrence_type != RecurrenceType.NONE
|
||||
|
||||
def should_repeat(self) -> bool:
|
||||
"""Check if the task should be repeated."""
|
||||
if not self.is_recurring():
|
||||
return False
|
||||
if self.recurrence_count is None:
|
||||
return True
|
||||
return self.executions_count < self.recurrence_count
|
||||
|
||||
def is_system_task(self) -> bool:
|
||||
"""Check if this is a system task (no user association)."""
|
||||
return self.user_id is None
|
||||
@@ -6,6 +6,7 @@ from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.extraction import Extraction
|
||||
from app.models.favorite import Favorite
|
||||
from app.models.playlist_sound import PlaylistSound
|
||||
from app.models.sound_played import SoundPlayed
|
||||
|
||||
@@ -30,11 +31,14 @@ class Sound(BaseModel, table=True):
|
||||
is_deletable: bool = Field(default=True, nullable=False)
|
||||
|
||||
# constraints
|
||||
__table_args__ = (
|
||||
UniqueConstraint("hash", name="uq_sound_hash"),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("hash", name="uq_sound_hash"),)
|
||||
|
||||
# relationships
|
||||
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="sound")
|
||||
playlist_sounds: list["PlaylistSound"] = Relationship(
|
||||
back_populates="sound", cascade_delete=True,
|
||||
)
|
||||
extractions: list["Extraction"] = Relationship(back_populates="sound")
|
||||
play_history: list["SoundPlayed"] = Relationship(back_populates="sound")
|
||||
play_history: list["SoundPlayed"] = Relationship(
|
||||
back_populates="sound", cascade_delete=True,
|
||||
)
|
||||
favorites: list["Favorite"] = Relationship(back_populates="sound")
|
||||
|
||||
30
app/models/tts.py
Normal file
30
app/models/tts.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""TTS model."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class TTS(SQLModel, table=True):
|
||||
"""Text-to-Speech generation record."""
|
||||
|
||||
__tablename__ = "tts"
|
||||
|
||||
id: int | None = Field(primary_key=True)
|
||||
text: str = Field(max_length=1000, description="Text that was converted to speech")
|
||||
provider: str = Field(max_length=50, description="TTS provider used")
|
||||
options: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
sa_column=Column(JSON),
|
||||
description="Provider-specific options used",
|
||||
)
|
||||
status: str = Field(default="pending", description="Processing status")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
sound_id: int | None = Field(
|
||||
foreign_key="sound.id", description="Associated sound ID",
|
||||
)
|
||||
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
@@ -8,6 +8,7 @@ from app.models.base import BaseModel
|
||||
if TYPE_CHECKING:
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.extraction import Extraction
|
||||
from app.models.favorite import Favorite
|
||||
from app.models.plan import Plan
|
||||
from app.models.playlist import Playlist
|
||||
from app.models.sound_played import SoundPlayed
|
||||
@@ -37,3 +38,4 @@ class User(BaseModel, table=True):
|
||||
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
|
||||
extractions: list["Extraction"] = Relationship(back_populates="user")
|
||||
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")
|
||||
favorites: list["Favorite"] = Relationship(back_populates="user")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -13,7 +13,7 @@ ModelType = TypeVar("ModelType")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
class BaseRepository[ModelType]:
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
|
||||
@@ -38,11 +38,15 @@ class BaseRepository(Generic[ModelType]):
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
|
||||
statement = select(self.model).where(self.model.id == entity_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
|
||||
logger.exception(
|
||||
"Failed to get %s by ID: %s",
|
||||
self.model.__name__,
|
||||
entity_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_all(
|
||||
@@ -68,27 +72,36 @@ class BaseRepository(Generic[ModelType]):
|
||||
logger.exception("Failed to get all %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
||||
async def create(self, entity_data: dict[str, Any] | ModelType) -> ModelType:
|
||||
"""Create a new entity.
|
||||
|
||||
Args:
|
||||
entity_data: Dictionary of entity data
|
||||
entity_data: Dictionary of entity data or model instance
|
||||
|
||||
Returns:
|
||||
The created entity
|
||||
|
||||
"""
|
||||
try:
|
||||
entity = self.model(**entity_data)
|
||||
if isinstance(entity_data, dict):
|
||||
entity = self.model(**entity_data)
|
||||
else:
|
||||
# Already a model instance
|
||||
entity = entity_data
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create %s", self.model.__name__)
|
||||
raise
|
||||
else:
|
||||
logger.info(
|
||||
"Created new %s with ID: %s",
|
||||
self.model.__name__,
|
||||
getattr(entity, "id", "unknown"),
|
||||
)
|
||||
return entity
|
||||
|
||||
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
|
||||
"""Update an entity.
|
||||
@@ -105,15 +118,22 @@ class BaseRepository(Generic[ModelType]):
|
||||
for field, value in update_data.items():
|
||||
setattr(entity, field, value)
|
||||
|
||||
# The updated_at timestamp will be automatically set by the SQLAlchemy
|
||||
# event listener
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update %s", self.model.__name__)
|
||||
raise
|
||||
else:
|
||||
logger.info(
|
||||
"Updated %s with ID: %s",
|
||||
self.model.__name__,
|
||||
getattr(entity, "id", "unknown"),
|
||||
)
|
||||
return entity
|
||||
|
||||
async def delete(self, entity: ModelType) -> None:
|
||||
"""Delete an entity.
|
||||
@@ -125,8 +145,12 @@ class BaseRepository(Generic[ModelType]):
|
||||
try:
|
||||
await self.session.delete(entity)
|
||||
await self.session.commit()
|
||||
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
logger.info(
|
||||
"Deleted %s with ID: %s",
|
||||
self.model.__name__,
|
||||
getattr(entity, "id", "unknown"),
|
||||
)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete %s", self.model.__name__)
|
||||
raise
|
||||
raise
|
||||
|
||||
@@ -91,18 +91,17 @@ class CreditTransactionRepository(BaseRepository[CreditTransaction]):
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.success == True) # noqa: E712
|
||||
select(CreditTransaction).where(CreditTransaction.success == True) # noqa: E712
|
||||
)
|
||||
|
||||
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(CreditTransaction.user_id == user_id)
|
||||
|
||||
|
||||
stmt = (
|
||||
stmt.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
return list(result.all())
|
||||
|
||||
@@ -1,42 +1,32 @@
|
||||
"""Extraction repository for database operations."""
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import asc, desc, func, or_
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.extraction import Extraction
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class ExtractionRepository:
|
||||
class ExtractionRepository(BaseRepository[Extraction]):
|
||||
"""Repository for extraction database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the extraction repository."""
|
||||
self.session = session
|
||||
|
||||
async def create(self, extraction_data: dict) -> Extraction:
|
||||
"""Create a new extraction."""
|
||||
extraction = Extraction(**extraction_data)
|
||||
self.session.add(extraction)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(extraction)
|
||||
return extraction
|
||||
|
||||
async def get_by_id(self, extraction_id: int) -> Extraction | None:
|
||||
"""Get an extraction by ID."""
|
||||
result = await self.session.exec(
|
||||
select(Extraction).where(Extraction.id == extraction_id)
|
||||
)
|
||||
return result.first()
|
||||
super().__init__(Extraction, session)
|
||||
|
||||
async def get_by_service_and_id(
|
||||
self, service: str, service_id: str
|
||||
self,
|
||||
service: str,
|
||||
service_id: str,
|
||||
) -> Extraction | None:
|
||||
"""Get an extraction by service and service_id."""
|
||||
result = await self.session.exec(
|
||||
select(Extraction).where(
|
||||
Extraction.service == service, Extraction.service_id == service_id
|
||||
)
|
||||
Extraction.service == service,
|
||||
Extraction.service_id == service_id,
|
||||
),
|
||||
)
|
||||
return result.first()
|
||||
|
||||
@@ -45,38 +35,131 @@ class ExtractionRepository:
|
||||
result = await self.session.exec(
|
||||
select(Extraction)
|
||||
.where(Extraction.user_id == user_id)
|
||||
.order_by(desc(Extraction.created_at))
|
||||
.order_by(desc(Extraction.created_at)),
|
||||
)
|
||||
return list(result.all())
|
||||
|
||||
async def get_pending_extractions(self) -> list[Extraction]:
|
||||
"""Get all pending extractions."""
|
||||
async def get_by_status(self, status: str) -> list[Extraction]:
|
||||
"""Get all extractions by status."""
|
||||
result = await self.session.exec(
|
||||
select(Extraction)
|
||||
.where(Extraction.status == "pending")
|
||||
.order_by(Extraction.created_at)
|
||||
.where(Extraction.status == status)
|
||||
.order_by(Extraction.created_at),
|
||||
)
|
||||
return list(result.all())
|
||||
|
||||
async def update(self, extraction: Extraction, update_data: dict) -> Extraction:
|
||||
"""Update an extraction."""
|
||||
for key, value in update_data.items():
|
||||
setattr(extraction, key, value)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(extraction)
|
||||
return extraction
|
||||
|
||||
async def delete(self, extraction: Extraction) -> None:
|
||||
"""Delete an extraction."""
|
||||
await self.session.delete(extraction)
|
||||
await self.session.commit()
|
||||
async def get_pending_extractions(self) -> list[tuple[Extraction, User]]:
|
||||
"""Get all pending extractions."""
|
||||
result = await self.session.exec(
|
||||
select(Extraction, User)
|
||||
.join(User, Extraction.user_id == User.id)
|
||||
.where(Extraction.status == "pending")
|
||||
.order_by(Extraction.created_at),
|
||||
)
|
||||
return list(result.all())
|
||||
|
||||
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
|
||||
"""Get extractions by status."""
|
||||
result = await self.session.exec(
|
||||
select(Extraction)
|
||||
.where(Extraction.status == status)
|
||||
.order_by(desc(Extraction.created_at))
|
||||
.order_by(desc(Extraction.created_at)),
|
||||
)
|
||||
return list(result.all())
|
||||
|
||||
async def get_user_extractions_filtered( # noqa: PLR0913
|
||||
self,
|
||||
user_id: int,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
status_filter: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[tuple[Extraction, User]], int]:
|
||||
"""Get extractions for a user with filtering, search, and sorting."""
|
||||
base_query = (
|
||||
select(Extraction, User)
|
||||
.join(User, Extraction.user_id == User.id)
|
||||
.where(Extraction.user_id == user_id)
|
||||
)
|
||||
|
||||
# Apply search filter
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
Extraction.title.ilike(search_pattern),
|
||||
Extraction.url.ilike(search_pattern),
|
||||
Extraction.service.ilike(search_pattern),
|
||||
),
|
||||
)
|
||||
|
||||
# Apply status filter
|
||||
if status_filter:
|
||||
base_query = base_query.where(Extraction.status == status_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(
|
||||
base_query.subquery(),
|
||||
)
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
|
||||
# Apply sorting and pagination
|
||||
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
|
||||
if sort_order.lower() == "asc":
|
||||
base_query = base_query.order_by(asc(sort_column))
|
||||
else:
|
||||
base_query = base_query.order_by(desc(sort_column))
|
||||
|
||||
paginated_query = base_query.limit(limit).offset(offset)
|
||||
result = await self.session.exec(paginated_query)
|
||||
|
||||
return list(result.all()), total_count
|
||||
|
||||
async def get_all_extractions_filtered( # noqa: PLR0913
|
||||
self,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
status_filter: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[tuple[Extraction, User]], int]:
|
||||
"""Get all extractions with filtering, search, and sorting."""
|
||||
base_query = select(Extraction, User).join(User, Extraction.user_id == User.id)
|
||||
|
||||
# Apply search filter
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
Extraction.title.ilike(search_pattern),
|
||||
Extraction.url.ilike(search_pattern),
|
||||
Extraction.service.ilike(search_pattern),
|
||||
),
|
||||
)
|
||||
|
||||
# Apply status filter
|
||||
if status_filter:
|
||||
base_query = base_query.where(Extraction.status == status_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(
|
||||
base_query.subquery(),
|
||||
)
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
|
||||
# Apply sorting and pagination
|
||||
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
|
||||
if sort_order.lower() == "asc":
|
||||
base_query = base_query.order_by(asc(sort_column))
|
||||
else:
|
||||
base_query = base_query.order_by(desc(sort_column))
|
||||
|
||||
paginated_query = base_query.limit(limit).offset(offset)
|
||||
result = await self.session.exec(paginated_query)
|
||||
|
||||
return list(result.all()), total_count
|
||||
|
||||
258
app/repositories/favorite.py
Normal file
258
app/repositories/favorite.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Repository for managing favorites."""
|
||||
|
||||
from sqlmodel import and_, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.favorite import Favorite
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FavoriteRepository(BaseRepository[Favorite]):
|
||||
"""Repository for managing favorites."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the favorite repository.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
||||
"""
|
||||
super().__init__(Favorite, session)
|
||||
|
||||
async def get_user_favorites(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[Favorite]:
|
||||
"""Get all favorites for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of favorites to return
|
||||
offset: Number of favorites to skip
|
||||
|
||||
Returns:
|
||||
List of user favorites
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = (
|
||||
select(Favorite)
|
||||
.where(Favorite.user_id == user_id)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.order_by(Favorite.created_at.desc())
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get favorites for user: %s", user_id)
|
||||
raise
|
||||
|
||||
async def get_user_sound_favorites(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[Favorite]:
|
||||
"""Get sound favorites for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of favorites to return
|
||||
offset: Number of favorites to skip
|
||||
|
||||
Returns:
|
||||
List of user sound favorites
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = (
|
||||
select(Favorite)
|
||||
.where(and_(Favorite.user_id == user_id, Favorite.sound_id.isnot(None)))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.order_by(Favorite.created_at.desc())
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get sound favorites for user: %s", user_id)
|
||||
raise
|
||||
|
||||
async def get_user_playlist_favorites(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[Favorite]:
|
||||
"""Get playlist favorites for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of favorites to return
|
||||
offset: Number of favorites to skip
|
||||
|
||||
Returns:
|
||||
List of user playlist favorites
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = (
|
||||
select(Favorite)
|
||||
.where(
|
||||
and_(Favorite.user_id == user_id, Favorite.playlist_id.isnot(None)),
|
||||
)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.order_by(Favorite.created_at.desc())
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get playlist favorites for user: %s", user_id)
|
||||
raise
|
||||
|
||||
async def get_by_user_and_sound(
|
||||
self,
|
||||
user_id: int,
|
||||
sound_id: int,
|
||||
) -> Favorite | None:
|
||||
"""Get a favorite by user and sound.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
sound_id: The sound ID
|
||||
|
||||
Returns:
|
||||
The favorite if found, None otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(Favorite).where(
|
||||
and_(Favorite.user_id == user_id, Favorite.sound_id == sound_id),
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get favorite for user %s and sound %s",
|
||||
user_id,
|
||||
sound_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_user_and_playlist(
|
||||
self,
|
||||
user_id: int,
|
||||
playlist_id: int,
|
||||
) -> Favorite | None:
|
||||
"""Get a favorite by user and playlist.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
playlist_id: The playlist ID
|
||||
|
||||
Returns:
|
||||
The favorite if found, None otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(Favorite).where(
|
||||
and_(Favorite.user_id == user_id, Favorite.playlist_id == playlist_id),
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get favorite for user %s and playlist %s",
|
||||
user_id,
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def is_sound_favorited(self, user_id: int, sound_id: int) -> bool:
|
||||
"""Check if a sound is favorited by a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
sound_id: The sound ID
|
||||
|
||||
Returns:
|
||||
True if the sound is favorited, False otherwise
|
||||
|
||||
"""
|
||||
favorite = await self.get_by_user_and_sound(user_id, sound_id)
|
||||
return favorite is not None
|
||||
|
||||
async def is_playlist_favorited(self, user_id: int, playlist_id: int) -> bool:
|
||||
"""Check if a playlist is favorited by a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
playlist_id: The playlist ID
|
||||
|
||||
Returns:
|
||||
True if the playlist is favorited, False otherwise
|
||||
|
||||
"""
|
||||
favorite = await self.get_by_user_and_playlist(user_id, playlist_id)
|
||||
return favorite is not None
|
||||
|
||||
async def count_user_favorites(self, user_id: int) -> int:
|
||||
"""Count total favorites for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
Total number of favorites
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(Favorite).where(Favorite.user_id == user_id)
|
||||
result = await self.session.exec(statement)
|
||||
return len(list(result.all()))
|
||||
except Exception:
|
||||
logger.exception("Failed to count favorites for user: %s", user_id)
|
||||
raise
|
||||
|
||||
async def count_sound_favorites(self, sound_id: int) -> int:
|
||||
"""Count how many users have favorited a sound.
|
||||
|
||||
Args:
|
||||
sound_id: The sound ID
|
||||
|
||||
Returns:
|
||||
Number of users who favorited this sound
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(Favorite).where(Favorite.sound_id == sound_id)
|
||||
result = await self.session.exec(statement)
|
||||
return len(list(result.all()))
|
||||
except Exception:
|
||||
logger.exception("Failed to count favorites for sound: %s", sound_id)
|
||||
raise
|
||||
|
||||
async def count_playlist_favorites(self, playlist_id: int) -> int:
|
||||
"""Count how many users have favorited a playlist.
|
||||
|
||||
Args:
|
||||
playlist_id: The playlist ID
|
||||
|
||||
Returns:
|
||||
Number of users who favorited this playlist
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(Favorite).where(Favorite.playlist_id == playlist_id)
|
||||
result = await self.session.exec(statement)
|
||||
return len(list(result.all()))
|
||||
except Exception:
|
||||
logger.exception("Failed to count favorites for playlist: %s", playlist_id)
|
||||
raise
|
||||
17
app/repositories/plan.py
Normal file
17
app/repositories/plan.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Plan repository."""
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.plan import Plan
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PlanRepository(BaseRepository[Plan]):
|
||||
"""Repository for plan operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the plan repository."""
|
||||
super().__init__(Plan, session)
|
||||
@@ -1,34 +1,65 @@
|
||||
"""Playlist repository for database operations."""
|
||||
|
||||
from typing import Any
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.favorite import Favorite
|
||||
from app.models.playlist import Playlist
|
||||
from app.models.playlist_sound import PlaylistSound
|
||||
from app.models.sound import Sound
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PlaylistRepository:
|
||||
class PlaylistSortField(str, Enum):
|
||||
"""Playlist sort field enumeration."""
|
||||
|
||||
NAME = "name"
|
||||
GENRE = "genre"
|
||||
CREATED_AT = "created_at"
|
||||
UPDATED_AT = "updated_at"
|
||||
SOUND_COUNT = "sound_count"
|
||||
TOTAL_DURATION = "total_duration"
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order enumeration."""
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class PlaylistRepository(BaseRepository[Playlist]):
|
||||
"""Repository for playlist operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the playlist repository."""
|
||||
self.session = session
|
||||
super().__init__(Playlist, session)
|
||||
|
||||
async def get_by_id(self, playlist_id: int) -> Playlist | None:
|
||||
"""Get a playlist by ID."""
|
||||
async def _update_playlist_timestamp(self, playlist_id: int) -> None:
|
||||
"""Update the playlist's updated_at timestamp."""
|
||||
try:
|
||||
statement = select(Playlist).where(Playlist.id == playlist_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
update_stmt = (
|
||||
update(Playlist)
|
||||
.where(Playlist.id == playlist_id)
|
||||
.values(updated_at=datetime.now(UTC))
|
||||
)
|
||||
await self.session.exec(update_stmt)
|
||||
# Note: No commit here - let the calling method handle transaction
|
||||
# management
|
||||
except Exception:
|
||||
logger.exception("Failed to get playlist by ID: %s", playlist_id)
|
||||
logger.exception(
|
||||
"Failed to update playlist timestamp for playlist: %s",
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_name(self, name: str) -> Playlist | None:
|
||||
@@ -51,16 +82,6 @@ class PlaylistRepository:
|
||||
logger.exception("Failed to get playlists for user: %s", user_id)
|
||||
raise
|
||||
|
||||
async def get_all(self) -> list[Playlist]:
|
||||
"""Get all playlists from all users."""
|
||||
try:
|
||||
statement = select(Playlist)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get all playlists")
|
||||
raise
|
||||
|
||||
async def get_main_playlist(self) -> Playlist | None:
|
||||
"""Get the global main playlist."""
|
||||
try:
|
||||
@@ -73,63 +94,22 @@ class PlaylistRepository:
|
||||
logger.exception("Failed to get main playlist")
|
||||
raise
|
||||
|
||||
async def get_current_playlist(self, user_id: int) -> Playlist | None:
|
||||
"""Get the user's current playlist."""
|
||||
async def get_current_playlist(self) -> Playlist | None:
|
||||
"""Get the global current playlist (app-wide)."""
|
||||
try:
|
||||
statement = select(Playlist).where(
|
||||
Playlist.user_id == user_id,
|
||||
Playlist.is_current == True, # noqa: E712
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get current playlist for user: %s", user_id)
|
||||
raise
|
||||
|
||||
async def create(self, playlist_data: dict[str, Any]) -> Playlist:
|
||||
"""Create a new playlist."""
|
||||
try:
|
||||
playlist = Playlist(**playlist_data)
|
||||
self.session.add(playlist)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(playlist)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create playlist")
|
||||
raise
|
||||
else:
|
||||
logger.info("Created new playlist: %s", playlist.name)
|
||||
return playlist
|
||||
|
||||
async def update(self, playlist: Playlist, update_data: dict[str, Any]) -> Playlist:
|
||||
"""Update a playlist."""
|
||||
try:
|
||||
for field, value in update_data.items():
|
||||
setattr(playlist, field, value)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(playlist)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update playlist")
|
||||
raise
|
||||
else:
|
||||
logger.info("Updated playlist: %s", playlist.name)
|
||||
return playlist
|
||||
|
||||
async def delete(self, playlist: Playlist) -> None:
|
||||
"""Delete a playlist."""
|
||||
try:
|
||||
await self.session.delete(playlist)
|
||||
await self.session.commit()
|
||||
logger.info("Deleted playlist: %s", playlist.name)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete playlist")
|
||||
logger.exception("Failed to get current playlist")
|
||||
raise
|
||||
|
||||
async def search_by_name(
|
||||
self, query: str, user_id: int | None = None
|
||||
self,
|
||||
query: str,
|
||||
user_id: int | None = None,
|
||||
) -> list[Playlist]:
|
||||
"""Search playlists by name (case-insensitive)."""
|
||||
try:
|
||||
@@ -146,11 +126,12 @@ class PlaylistRepository:
|
||||
raise
|
||||
|
||||
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
|
||||
"""Get all sounds in a playlist, ordered by position."""
|
||||
"""Get all sounds in a playlist with extractions, ordered by position."""
|
||||
try:
|
||||
statement = (
|
||||
select(Sound)
|
||||
.join(PlaylistSound)
|
||||
.options(selectinload(Sound.extractions))
|
||||
.where(PlaylistSound.playlist_id == playlist_id)
|
||||
.order_by(PlaylistSound.position)
|
||||
)
|
||||
@@ -160,18 +141,67 @@ class PlaylistRepository:
|
||||
logger.exception("Failed to get sounds for playlist: %s", playlist_id)
|
||||
raise
|
||||
|
||||
async def get_playlist_sound_entries(self, playlist_id: int) -> list[PlaylistSound]:
|
||||
"""Get all PlaylistSound entries for a playlist, ordered by position."""
|
||||
try:
|
||||
statement = (
|
||||
select(PlaylistSound)
|
||||
.where(PlaylistSound.playlist_id == playlist_id)
|
||||
.order_by(PlaylistSound.position)
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get playlist sound entries for playlist: %s",
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def add_sound_to_playlist(
|
||||
self, playlist_id: int, sound_id: int, position: int | None = None
|
||||
self,
|
||||
playlist_id: int,
|
||||
sound_id: int,
|
||||
position: int | None = None,
|
||||
) -> PlaylistSound:
|
||||
"""Add a sound to a playlist."""
|
||||
try:
|
||||
if position is None:
|
||||
# Get the next available position
|
||||
statement = select(
|
||||
func.coalesce(func.max(PlaylistSound.position), -1) + 1
|
||||
func.coalesce(func.max(PlaylistSound.position), -1) + 1,
|
||||
).where(PlaylistSound.playlist_id == playlist_id)
|
||||
result = await self.session.exec(statement)
|
||||
position = result.first() or 0
|
||||
else:
|
||||
# Shift existing positions to make room for the new sound
|
||||
# Use a two-step approach to avoid unique constraint violations:
|
||||
# 1. Move all affected positions to negative temporary positions
|
||||
# 2. Then move them to their final positions
|
||||
|
||||
# Step 1: Move to temporary negative positions
|
||||
update_to_negative = (
|
||||
update(PlaylistSound)
|
||||
.where(
|
||||
PlaylistSound.playlist_id == playlist_id,
|
||||
PlaylistSound.position >= position,
|
||||
)
|
||||
.values(position=PlaylistSound.position - 10000)
|
||||
)
|
||||
await self.session.exec(update_to_negative)
|
||||
await self.session.commit()
|
||||
|
||||
# Step 2: Move from temporary negative positions to final positions
|
||||
update_to_final = (
|
||||
update(PlaylistSound)
|
||||
.where(
|
||||
PlaylistSound.playlist_id == playlist_id,
|
||||
PlaylistSound.position < 0,
|
||||
)
|
||||
.values(position=PlaylistSound.position + 10001)
|
||||
)
|
||||
await self.session.exec(update_to_final)
|
||||
await self.session.commit()
|
||||
|
||||
playlist_sound = PlaylistSound(
|
||||
playlist_id=playlist_id,
|
||||
@@ -179,12 +209,17 @@ class PlaylistRepository:
|
||||
position=position,
|
||||
)
|
||||
self.session.add(playlist_sound)
|
||||
|
||||
# Update playlist timestamp before commit
|
||||
await self._update_playlist_timestamp(playlist_id)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(playlist_sound)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception(
|
||||
"Failed to add sound %s to playlist %s", sound_id, playlist_id
|
||||
"Failed to add sound %s to playlist %s",
|
||||
sound_id,
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
@@ -208,25 +243,47 @@ class PlaylistRepository:
|
||||
|
||||
if playlist_sound:
|
||||
await self.session.delete(playlist_sound)
|
||||
|
||||
# Update playlist timestamp before commit
|
||||
await self._update_playlist_timestamp(playlist_id)
|
||||
await self.session.commit()
|
||||
logger.info("Removed sound %s from playlist %s", sound_id, playlist_id)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception(
|
||||
"Failed to remove sound %s from playlist %s", sound_id, playlist_id
|
||||
"Failed to remove sound %s from playlist %s",
|
||||
sound_id,
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def reorder_playlist_sounds(
|
||||
self, playlist_id: int, sound_positions: list[tuple[int, int]]
|
||||
self,
|
||||
playlist_id: int,
|
||||
sound_positions: list[tuple[int, int]],
|
||||
) -> None:
|
||||
"""Reorder sounds in a playlist.
|
||||
|
||||
Args:
|
||||
playlist_id: The playlist ID
|
||||
sound_positions: List of (sound_id, new_position) tuples
|
||||
|
||||
"""
|
||||
try:
|
||||
# Phase 1: Set all positions to temporary negative values to avoid conflicts
|
||||
temp_offset = -10000 # Use large negative number to avoid conflicts
|
||||
for i, (sound_id, _) in enumerate(sound_positions):
|
||||
statement = select(PlaylistSound).where(
|
||||
PlaylistSound.playlist_id == playlist_id,
|
||||
PlaylistSound.sound_id == sound_id,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
playlist_sound = result.first()
|
||||
|
||||
if playlist_sound:
|
||||
playlist_sound.position = temp_offset + i
|
||||
|
||||
# Phase 2: Set the final positions
|
||||
for sound_id, new_position in sound_positions:
|
||||
statement = select(PlaylistSound).where(
|
||||
PlaylistSound.playlist_id == playlist_id,
|
||||
@@ -238,6 +295,8 @@ class PlaylistRepository:
|
||||
if playlist_sound:
|
||||
playlist_sound.position = new_position
|
||||
|
||||
# Update playlist timestamp before commit
|
||||
await self._update_playlist_timestamp(playlist_id)
|
||||
await self.session.commit()
|
||||
logger.info("Reordered sounds in playlist %s", playlist_id)
|
||||
except Exception:
|
||||
@@ -249,7 +308,7 @@ class PlaylistRepository:
|
||||
"""Get the number of sounds in a playlist."""
|
||||
try:
|
||||
statement = select(func.count(PlaylistSound.id)).where(
|
||||
PlaylistSound.playlist_id == playlist_id
|
||||
PlaylistSound.playlist_id == playlist_id,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first() or 0
|
||||
@@ -268,6 +327,230 @@ class PlaylistRepository:
|
||||
return result.first() is not None
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id
|
||||
"Failed to check if sound %s is in playlist %s",
|
||||
sound_id,
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def search_and_sort( # noqa: C901, PLR0913, PLR0912, PLR0915
|
||||
self,
|
||||
search_query: str | None = None,
|
||||
sort_by: PlaylistSortField | None = None,
|
||||
sort_order: SortOrder = SortOrder.ASC,
|
||||
user_id: int | None = None,
|
||||
include_stats: bool = False, # noqa: FBT001, FBT002
|
||||
limit: int | None = None,
|
||||
offset: int = 0,
|
||||
favorites_only: bool = False, # noqa: FBT001, FBT002
|
||||
current_user_id: int | None = None,
|
||||
*,
|
||||
return_count: bool = False,
|
||||
) -> list[dict] | tuple[list[dict], int]:
|
||||
"""Search and sort playlists with optional statistics."""
|
||||
try:
|
||||
if include_stats and sort_by in (
|
||||
PlaylistSortField.SOUND_COUNT,
|
||||
PlaylistSortField.TOTAL_DURATION,
|
||||
):
|
||||
# Use subquery for sorting by stats
|
||||
subquery = (
|
||||
select(
|
||||
Playlist.id,
|
||||
Playlist.name,
|
||||
Playlist.description,
|
||||
Playlist.genre,
|
||||
Playlist.user_id,
|
||||
Playlist.is_main,
|
||||
Playlist.is_current,
|
||||
Playlist.is_deletable,
|
||||
Playlist.created_at,
|
||||
Playlist.updated_at,
|
||||
func.count(PlaylistSound.id).label("sound_count"),
|
||||
func.coalesce(func.sum(Sound.duration), 0).label(
|
||||
"total_duration",
|
||||
),
|
||||
User.name.label("user_name"),
|
||||
)
|
||||
.select_from(Playlist)
|
||||
.join(User, Playlist.user_id == User.id, isouter=True)
|
||||
.join(
|
||||
PlaylistSound,
|
||||
Playlist.id == PlaylistSound.playlist_id,
|
||||
isouter=True,
|
||||
)
|
||||
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
||||
.group_by(Playlist.id, User.name)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if search_query and search_query.strip():
|
||||
search_pattern = f"%{search_query.strip().lower()}%"
|
||||
subquery = subquery.where(
|
||||
func.lower(Playlist.name).like(search_pattern),
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
subquery = subquery.where(Playlist.user_id == user_id)
|
||||
|
||||
# Apply favorites filter
|
||||
if favorites_only and current_user_id is not None:
|
||||
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
|
||||
favorites_subquery = (
|
||||
select(1)
|
||||
.select_from(Favorite)
|
||||
.where(
|
||||
Favorite.user_id == current_user_id,
|
||||
Favorite.playlist_id == Playlist.id,
|
||||
)
|
||||
)
|
||||
subquery = subquery.where(favorites_subquery.exists())
|
||||
|
||||
# Apply sorting
|
||||
if sort_by == PlaylistSortField.SOUND_COUNT:
|
||||
if sort_order == SortOrder.DESC:
|
||||
subquery = subquery.order_by(
|
||||
func.count(PlaylistSound.id).desc(),
|
||||
)
|
||||
else:
|
||||
subquery = subquery.order_by(func.count(PlaylistSound.id).asc())
|
||||
elif sort_by == PlaylistSortField.TOTAL_DURATION:
|
||||
if sort_order == SortOrder.DESC:
|
||||
subquery = subquery.order_by(
|
||||
func.coalesce(func.sum(Sound.duration), 0).desc(),
|
||||
)
|
||||
else:
|
||||
subquery = subquery.order_by(
|
||||
func.coalesce(func.sum(Sound.duration), 0).asc(),
|
||||
)
|
||||
else:
|
||||
# Default sorting by name
|
||||
subquery = subquery.order_by(Playlist.name.asc())
|
||||
|
||||
else:
|
||||
# Simple query without stats-based sorting
|
||||
subquery = (
|
||||
select(
|
||||
Playlist.id,
|
||||
Playlist.name,
|
||||
Playlist.description,
|
||||
Playlist.genre,
|
||||
Playlist.user_id,
|
||||
Playlist.is_main,
|
||||
Playlist.is_current,
|
||||
Playlist.is_deletable,
|
||||
Playlist.created_at,
|
||||
Playlist.updated_at,
|
||||
func.count(PlaylistSound.id).label("sound_count"),
|
||||
func.coalesce(func.sum(Sound.duration), 0).label(
|
||||
"total_duration",
|
||||
),
|
||||
User.name.label("user_name"),
|
||||
)
|
||||
.select_from(Playlist)
|
||||
.join(User, Playlist.user_id == User.id, isouter=True)
|
||||
.join(
|
||||
PlaylistSound,
|
||||
Playlist.id == PlaylistSound.playlist_id,
|
||||
isouter=True,
|
||||
)
|
||||
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
||||
.group_by(Playlist.id, User.name)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if search_query and search_query.strip():
|
||||
search_pattern = f"%{search_query.strip().lower()}%"
|
||||
subquery = subquery.where(
|
||||
func.lower(Playlist.name).like(search_pattern),
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
subquery = subquery.where(Playlist.user_id == user_id)
|
||||
|
||||
# Apply favorites filter
|
||||
if favorites_only and current_user_id is not None:
|
||||
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
|
||||
favorites_subquery = (
|
||||
select(1)
|
||||
.select_from(Favorite)
|
||||
.where(
|
||||
Favorite.user_id == current_user_id,
|
||||
Favorite.playlist_id == Playlist.id,
|
||||
)
|
||||
)
|
||||
subquery = subquery.where(favorites_subquery.exists())
|
||||
|
||||
# Apply sorting
|
||||
if sort_by:
|
||||
if sort_by == PlaylistSortField.NAME:
|
||||
sort_column = Playlist.name
|
||||
elif sort_by == PlaylistSortField.GENRE:
|
||||
sort_column = Playlist.genre
|
||||
elif sort_by == PlaylistSortField.CREATED_AT:
|
||||
sort_column = Playlist.created_at
|
||||
elif sort_by == PlaylistSortField.UPDATED_AT:
|
||||
sort_column = Playlist.updated_at
|
||||
else:
|
||||
sort_column = Playlist.name
|
||||
|
||||
if sort_order == SortOrder.DESC:
|
||||
subquery = subquery.order_by(sort_column.desc())
|
||||
else:
|
||||
subquery = subquery.order_by(sort_column.asc())
|
||||
else:
|
||||
# Default sorting by name ascending
|
||||
subquery = subquery.order_by(Playlist.name.asc())
|
||||
|
||||
# Get total count if requested
|
||||
total_count = 0
|
||||
if return_count:
|
||||
# Create count query from the subquery before pagination
|
||||
count_query = select(func.count()).select_from(subquery.subquery())
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
|
||||
# Apply pagination
|
||||
if offset > 0:
|
||||
subquery = subquery.offset(offset)
|
||||
if limit is not None:
|
||||
subquery = subquery.limit(limit)
|
||||
|
||||
result = await self.session.exec(subquery)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to dictionary format
|
||||
playlists = [
|
||||
{
|
||||
"id": row.id,
|
||||
"name": row.name,
|
||||
"description": row.description,
|
||||
"genre": row.genre,
|
||||
"user_id": row.user_id,
|
||||
"user_name": row.user_name,
|
||||
"is_main": row.is_main,
|
||||
"is_current": row.is_current,
|
||||
"is_deletable": row.is_deletable,
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
"sound_count": row.sound_count or 0,
|
||||
"total_duration": row.total_duration or 0,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
(
|
||||
"Failed to search and sort playlists: "
|
||||
"query=%s, sort_by=%s, sort_order=%s"
|
||||
),
|
||||
search_query,
|
||||
sort_by,
|
||||
sort_order,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
if return_count:
|
||||
return playlists, total_count
|
||||
return playlists
|
||||
|
||||
181
app/repositories/scheduled_task.py
Normal file
181
app/repositories/scheduled_task.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Repository for scheduled task operations."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
"""Repository for scheduled task database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the repository."""
|
||||
super().__init__(ScheduledTask, session)
|
||||
|
||||
async def get_pending_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get all pending tasks that are ready to be executed."""
|
||||
now = datetime.now(tz=UTC)
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.status == TaskStatus.PENDING,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
ScheduledTask.scheduled_at <= now,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_active_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get all active tasks."""
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.is_active.is_(True),
|
||||
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
status: TaskStatus | None = None,
|
||||
task_type: TaskType | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get tasks for a specific user."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
||||
|
||||
if status:
|
||||
statement = statement.where(ScheduledTask.status == status)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_system_tasks(
|
||||
self,
|
||||
status: TaskStatus | None = None,
|
||||
task_type: TaskType | None = None,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get system tasks (tasks with no user association)."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
||||
|
||||
if status:
|
||||
statement = statement.where(ScheduledTask.status == status)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
|
||||
"""Get recurring tasks that need their next execution scheduled."""
|
||||
now = datetime.now(tz=UTC)
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.recurrence_type != RecurrenceType.NONE,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
ScheduledTask.status == TaskStatus.COMPLETED,
|
||||
ScheduledTask.next_execution_at <= now,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_expired_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get expired tasks that should be cleaned up."""
|
||||
now = datetime.now(tz=UTC)
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.expires_at.is_not(None),
|
||||
ScheduledTask.expires_at <= now,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def cancel_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
task_type: TaskType | None = None,
|
||||
) -> int:
|
||||
"""Cancel all pending/running tasks for a user."""
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.user_id == user_id,
|
||||
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
||||
)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
tasks = list(result.all())
|
||||
|
||||
count = 0
|
||||
for task in tasks:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
self.session.add(task)
|
||||
count += 1
|
||||
|
||||
await self.session.commit()
|
||||
return count
|
||||
|
||||
async def mark_as_running(self, task: ScheduledTask) -> None:
|
||||
"""Mark a task as running."""
|
||||
task.status = TaskStatus.RUNNING
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
|
||||
async def mark_as_completed(
|
||||
self,
|
||||
task: ScheduledTask,
|
||||
next_execution_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Mark a task as completed and set next execution if recurring."""
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.last_executed_at = datetime.now(tz=UTC)
|
||||
task.executions_count += 1
|
||||
task.error_message = None
|
||||
|
||||
if next_execution_at and task.should_repeat():
|
||||
task.next_execution_at = next_execution_at
|
||||
task.status = TaskStatus.PENDING
|
||||
elif not task.should_repeat():
|
||||
task.is_active = False
|
||||
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
|
||||
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
|
||||
"""Mark a task as failed with error message."""
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error_message = error_message
|
||||
task.last_executed_at = datetime.now(tz=UTC)
|
||||
|
||||
# For non-recurring tasks, deactivate on failure
|
||||
if not task.is_recurring():
|
||||
task.is_active = False
|
||||
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
@@ -1,33 +1,47 @@
|
||||
"""Sound repository for database operations."""
|
||||
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlmodel import select
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.favorite import Favorite
|
||||
from app.models.sound import Sound
|
||||
from app.models.sound_played import SoundPlayed
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SoundRepository:
|
||||
class SoundSortField(str, Enum):
|
||||
"""Sound sort field enumeration."""
|
||||
|
||||
NAME = "name"
|
||||
FILENAME = "filename"
|
||||
DURATION = "duration"
|
||||
SIZE = "size"
|
||||
TYPE = "type"
|
||||
PLAY_COUNT = "play_count"
|
||||
CREATED_AT = "created_at"
|
||||
UPDATED_AT = "updated_at"
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order enumeration."""
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class SoundRepository(BaseRepository[Sound]):
|
||||
"""Repository for sound operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the sound repository."""
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, sound_id: int) -> Sound | None:
|
||||
"""Get a sound by ID."""
|
||||
try:
|
||||
statement = select(Sound).where(Sound.id == sound_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get sound by ID: %s", sound_id)
|
||||
raise
|
||||
super().__init__(Sound, session)
|
||||
|
||||
async def get_by_filename(self, filename: str) -> Sound | None:
|
||||
"""Get a sound by filename."""
|
||||
@@ -59,48 +73,6 @@ class SoundRepository:
|
||||
logger.exception("Failed to get sounds by type: %s", sound_type)
|
||||
raise
|
||||
|
||||
async def create(self, sound_data: dict[str, Any]) -> Sound:
|
||||
"""Create a new sound."""
|
||||
try:
|
||||
sound = Sound(**sound_data)
|
||||
self.session.add(sound)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(sound)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create sound")
|
||||
raise
|
||||
else:
|
||||
logger.info("Created new sound: %s", sound.name)
|
||||
return sound
|
||||
|
||||
async def update(self, sound: Sound, update_data: dict[str, Any]) -> Sound:
|
||||
"""Update a sound."""
|
||||
try:
|
||||
for field, value in update_data.items():
|
||||
setattr(sound, field, value)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(sound)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update sound")
|
||||
raise
|
||||
else:
|
||||
logger.info("Updated sound: %s", sound.name)
|
||||
return sound
|
||||
|
||||
async def delete(self, sound: Sound) -> None:
|
||||
"""Delete a sound."""
|
||||
try:
|
||||
await self.session.delete(sound)
|
||||
await self.session.commit()
|
||||
logger.info("Deleted sound: %s", sound.name)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete sound")
|
||||
raise
|
||||
|
||||
async def search_by_name(self, query: str) -> list[Sound]:
|
||||
"""Search sounds by name (case-insensitive)."""
|
||||
try:
|
||||
@@ -144,6 +116,213 @@ class SoundRepository:
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get unnormalized sounds by type: %s", sound_type
|
||||
"Failed to get unnormalized sounds by type: %s",
|
||||
sound_type,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_types(self, sound_types: list[str] | None = None) -> list[Sound]:
|
||||
"""Get sounds by types. If types is None or empty, return all sounds."""
|
||||
try:
|
||||
statement = select(Sound)
|
||||
if sound_types:
|
||||
statement = statement.where(col(Sound.type).in_(sound_types))
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get sounds by types: %s", sound_types)
|
||||
raise
|
||||
|
||||
async def search_and_sort( # noqa: PLR0913
|
||||
self,
|
||||
search_query: str | None = None,
|
||||
sound_types: list[str] | None = None,
|
||||
sort_by: SoundSortField | None = None,
|
||||
sort_order: SortOrder = SortOrder.ASC,
|
||||
limit: int | None = None,
|
||||
offset: int = 0,
|
||||
favorites_only: bool = False, # noqa: FBT001, FBT002
|
||||
user_id: int | None = None,
|
||||
) -> list[Sound]:
|
||||
"""Search and sort sounds with optional filtering."""
|
||||
try:
|
||||
statement = select(Sound)
|
||||
|
||||
# Apply favorites filter
|
||||
if favorites_only and user_id is not None:
|
||||
statement = statement.join(Favorite).where(
|
||||
Favorite.user_id == user_id,
|
||||
Favorite.sound_id == Sound.id,
|
||||
)
|
||||
|
||||
# Apply type filter
|
||||
if sound_types:
|
||||
statement = statement.where(col(Sound.type).in_(sound_types))
|
||||
|
||||
# Apply search filter
|
||||
if search_query and search_query.strip():
|
||||
search_pattern = f"%{search_query.strip().lower()}%"
|
||||
statement = statement.where(
|
||||
func.lower(Sound.name).like(search_pattern),
|
||||
)
|
||||
|
||||
# Apply sorting
|
||||
if sort_by:
|
||||
sort_column = getattr(Sound, sort_by.value)
|
||||
if sort_order == SortOrder.DESC:
|
||||
statement = statement.order_by(sort_column.desc())
|
||||
else:
|
||||
statement = statement.order_by(sort_column.asc())
|
||||
else:
|
||||
# Default sorting by name ascending
|
||||
statement = statement.order_by(Sound.name.asc())
|
||||
|
||||
# Apply pagination
|
||||
if offset > 0:
|
||||
statement = statement.offset(offset)
|
||||
if limit is not None:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception(
|
||||
(
|
||||
"Failed to search and sort sounds: "
|
||||
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, "
|
||||
"user_id=%s"
|
||||
),
|
||||
search_query,
|
||||
sound_types,
|
||||
sort_by,
|
||||
sort_order,
|
||||
favorites_only,
|
||||
user_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_soundboard_statistics(
|
||||
self,
|
||||
sound_type: str = "SDB",
|
||||
) -> dict[str, int | float]:
|
||||
"""Get statistics for sounds of a specific type."""
|
||||
try:
|
||||
statement = select(
|
||||
func.count(Sound.id).label("count"),
|
||||
func.sum(Sound.play_count).label("total_plays"),
|
||||
func.sum(Sound.duration).label("total_duration"),
|
||||
func.sum(
|
||||
Sound.size + func.coalesce(Sound.normalized_size, 0),
|
||||
).label("total_size"),
|
||||
).where(Sound.type == sound_type)
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
row = result.first()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to get soundboard statistics")
|
||||
raise
|
||||
else:
|
||||
return {
|
||||
"count": row.count if row.count is not None else 0,
|
||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||
"total_duration": (
|
||||
row.total_duration if row.total_duration is not None else 0
|
||||
),
|
||||
"total_size": row.total_size if row.total_size is not None else 0,
|
||||
}
|
||||
|
||||
async def get_track_statistics(self) -> dict[str, int | float]:
|
||||
"""Get statistics for EXT type sounds."""
|
||||
try:
|
||||
statement = select(
|
||||
func.count(Sound.id).label("count"),
|
||||
func.sum(Sound.play_count).label("total_plays"),
|
||||
func.sum(Sound.duration).label("total_duration"),
|
||||
func.sum(
|
||||
Sound.size + func.coalesce(Sound.normalized_size, 0),
|
||||
).label("total_size"),
|
||||
).where(Sound.type == "EXT")
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
row = result.first()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to get track statistics")
|
||||
raise
|
||||
else:
|
||||
return {
|
||||
"count": row.count if row.count is not None else 0,
|
||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||
"total_duration": (
|
||||
row.total_duration if row.total_duration is not None else 0
|
||||
),
|
||||
"total_size": row.total_size if row.total_size is not None else 0,
|
||||
}
|
||||
|
||||
async def get_top_sounds(
|
||||
self,
|
||||
sound_type: str,
|
||||
date_filter: datetime | None = None,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""Get top sounds by play count for a specific type and period."""
|
||||
try:
|
||||
# Join SoundPlayed with Sound and count plays within the period
|
||||
statement = (
|
||||
select(
|
||||
Sound.id,
|
||||
Sound.name,
|
||||
Sound.type,
|
||||
Sound.duration,
|
||||
Sound.created_at,
|
||||
func.count(SoundPlayed.id).label("play_count"),
|
||||
)
|
||||
.select_from(SoundPlayed)
|
||||
.join(Sound, SoundPlayed.sound_id == Sound.id)
|
||||
)
|
||||
|
||||
# Apply sound type filter
|
||||
if sound_type != "all":
|
||||
statement = statement.where(Sound.type == sound_type.upper())
|
||||
|
||||
# Apply date filter if provided
|
||||
if date_filter:
|
||||
statement = statement.where(SoundPlayed.created_at >= date_filter)
|
||||
|
||||
# Group by sound and order by play count descending
|
||||
statement = (
|
||||
statement.group_by(
|
||||
Sound.id,
|
||||
Sound.name,
|
||||
Sound.type,
|
||||
Sound.duration,
|
||||
Sound.created_at,
|
||||
)
|
||||
.order_by(func.count(SoundPlayed.id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to dictionaries with the play count from the period
|
||||
return [
|
||||
{
|
||||
"id": row.id,
|
||||
"name": row.name,
|
||||
"type": row.type,
|
||||
"play_count": row.play_count,
|
||||
"duration": row.duration,
|
||||
"created_at": row.created_at,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get top sounds: type=%s, date_filter=%s, limit=%s",
|
||||
sound_type,
|
||||
date_filter,
|
||||
limit,
|
||||
)
|
||||
raise
|
||||
|
||||
74
app/repositories/tts.py
Normal file
74
app/repositories/tts.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""TTS repository for database operations."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from app.models.tts import TTS
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class TTSRepository(BaseRepository[TTS]):
|
||||
"""Repository for TTS operations."""
|
||||
|
||||
def __init__(self, session: "AsyncSession") -> None:
|
||||
"""Initialize TTS repository.
|
||||
|
||||
Args:
|
||||
session: Database session for operations
|
||||
|
||||
"""
|
||||
super().__init__(TTS, session)
|
||||
|
||||
async def get_by_user_id(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> Sequence[TTS]:
|
||||
"""Get TTS records by user ID with pagination.
|
||||
|
||||
Args:
|
||||
user_id: User ID to filter by
|
||||
limit: Maximum number of records to return
|
||||
offset: Number of records to skip
|
||||
|
||||
Returns:
|
||||
List of TTS records
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(self.model)
|
||||
.where(self.model.user_id == user_id)
|
||||
.order_by(self.model.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return result.all()
|
||||
|
||||
async def get_by_user_and_id(
|
||||
self,
|
||||
user_id: int,
|
||||
tts_id: int,
|
||||
) -> TTS | None:
|
||||
"""Get a specific TTS record by user ID and TTS ID.
|
||||
|
||||
Args:
|
||||
user_id: User ID to filter by
|
||||
tts_id: TTS ID to retrieve
|
||||
|
||||
Returns:
|
||||
TTS record if found and belongs to user, None otherwise
|
||||
|
||||
"""
|
||||
stmt = select(self.model).where(
|
||||
self.model.id == tts_id,
|
||||
self.model.user_id == user_id,
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return result.first()
|
||||
@@ -1,32 +1,156 @@
|
||||
"""User repository."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Select, func
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.extraction import Extraction
|
||||
from app.models.plan import Plan
|
||||
from app.models.playlist import Playlist
|
||||
from app.models.sound_played import SoundPlayed
|
||||
from app.models.tts import TTS
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UserRepository:
|
||||
class UserSortField(str, Enum):
|
||||
"""User sort fields."""
|
||||
|
||||
NAME = "name"
|
||||
EMAIL = "email"
|
||||
ROLE = "role"
|
||||
CREDITS = "credits"
|
||||
CREATED_AT = "created_at"
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order."""
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
"""User status filter."""
|
||||
|
||||
ALL = "all"
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User]):
|
||||
"""Repository for user operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the user repository."""
|
||||
self.session = session
|
||||
super().__init__(User, session)
|
||||
|
||||
async def get_by_id(self, user_id: int) -> User | None:
|
||||
"""Get a user by ID."""
|
||||
async def get_all_with_plan(
|
||||
self,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[User]:
|
||||
"""Get all users with plan relationship loaded."""
|
||||
try:
|
||||
statement = select(User).where(User.id == user_id)
|
||||
statement = (
|
||||
select(User)
|
||||
.options(selectinload(User.plan))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get all users with plan")
|
||||
raise
|
||||
|
||||
async def get_all_with_plan_paginated( # noqa: PLR0913
|
||||
self,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
search: str | None = None,
|
||||
sort_by: UserSortField = UserSortField.NAME,
|
||||
sort_order: SortOrder = SortOrder.ASC,
|
||||
status_filter: UserStatus = UserStatus.ALL,
|
||||
) -> tuple[list[User], int]:
|
||||
"""Get all users with plan relationship loaded and return total count."""
|
||||
try:
|
||||
# Calculate offset
|
||||
offset = (page - 1) * limit
|
||||
|
||||
# Build base query
|
||||
base_query = select(User).options(selectinload(User.plan))
|
||||
count_query = select(func.count(User.id))
|
||||
|
||||
# Apply search filter
|
||||
if search and search.strip():
|
||||
search_pattern = f"%{search.strip().lower()}%"
|
||||
search_condition = func.lower(User.name).like(
|
||||
search_pattern,
|
||||
) | func.lower(User.email).like(search_pattern)
|
||||
base_query = base_query.where(search_condition)
|
||||
count_query = count_query.where(search_condition)
|
||||
|
||||
# Apply status filter
|
||||
if status_filter == UserStatus.ACTIVE:
|
||||
base_query = base_query.where(User.is_active == True) # noqa: E712
|
||||
count_query = count_query.where(User.is_active == True) # noqa: E712
|
||||
elif status_filter == UserStatus.INACTIVE:
|
||||
base_query = base_query.where(User.is_active == False) # noqa: E712
|
||||
count_query = count_query.where(User.is_active == False) # noqa: E712
|
||||
|
||||
# Apply sorting
|
||||
sort_column = {
|
||||
UserSortField.NAME: User.name,
|
||||
UserSortField.EMAIL: User.email,
|
||||
UserSortField.ROLE: User.role,
|
||||
UserSortField.CREDITS: User.credits,
|
||||
UserSortField.CREATED_AT: User.created_at,
|
||||
}.get(sort_by, User.name)
|
||||
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(sort_column.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(sort_column.asc())
|
||||
|
||||
# Get total count
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
|
||||
# Apply pagination and get results
|
||||
paginated_query = base_query.limit(limit).offset(offset)
|
||||
result = await self.session.exec(paginated_query)
|
||||
users = list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get paginated users with plan")
|
||||
raise
|
||||
else:
|
||||
return users, total_count
|
||||
|
||||
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
|
||||
"""Get a user by ID with plan relationship loaded."""
|
||||
try:
|
||||
statement = (
|
||||
select(User)
|
||||
.options(selectinload(User.plan))
|
||||
.where(User.id == entity_id)
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get user by ID: %s", user_id)
|
||||
logger.exception(
|
||||
"Failed to get user by ID with plan: %s",
|
||||
entity_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_email(self, email: str) -> User | None:
|
||||
@@ -49,8 +173,8 @@ class UserRepository:
|
||||
logger.exception("Failed to get user by API token")
|
||||
raise
|
||||
|
||||
async def create(self, user_data: dict[str, Any]) -> User:
|
||||
"""Create a new user."""
|
||||
async def create(self, entity_data: dict[str, Any]) -> User:
|
||||
"""Create a new user with plan assignment and first user admin logic."""
|
||||
|
||||
def _raise_plan_not_found() -> None:
|
||||
msg = "Default plan not found"
|
||||
@@ -65,7 +189,7 @@ class UserRepository:
|
||||
if is_first_user:
|
||||
# First user gets admin role and pro plan
|
||||
plan_statement = select(Plan).where(Plan.code == "pro")
|
||||
user_data["role"] = "admin"
|
||||
entity_data["role"] = "admin"
|
||||
logger.info("Creating first user with admin role and pro plan")
|
||||
else:
|
||||
# Regular users get free plan
|
||||
@@ -81,48 +205,14 @@ class UserRepository:
|
||||
assert default_plan is not None # noqa: S101
|
||||
|
||||
# Set plan_id and default credits
|
||||
user_data["plan_id"] = default_plan.id
|
||||
user_data["credits"] = default_plan.credits
|
||||
entity_data["plan_id"] = default_plan.id
|
||||
entity_data["credits"] = default_plan.credits
|
||||
|
||||
user = User(**user_data)
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(user)
|
||||
# Use BaseRepository's create method
|
||||
return await super().create(entity_data)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create user")
|
||||
raise
|
||||
else:
|
||||
logger.info("Created new user with email: %s", user.email)
|
||||
return user
|
||||
|
||||
async def update(self, user: User, update_data: dict[str, Any]) -> User:
|
||||
"""Update a user."""
|
||||
try:
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(user)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update user")
|
||||
raise
|
||||
else:
|
||||
logger.info("Updated user: %s", user.email)
|
||||
return user
|
||||
|
||||
async def delete(self, user: User) -> None:
|
||||
"""Delete a user."""
|
||||
try:
|
||||
await self.session.delete(user)
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Deleted user: %s", user.email)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete user")
|
||||
raise
|
||||
|
||||
async def email_exists(self, email: str) -> bool:
|
||||
"""Check if an email address is already registered."""
|
||||
@@ -133,3 +223,146 @@ class UserRepository:
|
||||
except Exception:
|
||||
logger.exception("Failed to check if email exists: %s", email)
|
||||
raise
|
||||
|
||||
async def get_top_users(
|
||||
self,
|
||||
metric_type: str,
|
||||
date_filter: datetime | None = None,
|
||||
limit: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get top users by different metrics."""
|
||||
try:
|
||||
query = self._build_top_users_query(metric_type, date_filter)
|
||||
|
||||
# Add ordering and limit
|
||||
query = query.order_by(func.count().desc()).limit(limit)
|
||||
|
||||
result = await self.session.exec(query)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"count": int(row[2]),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get top users for metric=%s, date_filter=%s",
|
||||
metric_type,
|
||||
date_filter,
|
||||
)
|
||||
raise
|
||||
|
||||
def _build_top_users_query(
|
||||
self,
|
||||
metric_type: str,
|
||||
date_filter: datetime | None,
|
||||
) -> Select:
|
||||
"""Build query for top users based on metric type."""
|
||||
match metric_type:
|
||||
case "sounds_played":
|
||||
query = self._build_sounds_played_query()
|
||||
case "credits_used":
|
||||
query = self._build_credits_used_query()
|
||||
case "tracks_added":
|
||||
query = self._build_tracks_added_query()
|
||||
case "tts_added":
|
||||
query = self._build_tts_added_query()
|
||||
case "playlists_created":
|
||||
query = self._build_playlists_created_query()
|
||||
case _:
|
||||
msg = f"Unknown metric type: {metric_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Apply date filter if provided
|
||||
if date_filter:
|
||||
query = self._apply_date_filter(query, metric_type, date_filter)
|
||||
|
||||
return query
|
||||
|
||||
def _build_sounds_played_query(self) -> Select:
|
||||
"""Build query for sounds played metric."""
|
||||
return (
|
||||
select(
|
||||
User.id,
|
||||
User.name,
|
||||
func.count(SoundPlayed.id).label("count"),
|
||||
)
|
||||
.join(SoundPlayed, User.id == SoundPlayed.user_id)
|
||||
.group_by(User.id, User.name)
|
||||
)
|
||||
|
||||
def _build_credits_used_query(self) -> Select:
|
||||
"""Build query for credits used metric."""
|
||||
return (
|
||||
select(
|
||||
User.id,
|
||||
User.name,
|
||||
func.sum(func.abs(CreditTransaction.amount)).label("count"),
|
||||
)
|
||||
.join(CreditTransaction, User.id == CreditTransaction.user_id)
|
||||
.where(CreditTransaction.amount < 0)
|
||||
.group_by(User.id, User.name)
|
||||
)
|
||||
|
||||
def _build_tracks_added_query(self) -> Select:
|
||||
"""Build query for tracks added metric."""
|
||||
return (
|
||||
select(
|
||||
User.id,
|
||||
User.name,
|
||||
func.count(Extraction.id).label("count"),
|
||||
)
|
||||
.join(Extraction, User.id == Extraction.user_id)
|
||||
.where(Extraction.sound_id.is_not(None))
|
||||
.group_by(User.id, User.name)
|
||||
)
|
||||
|
||||
def _build_tts_added_query(self) -> Select:
|
||||
"""Build query for TTS added metric."""
|
||||
return (
|
||||
select(
|
||||
User.id,
|
||||
User.name,
|
||||
func.count(TTS.id).label("count"),
|
||||
)
|
||||
.join(TTS, User.id == TTS.user_id)
|
||||
.group_by(User.id, User.name)
|
||||
)
|
||||
|
||||
def _build_playlists_created_query(self) -> Select:
|
||||
"""Build query for playlists created metric."""
|
||||
return (
|
||||
select(
|
||||
User.id,
|
||||
User.name,
|
||||
func.count(Playlist.id).label("count"),
|
||||
)
|
||||
.join(Playlist, User.id == Playlist.user_id)
|
||||
.group_by(User.id, User.name)
|
||||
)
|
||||
|
||||
def _apply_date_filter(
|
||||
self,
|
||||
query: Select,
|
||||
metric_type: str,
|
||||
date_filter: datetime,
|
||||
) -> Select:
|
||||
"""Apply date filter to query based on metric type."""
|
||||
match metric_type:
|
||||
case "sounds_played":
|
||||
return query.where(SoundPlayed.created_at >= date_filter)
|
||||
case "credits_used":
|
||||
return query.where(CreditTransaction.created_at >= date_filter)
|
||||
case "tracks_added":
|
||||
return query.where(Extraction.created_at >= date_filter)
|
||||
case "tts_added":
|
||||
return query.where(TTS.created_at >= date_filter)
|
||||
case "playlists_created":
|
||||
return query.where(Playlist.created_at >= date_filter)
|
||||
case _:
|
||||
return query
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
"""Repository for user OAuth operations."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user_oauth import UserOauth
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UserOauthRepository:
|
||||
class UserOauthRepository(BaseRepository[UserOauth]):
|
||||
"""Repository for user OAuth operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize repository with database session."""
|
||||
self.session = session
|
||||
super().__init__(UserOauth, session)
|
||||
|
||||
async def get_by_provider_user_id(
|
||||
self,
|
||||
@@ -61,57 +60,12 @@ class UserOauthRepository:
|
||||
else:
|
||||
return result.first()
|
||||
|
||||
async def create(self, oauth_data: dict[str, Any]) -> UserOauth:
|
||||
"""Create a new user OAuth record."""
|
||||
async def get_by_user_id(self, user_id: int) -> list[UserOauth]:
|
||||
"""Get all OAuth providers for a user."""
|
||||
try:
|
||||
oauth = UserOauth(**oauth_data)
|
||||
self.session.add(oauth)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(oauth)
|
||||
logger.info(
|
||||
"Created OAuth link for user %s with provider %s",
|
||||
oauth.user_id,
|
||||
oauth.provider,
|
||||
)
|
||||
statement = select(UserOauth).where(UserOauth.user_id == user_id)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create user OAuth")
|
||||
raise
|
||||
else:
|
||||
return oauth
|
||||
|
||||
async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth:
|
||||
"""Update a user OAuth record."""
|
||||
try:
|
||||
for key, value in update_data.items():
|
||||
setattr(oauth, key, value)
|
||||
|
||||
self.session.add(oauth)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(oauth)
|
||||
logger.info(
|
||||
"Updated OAuth link for user %s with provider %s",
|
||||
oauth.user_id,
|
||||
oauth.provider,
|
||||
)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update user OAuth")
|
||||
raise
|
||||
else:
|
||||
return oauth
|
||||
|
||||
async def delete(self, oauth: UserOauth) -> None:
|
||||
"""Delete a user OAuth record."""
|
||||
try:
|
||||
await self.session.delete(oauth)
|
||||
await self.session.commit()
|
||||
logger.info(
|
||||
"Deleted OAuth link for user %s with provider %s",
|
||||
oauth.user_id,
|
||||
oauth.provider,
|
||||
)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete user OAuth")
|
||||
logger.exception("Failed to get OAuth providers for user ID: %s", user_id)
|
||||
raise
|
||||
|
||||
@@ -28,25 +28,16 @@ from .playlist import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Auth schemas
|
||||
"ApiTokenRequest",
|
||||
"ApiTokenResponse",
|
||||
"ApiTokenStatusResponse",
|
||||
"AuthResponse",
|
||||
"TokenResponse",
|
||||
"UserLoginRequest",
|
||||
"UserRegisterRequest",
|
||||
"UserResponse",
|
||||
# Common schemas
|
||||
"HealthResponse",
|
||||
"MessageResponse",
|
||||
"StatusResponse",
|
||||
# Player schemas
|
||||
"MessageResponse",
|
||||
"PlayerModeRequest",
|
||||
"PlayerSeekRequest",
|
||||
"PlayerStateResponse",
|
||||
"PlayerVolumeRequest",
|
||||
# Playlist schemas
|
||||
"PlaylistAddSoundRequest",
|
||||
"PlaylistCreateRequest",
|
||||
"PlaylistReorderRequest",
|
||||
@@ -54,4 +45,9 @@ __all__ = [
|
||||
"PlaylistSoundResponse",
|
||||
"PlaylistStatsResponse",
|
||||
"PlaylistUpdateRequest",
|
||||
"StatusResponse",
|
||||
"TokenResponse",
|
||||
"UserLoginRequest",
|
||||
"UserRegisterRequest",
|
||||
"UserResponse",
|
||||
]
|
||||
|
||||
@@ -79,3 +79,28 @@ class ApiTokenStatusResponse(BaseModel):
|
||||
has_token: bool = Field(..., description="Whether user has an active API token")
|
||||
expires_at: datetime | None = Field(None, description="Token expiration timestamp")
|
||||
is_expired: bool = Field(..., description="Whether the token is expired")
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""Schema for password change request."""
|
||||
|
||||
current_password: str | None = Field(
|
||||
None,
|
||||
description="Current password (required if user has existing password)",
|
||||
)
|
||||
new_password: str = Field(
|
||||
...,
|
||||
min_length=8,
|
||||
description="New password (minimum 8 characters)",
|
||||
)
|
||||
|
||||
|
||||
class UpdateProfileRequest(BaseModel):
|
||||
"""Schema for profile update request."""
|
||||
|
||||
name: str | None = Field(
|
||||
None,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="User display name",
|
||||
)
|
||||
|
||||
@@ -18,4 +18,4 @@ class StatusResponse(BaseModel):
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(description="Health status")
|
||||
status: str = Field(description="Health status")
|
||||
|
||||
41
app/schemas/favorite.py
Normal file
41
app/schemas/favorite.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Favorite response schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FavoriteResponse(BaseModel):
|
||||
"""Response schema for a favorite."""
|
||||
|
||||
id: int = Field(description="Favorite ID")
|
||||
user_id: int = Field(description="User ID")
|
||||
sound_id: int | None = Field(
|
||||
description="Sound ID if this is a sound favorite",
|
||||
default=None,
|
||||
)
|
||||
playlist_id: int | None = Field(
|
||||
description="Playlist ID if this is a playlist favorite",
|
||||
default=None,
|
||||
)
|
||||
created_at: datetime = Field(description="Creation timestamp")
|
||||
updated_at: datetime = Field(description="Last update timestamp")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class FavoritesListResponse(BaseModel):
|
||||
"""Response schema for a list of favorites."""
|
||||
|
||||
favorites: list[FavoriteResponse] = Field(description="List of favorites")
|
||||
|
||||
|
||||
class FavoriteCountsResponse(BaseModel):
|
||||
"""Response schema for favorite counts."""
|
||||
|
||||
total: int = Field(description="Total number of favorites")
|
||||
sounds: int = Field(description="Number of favorited sounds")
|
||||
playlists: int = Field(description="Number of favorited playlists")
|
||||
@@ -10,7 +10,7 @@ from app.services.player import PlayerMode
|
||||
class PlayerSeekRequest(BaseModel):
|
||||
"""Request model for seek operation."""
|
||||
|
||||
position_ms: int = Field(ge=0, description="Position in milliseconds")
|
||||
position: int = Field(ge=0, description="Position in milliseconds")
|
||||
|
||||
|
||||
class PlayerVolumeRequest(BaseModel):
|
||||
@@ -30,17 +30,26 @@ class PlayerStateResponse(BaseModel):
|
||||
|
||||
status: str = Field(description="Player status (playing, paused, stopped)")
|
||||
current_sound: dict[str, Any] | None = Field(
|
||||
None, description="Current sound information"
|
||||
None,
|
||||
description="Current sound information",
|
||||
)
|
||||
playlist: dict[str, Any] | None = Field(
|
||||
None, description="Current playlist information"
|
||||
None,
|
||||
description="Current playlist information",
|
||||
)
|
||||
position_ms: int = Field(description="Current position in milliseconds")
|
||||
duration_ms: int | None = Field(
|
||||
None, description="Total duration in milliseconds",
|
||||
position: int = Field(description="Current position in milliseconds")
|
||||
duration: int | None = Field(
|
||||
None,
|
||||
description="Total duration in milliseconds",
|
||||
)
|
||||
volume: int = Field(description="Current volume (0-100)")
|
||||
previous_volume: int = Field(description="Previous volume for unmuting (0-100)")
|
||||
mode: str = Field(description="Current playback mode")
|
||||
index: int | None = Field(
|
||||
None, description="Current track index in playlist",
|
||||
None,
|
||||
description="Current track index in playlist",
|
||||
)
|
||||
play_next_queue: list[dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Play next queue",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Playlist schemas."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.playlist import Playlist
|
||||
from app.models.sound import Sound
|
||||
@@ -33,14 +33,32 @@ class PlaylistResponse(BaseModel):
|
||||
is_main: bool
|
||||
is_current: bool
|
||||
is_deletable: bool
|
||||
is_favorited: bool = False
|
||||
favorite_count: int = 0
|
||||
created_at: str
|
||||
updated_at: str | None
|
||||
|
||||
@classmethod
|
||||
def from_playlist(cls, playlist: Playlist) -> "PlaylistResponse":
|
||||
"""Create response from playlist model."""
|
||||
def from_playlist(
|
||||
cls,
|
||||
playlist: Playlist,
|
||||
is_favorited: bool = False, # noqa: FBT001, FBT002
|
||||
favorite_count: int = 0,
|
||||
) -> "PlaylistResponse":
|
||||
"""Create response from playlist model.
|
||||
|
||||
Args:
|
||||
playlist: The Playlist model
|
||||
is_favorited: Whether the playlist is favorited by the current user
|
||||
favorite_count: Number of users who favorited this playlist
|
||||
|
||||
Returns:
|
||||
PlaylistResponse instance
|
||||
|
||||
"""
|
||||
if playlist.id is None:
|
||||
raise ValueError("Playlist ID cannot be None")
|
||||
msg = "Playlist ID cannot be None"
|
||||
raise ValueError(msg)
|
||||
return cls(
|
||||
id=playlist.id,
|
||||
name=playlist.name,
|
||||
@@ -49,6 +67,8 @@ class PlaylistResponse(BaseModel):
|
||||
is_main=playlist.is_main,
|
||||
is_current=playlist.is_current,
|
||||
is_deletable=playlist.is_deletable,
|
||||
is_favorited=is_favorited,
|
||||
favorite_count=favorite_count,
|
||||
created_at=playlist.created_at.isoformat(),
|
||||
updated_at=playlist.updated_at.isoformat() if playlist.updated_at else None,
|
||||
)
|
||||
@@ -70,7 +90,8 @@ class PlaylistSoundResponse(BaseModel):
|
||||
def from_sound(cls, sound: Sound) -> "PlaylistSoundResponse":
|
||||
"""Create response from sound model."""
|
||||
if sound.id is None:
|
||||
raise ValueError("Sound ID cannot be None")
|
||||
msg = "Sound ID cannot be None"
|
||||
raise ValueError(msg)
|
||||
return cls(
|
||||
id=sound.id,
|
||||
name=sound.name,
|
||||
|
||||
197
app/schemas/scheduler.py
Normal file
197
app/schemas/scheduler.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Schemas for scheduled task API."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.models.scheduled_task import RecurrenceType, TaskStatus, TaskType
|
||||
|
||||
|
||||
class TaskFilterParams(BaseModel):
|
||||
"""Query parameters for filtering tasks."""
|
||||
|
||||
status: TaskStatus | None = Field(default=None, description="Filter by task status")
|
||||
task_type: TaskType | None = Field(default=None, description="Filter by task type")
|
||||
limit: int = Field(default=50, description="Maximum number of tasks to return")
|
||||
offset: int = Field(default=0, description="Number of tasks to skip")
|
||||
|
||||
|
||||
class ScheduledTaskBase(BaseModel):
|
||||
"""Base schema for scheduled tasks."""
|
||||
|
||||
name: str = Field(description="Human-readable task name")
|
||||
task_type: TaskType = Field(description="Type of task to execute")
|
||||
scheduled_at: datetime = Field(description="When the task should be executed")
|
||||
timezone: str = Field(default="UTC", description="Timezone for scheduling")
|
||||
parameters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Task-specific parameters",
|
||||
)
|
||||
recurrence_type: RecurrenceType = Field(
|
||||
default=RecurrenceType.NONE,
|
||||
description="Recurrence pattern",
|
||||
)
|
||||
cron_expression: str | None = Field(
|
||||
default=None,
|
||||
description="Cron expression for custom recurrence",
|
||||
)
|
||||
recurrence_count: int | None = Field(
|
||||
default=None,
|
||||
description="Number of times to repeat (None for infinite)",
|
||||
)
|
||||
expires_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="When the task expires (optional)",
|
||||
)
|
||||
|
||||
|
||||
class ScheduledTaskCreate(ScheduledTaskBase):
|
||||
"""Schema for creating a scheduled task."""
|
||||
|
||||
|
||||
|
||||
class ScheduledTaskUpdate(BaseModel):
|
||||
"""Schema for updating a scheduled task."""
|
||||
|
||||
name: str | None = None
|
||||
scheduled_at: datetime | None = None
|
||||
timezone: str | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
is_active: bool | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class ScheduledTaskResponse(ScheduledTaskBase):
|
||||
"""Schema for scheduled task responses."""
|
||||
|
||||
id: int
|
||||
status: TaskStatus
|
||||
user_id: int | None = None
|
||||
executions_count: int
|
||||
last_executed_at: datetime | None = None
|
||||
next_execution_at: datetime | None = None
|
||||
error_message: str | None = None
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration."""
|
||||
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Task-specific parameter schemas
|
||||
class CreditRechargeParameters(BaseModel):
|
||||
"""Parameters for credit recharge tasks."""
|
||||
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
description="Specific user ID to recharge (None for all users)",
|
||||
)
|
||||
|
||||
|
||||
class PlaySoundParameters(BaseModel):
|
||||
"""Parameters for play sound tasks."""
|
||||
|
||||
sound_id: int = Field(description="ID of the sound to play")
|
||||
|
||||
|
||||
class PlayPlaylistParameters(BaseModel):
|
||||
"""Parameters for play playlist tasks."""
|
||||
|
||||
playlist_id: int = Field(description="ID of the playlist to play")
|
||||
play_mode: str = Field(
|
||||
default="continuous",
|
||||
description="Play mode (continuous, loop, loop_one, random, single)",
|
||||
)
|
||||
shuffle: bool = Field(default=False, description="Whether to shuffle the playlist")
|
||||
|
||||
|
||||
# Convenience schemas for creating specific task types
|
||||
class CreateCreditRechargeTask(BaseModel):
|
||||
"""Schema for creating credit recharge tasks."""
|
||||
|
||||
name: str = "Credit Recharge"
|
||||
scheduled_at: datetime
|
||||
timezone: str = "UTC"
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||
cron_expression: str | None = None
|
||||
recurrence_count: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
user_id: int | None = None
|
||||
|
||||
def to_task_create(self) -> ScheduledTaskCreate:
|
||||
"""Convert to generic task creation schema."""
|
||||
return ScheduledTaskCreate(
|
||||
name=self.name,
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=self.scheduled_at,
|
||||
timezone=self.timezone,
|
||||
parameters={"user_id": self.user_id},
|
||||
recurrence_type=self.recurrence_type,
|
||||
cron_expression=self.cron_expression,
|
||||
recurrence_count=self.recurrence_count,
|
||||
expires_at=self.expires_at,
|
||||
)
|
||||
|
||||
|
||||
class CreatePlaySoundTask(BaseModel):
|
||||
"""Schema for creating play sound tasks."""
|
||||
|
||||
name: str
|
||||
scheduled_at: datetime
|
||||
sound_id: int
|
||||
timezone: str = "UTC"
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||
cron_expression: str | None = None
|
||||
recurrence_count: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
def to_task_create(self) -> ScheduledTaskCreate:
|
||||
"""Convert to generic task creation schema."""
|
||||
return ScheduledTaskCreate(
|
||||
name=self.name,
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=self.scheduled_at,
|
||||
timezone=self.timezone,
|
||||
parameters={"sound_id": self.sound_id},
|
||||
recurrence_type=self.recurrence_type,
|
||||
cron_expression=self.cron_expression,
|
||||
recurrence_count=self.recurrence_count,
|
||||
expires_at=self.expires_at,
|
||||
)
|
||||
|
||||
|
||||
class CreatePlayPlaylistTask(BaseModel):
|
||||
"""Schema for creating play playlist tasks."""
|
||||
|
||||
name: str
|
||||
scheduled_at: datetime
|
||||
playlist_id: int
|
||||
play_mode: str = "continuous"
|
||||
shuffle: bool = False
|
||||
timezone: str = "UTC"
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||
cron_expression: str | None = None
|
||||
recurrence_count: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
def to_task_create(self) -> ScheduledTaskCreate:
|
||||
"""Convert to generic task creation schema."""
|
||||
return ScheduledTaskCreate(
|
||||
name=self.name,
|
||||
task_type=TaskType.PLAY_PLAYLIST,
|
||||
scheduled_at=self.scheduled_at,
|
||||
timezone=self.timezone,
|
||||
parameters={
|
||||
"playlist_id": self.playlist_id,
|
||||
"play_mode": self.play_mode,
|
||||
"shuffle": self.shuffle,
|
||||
},
|
||||
recurrence_type=self.recurrence_type,
|
||||
cron_expression=self.cron_expression,
|
||||
recurrence_count=self.recurrence_count,
|
||||
expires_at=self.expires_at,
|
||||
)
|
||||
106
app/schemas/sound.py
Normal file
106
app/schemas/sound.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Sound response schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.models.sound import Sound
|
||||
|
||||
|
||||
class SoundResponse(BaseModel):
|
||||
"""Response schema for a sound with favorite indicator."""
|
||||
|
||||
id: int = Field(description="Sound ID")
|
||||
type: str = Field(description="Sound type")
|
||||
name: str = Field(description="Sound name")
|
||||
filename: str = Field(description="Sound filename")
|
||||
duration: int = Field(description="Duration in milliseconds")
|
||||
size: int = Field(description="File size in bytes")
|
||||
hash: str = Field(description="File hash")
|
||||
normalized_filename: str | None = Field(
|
||||
description="Normalized filename",
|
||||
default=None,
|
||||
)
|
||||
normalized_duration: int | None = Field(
|
||||
description="Normalized duration in milliseconds",
|
||||
default=None,
|
||||
)
|
||||
normalized_size: int | None = Field(
|
||||
description="Normalized file size in bytes",
|
||||
default=None,
|
||||
)
|
||||
normalized_hash: str | None = Field(
|
||||
description="Normalized file hash",
|
||||
default=None,
|
||||
)
|
||||
thumbnail: str | None = Field(description="Thumbnail filename", default=None)
|
||||
play_count: int = Field(description="Number of times played")
|
||||
is_normalized: bool = Field(description="Whether the sound is normalized")
|
||||
is_music: bool = Field(description="Whether the sound is music")
|
||||
is_deletable: bool = Field(description="Whether the sound can be deleted")
|
||||
is_favorited: bool = Field(
|
||||
description="Whether the sound is favorited by the current user",
|
||||
default=False,
|
||||
)
|
||||
favorite_count: int = Field(
|
||||
description="Number of users who favorited this sound",
|
||||
default=0,
|
||||
)
|
||||
created_at: datetime = Field(description="Creation timestamp")
|
||||
updated_at: datetime = Field(description="Last update timestamp")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
from_attributes = True
|
||||
|
||||
@classmethod
|
||||
def from_sound(
|
||||
cls,
|
||||
sound: Sound,
|
||||
is_favorited: bool = False, # noqa: FBT001, FBT002
|
||||
favorite_count: int = 0,
|
||||
) -> "SoundResponse":
|
||||
"""Create a SoundResponse from a Sound model.
|
||||
|
||||
Args:
|
||||
sound: The Sound model
|
||||
is_favorited: Whether the sound is favorited by the current user
|
||||
favorite_count: Number of users who favorited this sound
|
||||
|
||||
Returns:
|
||||
SoundResponse instance
|
||||
|
||||
"""
|
||||
if sound.id is None:
|
||||
msg = "Sound ID cannot be None"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls(
|
||||
id=sound.id,
|
||||
type=sound.type,
|
||||
name=sound.name,
|
||||
filename=sound.filename,
|
||||
duration=sound.duration,
|
||||
size=sound.size,
|
||||
hash=sound.hash,
|
||||
normalized_filename=sound.normalized_filename,
|
||||
normalized_duration=sound.normalized_duration,
|
||||
normalized_size=sound.normalized_size,
|
||||
normalized_hash=sound.normalized_hash,
|
||||
thumbnail=sound.thumbnail,
|
||||
play_count=sound.play_count,
|
||||
is_normalized=sound.is_normalized,
|
||||
is_music=sound.is_music,
|
||||
is_deletable=sound.is_deletable,
|
||||
is_favorited=is_favorited,
|
||||
favorite_count=favorite_count,
|
||||
created_at=sound.created_at,
|
||||
updated_at=sound.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class SoundsListResponse(BaseModel):
|
||||
"""Response schema for a list of sounds."""
|
||||
|
||||
sounds: list[SoundResponse] = Field(description="List of sounds")
|
||||
27
app/schemas/user.py
Normal file
27
app/schemas/user.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""User schemas."""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating a user."""
|
||||
|
||||
name: str | None = Field(
|
||||
None,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="User full name",
|
||||
)
|
||||
plan_id: int | None = Field(None, description="User plan ID")
|
||||
credits: int | None = Field(None, ge=0, description="User credits")
|
||||
is_active: bool | None = Field(None, description="Whether user is active")
|
||||
role: str | None = Field(None, description="User role (admin or user)")
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: str | None) -> str | None:
|
||||
"""Validate that role is either 'user' or 'admin'."""
|
||||
if v is not None and v not in {"user", "admin"}:
|
||||
msg = "Role must be either 'user' or 'admin'"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
@@ -20,11 +20,6 @@ from app.schemas.auth import (
|
||||
)
|
||||
from app.services.oauth import OAuthUserInfo
|
||||
from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
|
||||
from app.utils.exceptions import (
|
||||
raise_bad_request,
|
||||
raise_not_found,
|
||||
raise_unauthorized,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -44,7 +39,10 @@ class AuthService:
|
||||
|
||||
# Check if email already exists
|
||||
if await self.user_repo.email_exists(request.email):
|
||||
raise_bad_request("Email address is already registered")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email address is already registered",
|
||||
)
|
||||
|
||||
# Hash the password
|
||||
hashed_password = PasswordUtils.hash_password(request.password)
|
||||
@@ -77,18 +75,27 @@ class AuthService:
|
||||
# Get user by email
|
||||
user = await self.user_repo.get_by_email(request.email)
|
||||
if not user:
|
||||
raise_unauthorized("Invalid email or password")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
raise_unauthorized("Account is deactivated")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
|
||||
# Verify password
|
||||
if not user.password_hash or not PasswordUtils.verify_password(
|
||||
request.password,
|
||||
user.password_hash,
|
||||
):
|
||||
raise_unauthorized("Invalid email or password")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
|
||||
# Generate access token
|
||||
token = self._create_access_token(user)
|
||||
@@ -103,10 +110,16 @@ class AuthService:
|
||||
"""Get the current authenticated user."""
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise_not_found("User")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise_unauthorized("Account is deactivated")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@@ -417,3 +430,93 @@ class AuthService:
|
||||
oauth_user_info.email,
|
||||
)
|
||||
return AuthResponse(user=user_response, token=token)
|
||||
|
||||
async def update_user_profile(self, user: User, data: dict) -> User:
|
||||
"""Update user profile information."""
|
||||
logger.info("Updating profile for user: %s", user.email)
|
||||
|
||||
# Only allow updating specific fields
|
||||
allowed_fields = {"name"}
|
||||
update_data = {k: v for k, v in data.items() if k in allowed_fields}
|
||||
|
||||
if not update_data:
|
||||
return user
|
||||
|
||||
# Update user
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(user, ["plan"])
|
||||
|
||||
logger.info("Profile updated successfully for user: %s", user.email)
|
||||
return user
|
||||
|
||||
async def change_user_password(
|
||||
self,
|
||||
user: User,
|
||||
current_password: str | None,
|
||||
new_password: str,
|
||||
) -> None:
|
||||
"""Change user's password."""
|
||||
# Store user email before any operations to avoid session detachment issues
|
||||
user_email = user.email
|
||||
logger.info("Changing password for user: %s", user_email)
|
||||
|
||||
# Store whether user had existing password before we modify it
|
||||
had_existing_password = user.password_hash is not None
|
||||
|
||||
# If user has existing password, verify it
|
||||
if had_existing_password:
|
||||
if not current_password:
|
||||
msg = "Current password is required when changing existing password"
|
||||
raise ValueError(msg)
|
||||
if not PasswordUtils.verify_password(current_password, user.password_hash):
|
||||
msg = "Current password is incorrect"
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
# User doesn't have a password (OAuth-only user), setting first password
|
||||
logger.info("Setting first password for OAuth user: %s", user_email)
|
||||
|
||||
# Hash new password
|
||||
new_password_hash = PasswordUtils.hash_password(new_password)
|
||||
|
||||
# Update user
|
||||
user.password_hash = new_password_hash
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
|
||||
logger.info(
|
||||
"Password %s successfully for user: %s",
|
||||
"changed" if had_existing_password else "set",
|
||||
user_email,
|
||||
)
|
||||
|
||||
async def user_to_response(self, user: User) -> UserResponse:
|
||||
"""Convert User model to UserResponse with plan information."""
|
||||
# Load plan relationship if not already loaded
|
||||
if not hasattr(user, "plan") or not user.plan:
|
||||
await self.session.refresh(user, ["plan"])
|
||||
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
picture=user.picture,
|
||||
role=user.role,
|
||||
credits=user.credits,
|
||||
is_active=user.is_active,
|
||||
plan={
|
||||
"id": user.plan.id,
|
||||
"name": user.plan.name,
|
||||
"max_credits": user.plan.max_credits,
|
||||
"features": [], # Add features if needed
|
||||
},
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
async def get_user_oauth_providers(self, user: User) -> list:
|
||||
"""Get OAuth providers connected to the user."""
|
||||
return await self.oauth_repo.get_by_user_id(user.id)
|
||||
|
||||
@@ -30,7 +30,7 @@ class InsufficientCreditsError(Exception):
|
||||
self.required = required
|
||||
self.available = available
|
||||
super().__init__(
|
||||
f"Insufficient credits: {required} required, {available} available"
|
||||
f"Insufficient credits: {required} required, {available} available",
|
||||
)
|
||||
|
||||
|
||||
@@ -76,14 +76,12 @@ class CreditService:
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> tuple[User, CreditAction]:
|
||||
"""Validate user has sufficient credits and optionally reserve them.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
Tuple of (user, credit_action)
|
||||
@@ -118,6 +116,7 @@ class CreditService:
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
*,
|
||||
success: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
@@ -138,19 +137,27 @@ class CreditService:
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
|
||||
# Only deduct if action requires success and was successful, or doesn't require success
|
||||
should_deduct = (action.requires_success and success) or not action.requires_success
|
||||
|
||||
|
||||
# Only deduct if action requires success and was successful,
|
||||
# or doesn't require success
|
||||
should_deduct = (
|
||||
action.requires_success and success
|
||||
) or not action.requires_success
|
||||
|
||||
if not should_deduct:
|
||||
logger.info(
|
||||
"Skipping credit deduction for user %s: action %s failed and requires success",
|
||||
"Skipping credit deduction for user %s: "
|
||||
"action %s failed and requires success",
|
||||
user_id,
|
||||
action_type.value,
|
||||
)
|
||||
# Still create a transaction record for auditing
|
||||
return await self._create_transaction_record(
|
||||
user_id, action, 0, success, metadata
|
||||
user_id,
|
||||
action,
|
||||
0,
|
||||
success=success,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
session = self.db_session_factory()
|
||||
@@ -204,18 +211,23 @@ class CreditService:
|
||||
"action_type": action_type.value,
|
||||
"success": success,
|
||||
}
|
||||
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||
await socket_manager.send_to_user(
|
||||
str(user_id),
|
||||
"user_credits_changed",
|
||||
event_data,
|
||||
)
|
||||
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||
"Failed to emit user_credits_changed event for user %s",
|
||||
user_id,
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
else:
|
||||
return transaction
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@@ -292,18 +304,23 @@ class CreditService:
|
||||
"description": description,
|
||||
"success": True,
|
||||
}
|
||||
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||
await socket_manager.send_to_user(
|
||||
str(user_id),
|
||||
"user_credits_changed",
|
||||
event_data,
|
||||
)
|
||||
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||
"Failed to emit user_credits_changed event for user %s",
|
||||
user_id,
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
else:
|
||||
return transaction
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@@ -312,6 +329,7 @@ class CreditService:
|
||||
user_id: int,
|
||||
action: CreditAction,
|
||||
amount: int,
|
||||
*,
|
||||
success: bool,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
@@ -342,19 +360,22 @@ class CreditService:
|
||||
amount=amount,
|
||||
balance_before=user.credits,
|
||||
balance_after=user.credits,
|
||||
description=f"{action.description} (failed)" if not success else action.description,
|
||||
description=(
|
||||
f"{action.description} (failed)"
|
||||
if not success
|
||||
else action.description
|
||||
),
|
||||
success=success,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
else:
|
||||
return transaction
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@@ -380,4 +401,227 @@ class CreditService:
|
||||
raise ValueError(msg)
|
||||
return user.credits
|
||||
finally:
|
||||
await session.close()
|
||||
await session.close()
|
||||
|
||||
async def recharge_user_credits_auto(
|
||||
self,
|
||||
user_id: int,
|
||||
) -> CreditTransaction | None:
|
||||
"""Recharge credits for a user automatically based on their plan.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
The created credit transaction if credits were added, None if no recharge
|
||||
needed
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found or has no plan
|
||||
|
||||
"""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id_with_plan(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
if not user.plan:
|
||||
msg = f"User {user_id} has no plan assigned"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Call the main method with plan details
|
||||
return await self.recharge_user_credits(
|
||||
user_id,
|
||||
user.plan.credits,
|
||||
user.plan.max_credits,
|
||||
)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def recharge_user_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
plan_credits: int,
|
||||
max_credits: int,
|
||||
) -> CreditTransaction | None:
|
||||
"""Recharge credits for a user based on their plan.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
plan_credits: Number of credits from the plan
|
||||
max_credits: Maximum credits allowed for the plan
|
||||
|
||||
Returns:
|
||||
The created credit transaction if credits were added, None if no recharge
|
||||
needed
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found
|
||||
|
||||
"""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Calculate credits to add (can't exceed max_credits)
|
||||
current_credits = user.credits
|
||||
target_credits = min(current_credits + plan_credits, max_credits)
|
||||
credits_to_add = target_credits - current_credits
|
||||
|
||||
# If no credits to add, return None
|
||||
if credits_to_add <= 0:
|
||||
logger.info(
|
||||
"No credits to add for user %s: current=%s, "
|
||||
"plan_credits=%s, max=%s",
|
||||
user_id,
|
||||
current_credits,
|
||||
plan_credits,
|
||||
max_credits,
|
||||
)
|
||||
return None
|
||||
|
||||
# Record transaction
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type=CreditActionType.DAILY_RECHARGE.value,
|
||||
amount=credits_to_add,
|
||||
balance_before=current_credits,
|
||||
balance_after=target_credits,
|
||||
description="Daily credit recharge",
|
||||
success=True,
|
||||
metadata_json=json.dumps(
|
||||
{
|
||||
"plan_credits": plan_credits,
|
||||
"max_credits": max_credits,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Update user credits
|
||||
await user_repo.update(user, {"credits": target_credits})
|
||||
|
||||
# Save transaction
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Credits recharged for user %s: %s credits added (balance: %s → %s)",
|
||||
user_id,
|
||||
credits_to_add,
|
||||
current_credits,
|
||||
target_credits,
|
||||
)
|
||||
|
||||
# Emit user_credits_changed event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"user_id": str(user_id),
|
||||
"credits_before": current_credits,
|
||||
"credits_after": target_credits,
|
||||
"credits_added": credits_to_add,
|
||||
"description": "Daily credit recharge",
|
||||
"success": True,
|
||||
}
|
||||
await socket_manager.send_to_user(
|
||||
str(user_id),
|
||||
"user_credits_changed",
|
||||
event_data,
|
||||
)
|
||||
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to emit user_credits_changed event for user %s",
|
||||
user_id,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
else:
|
||||
return transaction
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def recharge_all_users_credits(self) -> dict[str, int]:
|
||||
"""Recharge credits for all users based on their plans.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about the recharge operation
|
||||
|
||||
"""
|
||||
session = self.db_session_factory()
|
||||
stats = {
|
||||
"total_users": 0,
|
||||
"recharged_users": 0,
|
||||
"skipped_users": 0,
|
||||
"total_credits_added": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# Process users in batches to avoid memory issues
|
||||
offset = 0
|
||||
batch_size = 100
|
||||
|
||||
while True:
|
||||
users = await user_repo.get_all_with_plan(
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
)
|
||||
if not users:
|
||||
break
|
||||
|
||||
for user in users:
|
||||
stats["total_users"] += 1
|
||||
|
||||
# Skip users without ID (shouldn't happen in practice)
|
||||
if user.id is None:
|
||||
continue
|
||||
|
||||
transaction = await self.recharge_user_credits(
|
||||
user.id,
|
||||
user.plan.credits,
|
||||
user.plan.max_credits,
|
||||
)
|
||||
|
||||
if transaction:
|
||||
stats["recharged_users"] += 1
|
||||
# Calculate the amount from plan data to avoid session issues
|
||||
current_credits = user.credits
|
||||
plan_credits = user.plan.credits
|
||||
max_credits = user.plan.max_credits
|
||||
target_credits = min(
|
||||
current_credits + plan_credits, max_credits,
|
||||
)
|
||||
credits_added = target_credits - current_credits
|
||||
stats["total_credits_added"] += credits_added
|
||||
else:
|
||||
stats["skipped_users"] += 1
|
||||
|
||||
offset += batch_size
|
||||
|
||||
# Break if we got fewer users than batch_size (last batch)
|
||||
if len(users) < batch_size:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"Daily credit recharge completed: %s total users, "
|
||||
"%s recharged, %s skipped, %s total credits added",
|
||||
stats["total_users"],
|
||||
stats["recharged_users"],
|
||||
stats["skipped_users"],
|
||||
stats["total_credits_added"],
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
161
app/services/dashboard.py
Normal file
161
app/services/dashboard.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Dashboard service for statistics and analytics."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DashboardService:
|
||||
"""Service for dashboard statistics and analytics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
user_repository: UserRepository,
|
||||
) -> None:
|
||||
"""Initialize the dashboard service."""
|
||||
self.sound_repository = sound_repository
|
||||
self.user_repository = user_repository
|
||||
|
||||
async def get_soundboard_statistics(self) -> dict[str, Any]:
|
||||
"""Get comprehensive soundboard statistics."""
|
||||
try:
|
||||
stats = await self.sound_repository.get_soundboard_statistics()
|
||||
|
||||
return {
|
||||
"sound_count": stats["count"],
|
||||
"total_play_count": stats["total_plays"],
|
||||
"total_duration": stats["total_duration"],
|
||||
"total_size": stats["total_size"],
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Failed to get soundboard statistics")
|
||||
raise
|
||||
|
||||
async def get_track_statistics(self) -> dict[str, Any]:
|
||||
"""Get comprehensive track statistics."""
|
||||
try:
|
||||
stats = await self.sound_repository.get_track_statistics()
|
||||
|
||||
return {
|
||||
"track_count": stats["count"],
|
||||
"total_play_count": stats["total_plays"],
|
||||
"total_duration": stats["total_duration"],
|
||||
"total_size": stats["total_size"],
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Failed to get track statistics")
|
||||
raise
|
||||
|
||||
async def get_top_sounds(
|
||||
self,
|
||||
sound_type: str,
|
||||
period: str = "all_time",
|
||||
limit: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get top sounds by play count for a specific period."""
|
||||
try:
|
||||
# Calculate the date filter based on period
|
||||
date_filter = self._get_date_filter(period)
|
||||
|
||||
# Get top sounds from repository
|
||||
top_sounds = await self.sound_repository.get_top_sounds(
|
||||
sound_type=sound_type,
|
||||
date_filter=date_filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": sound["id"],
|
||||
"name": sound["name"],
|
||||
"type": sound["type"],
|
||||
"play_count": sound["play_count"],
|
||||
"duration": sound["duration"],
|
||||
"created_at": (
|
||||
sound["created_at"].isoformat() if sound["created_at"] else None
|
||||
),
|
||||
}
|
||||
for sound in top_sounds
|
||||
]
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get top sounds for type=%s, period=%s",
|
||||
sound_type,
|
||||
period,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_tts_statistics(self) -> dict[str, Any]:
|
||||
"""Get comprehensive TTS statistics."""
|
||||
try:
|
||||
stats = await self.sound_repository.get_soundboard_statistics("TTS")
|
||||
|
||||
return {
|
||||
"sound_count": stats["count"],
|
||||
"total_play_count": stats["total_plays"],
|
||||
"total_duration": stats["total_duration"],
|
||||
"total_size": stats["total_size"],
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Failed to get TTS statistics")
|
||||
raise
|
||||
|
||||
async def get_top_users(
|
||||
self,
|
||||
metric_type: str,
|
||||
period: str = "all_time",
|
||||
limit: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get top users by different metrics for a specific period."""
|
||||
try:
|
||||
# Calculate the date filter based on period
|
||||
date_filter = self._get_date_filter(period)
|
||||
|
||||
# Get top users from repository
|
||||
top_users = await self.user_repository.get_top_users(
|
||||
metric_type=metric_type,
|
||||
date_filter=date_filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": user["id"],
|
||||
"name": user["name"],
|
||||
"count": user["count"],
|
||||
}
|
||||
for user in top_users
|
||||
]
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get top users for metric=%s, period=%s",
|
||||
metric_type,
|
||||
period,
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_date_filter(self, period: str) -> datetime | None: # noqa: PLR0911
|
||||
"""Calculate the date filter based on the period."""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
match period:
|
||||
case "today":
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
case "1_day":
|
||||
return now - timedelta(days=1)
|
||||
case "1_week":
|
||||
return now - timedelta(weeks=1)
|
||||
case "1_month":
|
||||
return now - timedelta(days=30)
|
||||
case "1_year":
|
||||
return now - timedelta(days=365)
|
||||
case "all_time":
|
||||
return None
|
||||
case _:
|
||||
return None # Default to all time for unknown periods
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import yt_dlp
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -14,6 +15,7 @@ from app.models.extraction import Extraction
|
||||
from app.models.sound import Sound
|
||||
from app.repositories.extraction import ExtractionRepository
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.repositories.user import UserRepository
|
||||
from app.services.playlist import PlaylistService
|
||||
from app.services.sound_normalizer import SoundNormalizerService
|
||||
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||
@@ -21,6 +23,18 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionContext:
|
||||
"""Context data for extraction processing."""
|
||||
|
||||
extraction_id: int
|
||||
extraction_url: str
|
||||
extraction_service: str | None
|
||||
extraction_service_id: str | None
|
||||
extraction_title: str | None
|
||||
user_id: int
|
||||
|
||||
|
||||
class ExtractionInfo(TypedDict):
|
||||
"""Type definition for extraction information."""
|
||||
|
||||
@@ -32,6 +46,20 @@ class ExtractionInfo(TypedDict):
|
||||
status: str
|
||||
error: str | None
|
||||
sound_id: int | None
|
||||
user_id: int
|
||||
user_name: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class PaginatedExtractionsResponse(TypedDict):
|
||||
"""Type definition for paginated extractions response."""
|
||||
|
||||
extractions: list[ExtractionInfo]
|
||||
total: int
|
||||
page: int
|
||||
limit: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
class ExtractionService:
|
||||
@@ -42,6 +70,7 @@ class ExtractionService:
|
||||
self.session = session
|
||||
self.extraction_repo = ExtractionRepository(session)
|
||||
self.sound_repo = SoundRepository(session)
|
||||
self.user_repo = UserRepository(session)
|
||||
self.playlist_service = PlaylistService(session)
|
||||
|
||||
# Ensure required directories exist
|
||||
@@ -64,6 +93,15 @@ class ExtractionService:
|
||||
logger.info("Creating extraction for URL: %s (user: %d)", url, user_id)
|
||||
|
||||
try:
|
||||
# Get user information
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Extract user name immediately while in session context
|
||||
user_name = user.name
|
||||
|
||||
# Create the extraction record without service detection for fast response
|
||||
extraction_data = {
|
||||
"url": url,
|
||||
@@ -76,7 +114,10 @@ class ExtractionService:
|
||||
|
||||
extraction = await self.extraction_repo.create(extraction_data)
|
||||
logger.info("Created extraction with ID: %d", extraction.id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to create extraction for URL: %s", url)
|
||||
raise
|
||||
else:
|
||||
return {
|
||||
"id": extraction.id or 0, # Should never be None for created extraction
|
||||
"url": extraction.url,
|
||||
@@ -86,12 +127,12 @@ class ExtractionService:
|
||||
"status": extraction.status,
|
||||
"error": extraction.error,
|
||||
"sound_id": extraction.sound_id,
|
||||
"user_id": extraction.user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": extraction.created_at.isoformat(),
|
||||
"updated_at": extraction.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to create extraction for URL: %s", url)
|
||||
raise
|
||||
|
||||
async def _detect_service_info(self, url: str) -> dict[str, str | None] | None:
|
||||
"""Detect service information from URL using yt-dlp."""
|
||||
try:
|
||||
@@ -123,14 +164,16 @@ class ExtractionService:
|
||||
logger.exception("Failed to detect service info for URL: %s", url)
|
||||
return None
|
||||
|
||||
async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
|
||||
"""Process an extraction job."""
|
||||
async def _validate_extraction(self, extraction_id: int) -> tuple:
|
||||
"""Validate extraction and return extraction data."""
|
||||
extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||
if not extraction:
|
||||
raise ValueError(f"Extraction {extraction_id} not found")
|
||||
msg = f"Extraction {extraction_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
if extraction.status != "pending":
|
||||
raise ValueError(f"Extraction {extraction_id} is not pending")
|
||||
msg = f"Extraction {extraction_id} is not pending"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Store all needed values early to avoid session detachment issues
|
||||
user_id = extraction.user_id
|
||||
@@ -139,106 +182,274 @@ class ExtractionService:
|
||||
extraction_service_id = extraction.service_id
|
||||
extraction_title = extraction.title
|
||||
|
||||
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
|
||||
|
||||
# Get user information for return value
|
||||
try:
|
||||
# Update status to processing
|
||||
await self.extraction_repo.update(extraction, {"status": "processing"})
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
user_name = user.name if user else None
|
||||
except Exception:
|
||||
logger.exception("Failed to get user %d for extraction", user_id)
|
||||
user_name = None
|
||||
|
||||
# Detect service info if not already available
|
||||
if not extraction_service or not extraction_service_id:
|
||||
logger.info("Detecting service info for extraction %d", extraction_id)
|
||||
service_info = await self._detect_service_info(extraction_url)
|
||||
return (
|
||||
extraction,
|
||||
user_id,
|
||||
extraction_url,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
extraction_title,
|
||||
user_name,
|
||||
)
|
||||
|
||||
if not service_info:
|
||||
raise ValueError("Unable to detect service information from URL")
|
||||
|
||||
# Check if extraction already exists for this service
|
||||
existing = await self.extraction_repo.get_by_service_and_id(
|
||||
service_info["service"], service_info["service_id"]
|
||||
)
|
||||
if existing and existing.id != extraction_id:
|
||||
error_msg = (
|
||||
f"Extraction already exists for "
|
||||
f"{service_info['service']}:{service_info['service_id']}"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Update extraction with service info
|
||||
update_data = {
|
||||
"service": service_info["service"],
|
||||
"service_id": service_info["service_id"],
|
||||
"title": service_info.get("title") or extraction_title,
|
||||
}
|
||||
await self.extraction_repo.update(extraction, update_data)
|
||||
|
||||
# Update values for processing
|
||||
extraction_service = service_info["service"]
|
||||
extraction_service_id = service_info["service_id"]
|
||||
extraction_title = service_info.get("title") or extraction_title
|
||||
|
||||
# Extract audio and thumbnail
|
||||
audio_file, thumbnail_file = await self._extract_media(
|
||||
extraction_id, extraction_url
|
||||
async def _handle_service_detection(
|
||||
self,
|
||||
extraction: Extraction,
|
||||
context: ExtractionContext,
|
||||
) -> tuple:
|
||||
"""Handle service detection and duplicate checking."""
|
||||
if context.extraction_service and context.extraction_service_id:
|
||||
return (
|
||||
context.extraction_service,
|
||||
context.extraction_service_id,
|
||||
context.extraction_title,
|
||||
)
|
||||
|
||||
# Move files to final locations
|
||||
(
|
||||
final_audio_path,
|
||||
final_thumbnail_path,
|
||||
) = await self._move_files_to_final_location(
|
||||
logger.info("Detecting service info for extraction %d", context.extraction_id)
|
||||
service_info = await self._detect_service_info(context.extraction_url)
|
||||
|
||||
if not service_info:
|
||||
msg = "Unable to detect service information from URL"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check if extraction already exists for this service
|
||||
service_name = service_info["service"]
|
||||
service_id_val = service_info["service_id"]
|
||||
|
||||
if not service_name or not service_id_val:
|
||||
msg = "Service info is incomplete"
|
||||
raise ValueError(msg)
|
||||
|
||||
existing = await self.extraction_repo.get_by_service_and_id(
|
||||
service_name,
|
||||
service_id_val,
|
||||
)
|
||||
if existing and existing.id != context.extraction_id:
|
||||
error_msg = (
|
||||
f"Extraction already exists for "
|
||||
f"{service_info['service']}:{service_info['service_id']}"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Update extraction with service info
|
||||
update_data = {
|
||||
"service": service_info["service"],
|
||||
"service_id": service_info["service_id"],
|
||||
"title": service_info.get("title") or context.extraction_title,
|
||||
}
|
||||
await self.extraction_repo.update(extraction, update_data)
|
||||
|
||||
# Update values for processing
|
||||
new_service = service_info["service"]
|
||||
new_service_id = service_info["service_id"]
|
||||
new_title = service_info.get("title") or context.extraction_title
|
||||
|
||||
await self._emit_extraction_event(
|
||||
context.user_id,
|
||||
{
|
||||
"extraction_id": context.extraction_id,
|
||||
"status": "processing",
|
||||
"title": new_title,
|
||||
"url": context.extraction_url,
|
||||
},
|
||||
)
|
||||
|
||||
return new_service, new_service_id, new_title
|
||||
|
||||
async def _process_media_files(
|
||||
self,
|
||||
extraction_id: int,
|
||||
extraction_url: str,
|
||||
extraction_title: str | None,
|
||||
extraction_service: str,
|
||||
extraction_service_id: str,
|
||||
) -> int:
|
||||
"""Process media files and create sound record."""
|
||||
# Extract audio and thumbnail
|
||||
audio_file, thumbnail_file = await self._extract_media(
|
||||
extraction_id,
|
||||
extraction_url,
|
||||
)
|
||||
|
||||
# Move files to final locations
|
||||
final_audio_path, final_thumbnail_path = (
|
||||
await self._move_files_to_final_location(
|
||||
audio_file,
|
||||
thumbnail_file,
|
||||
extraction_title,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Create Sound record
|
||||
sound = await self._create_sound_record(
|
||||
final_audio_path,
|
||||
extraction_title,
|
||||
# Create Sound record
|
||||
sound = await self._create_sound_record(
|
||||
final_audio_path,
|
||||
final_thumbnail_path,
|
||||
extraction_title,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
)
|
||||
|
||||
if not sound.id:
|
||||
msg = "Sound creation failed - no ID returned"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return sound.id
|
||||
|
||||
async def _complete_extraction(
|
||||
self,
|
||||
extraction: Extraction,
|
||||
context: ExtractionContext,
|
||||
sound_id: int,
|
||||
) -> None:
|
||||
"""Complete extraction processing."""
|
||||
# Normalize the sound
|
||||
await self._normalize_sound(sound_id)
|
||||
|
||||
# Add to main playlist
|
||||
await self._add_to_main_playlist(sound_id, context.user_id)
|
||||
|
||||
# Update extraction with success
|
||||
await self.extraction_repo.update(
|
||||
extraction,
|
||||
{
|
||||
"status": "completed",
|
||||
"sound_id": sound_id,
|
||||
"error": None,
|
||||
},
|
||||
)
|
||||
|
||||
# Emit WebSocket event for completion
|
||||
await self._emit_extraction_event(
|
||||
context.user_id,
|
||||
{
|
||||
"extraction_id": context.extraction_id,
|
||||
"status": "completed",
|
||||
"title": context.extraction_title,
|
||||
"url": context.extraction_url,
|
||||
"sound_id": sound_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
|
||||
"""Process an extraction job."""
|
||||
# Validate extraction and get context data
|
||||
(
|
||||
extraction,
|
||||
user_id,
|
||||
extraction_url,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
extraction_title,
|
||||
user_name,
|
||||
) = await self._validate_extraction(extraction_id)
|
||||
|
||||
# Create context object for helper methods
|
||||
context = ExtractionContext(
|
||||
extraction_id=extraction_id,
|
||||
extraction_url=extraction_url,
|
||||
extraction_service=extraction_service,
|
||||
extraction_service_id=extraction_service_id,
|
||||
extraction_title=extraction_title,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
|
||||
|
||||
try:
|
||||
# Update status to processing
|
||||
await self.extraction_repo.update(extraction, {"status": "processing"})
|
||||
|
||||
# Emit WebSocket event for processing start
|
||||
await self._emit_extraction_event(
|
||||
context.user_id,
|
||||
{
|
||||
"extraction_id": context.extraction_id,
|
||||
"status": "processing",
|
||||
"title": context.extraction_title or "Processing extraction...",
|
||||
"url": context.extraction_url,
|
||||
},
|
||||
)
|
||||
|
||||
# Handle service detection and duplicate checking
|
||||
extraction_service, extraction_service_id, extraction_title = (
|
||||
await self._handle_service_detection(extraction, context)
|
||||
)
|
||||
|
||||
# Update context with potentially new values
|
||||
context.extraction_service = extraction_service
|
||||
context.extraction_service_id = extraction_service_id
|
||||
context.extraction_title = extraction_title
|
||||
|
||||
# Process media files and create sound record
|
||||
sound_id = await self._process_media_files(
|
||||
context.extraction_id,
|
||||
context.extraction_url,
|
||||
context.extraction_title,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
)
|
||||
|
||||
# Store sound_id early to avoid session detachment issues
|
||||
sound_id = sound.id
|
||||
# Complete extraction processing
|
||||
await self._complete_extraction(extraction, context, sound_id)
|
||||
|
||||
# Normalize the sound
|
||||
await self._normalize_sound(sound_id)
|
||||
logger.info("Successfully processed extraction %d", context.extraction_id)
|
||||
|
||||
# Add to main playlist
|
||||
await self._add_to_main_playlist(sound_id, user_id)
|
||||
|
||||
# Update extraction with success
|
||||
await self.extraction_repo.update(
|
||||
extraction,
|
||||
{
|
||||
"status": "completed",
|
||||
"sound_id": sound_id,
|
||||
"error": None,
|
||||
},
|
||||
# Get updated extraction to get latest timestamps
|
||||
updated_extraction = await self.extraction_repo.get_by_id(
|
||||
context.extraction_id,
|
||||
)
|
||||
|
||||
logger.info("Successfully processed extraction %d", extraction_id)
|
||||
|
||||
return {
|
||||
"id": extraction_id,
|
||||
"url": extraction_url,
|
||||
"id": context.extraction_id,
|
||||
"url": context.extraction_url,
|
||||
"service": extraction_service,
|
||||
"service_id": extraction_service_id,
|
||||
"title": extraction_title,
|
||||
"status": "completed",
|
||||
"error": None,
|
||||
"sound_id": sound_id,
|
||||
"user_id": context.user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": (
|
||||
updated_extraction.created_at.isoformat()
|
||||
if updated_extraction
|
||||
else ""
|
||||
),
|
||||
"updated_at": (
|
||||
updated_extraction.updated_at.isoformat()
|
||||
if updated_extraction
|
||||
else ""
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(
|
||||
"Failed to process extraction %d: %s", extraction_id, error_msg
|
||||
"Failed to process extraction %d: %s",
|
||||
context.extraction_id,
|
||||
error_msg,
|
||||
)
|
||||
|
||||
# Emit WebSocket event for failure
|
||||
await self._emit_extraction_event(
|
||||
context.user_id,
|
||||
{
|
||||
"extraction_id": context.extraction_id,
|
||||
"status": "failed",
|
||||
"title": context.extraction_title or "Extraction failed",
|
||||
"url": context.extraction_url,
|
||||
"error": error_msg,
|
||||
},
|
||||
)
|
||||
|
||||
# Update extraction with error
|
||||
@@ -250,26 +461,44 @@ class ExtractionService:
|
||||
},
|
||||
)
|
||||
|
||||
# Get updated extraction to get latest timestamps
|
||||
updated_extraction = await self.extraction_repo.get_by_id(
|
||||
context.extraction_id,
|
||||
)
|
||||
return {
|
||||
"id": extraction_id,
|
||||
"url": extraction_url,
|
||||
"service": extraction_service,
|
||||
"service_id": extraction_service_id,
|
||||
"title": extraction_title,
|
||||
"id": context.extraction_id,
|
||||
"url": context.extraction_url,
|
||||
"service": context.extraction_service,
|
||||
"service_id": context.extraction_service_id,
|
||||
"title": context.extraction_title,
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"sound_id": None,
|
||||
"user_id": context.user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": (
|
||||
updated_extraction.created_at.isoformat()
|
||||
if updated_extraction
|
||||
else ""
|
||||
),
|
||||
"updated_at": (
|
||||
updated_extraction.updated_at.isoformat()
|
||||
if updated_extraction
|
||||
else ""
|
||||
),
|
||||
}
|
||||
|
||||
async def _extract_media(
|
||||
self, extraction_id: int, extraction_url: str
|
||||
self,
|
||||
extraction_id: int,
|
||||
extraction_url: str,
|
||||
) -> tuple[Path, Path | None]:
|
||||
"""Extract audio and thumbnail using yt-dlp."""
|
||||
temp_dir = Path(settings.EXTRACTION_TEMP_DIR)
|
||||
|
||||
# Create unique filename based on extraction ID
|
||||
output_template = str(
|
||||
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s"
|
||||
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s",
|
||||
)
|
||||
|
||||
# Configure yt-dlp options
|
||||
@@ -304,8 +533,8 @@ class ExtractionService:
|
||||
# Find the extracted files
|
||||
audio_files = list(
|
||||
temp_dir.glob(
|
||||
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}"
|
||||
)
|
||||
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}",
|
||||
),
|
||||
)
|
||||
thumbnail_files = (
|
||||
list(temp_dir.glob(f"extraction_{extraction_id}_*.webp"))
|
||||
@@ -314,7 +543,8 @@ class ExtractionService:
|
||||
)
|
||||
|
||||
if not audio_files:
|
||||
raise RuntimeError("No audio file was created during extraction")
|
||||
msg = "No audio file was created during extraction"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
audio_file = audio_files[0]
|
||||
thumbnail_file = thumbnail_files[0] if thumbnail_files else None
|
||||
@@ -325,11 +555,12 @@ class ExtractionService:
|
||||
thumbnail_file or "None",
|
||||
)
|
||||
|
||||
return audio_file, thumbnail_file
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("yt-dlp extraction failed for %s", extraction_url)
|
||||
raise RuntimeError(f"Audio extraction failed: {e}") from e
|
||||
error_msg = f"Audio extraction failed: {e}"
|
||||
raise RuntimeError(error_msg) from e
|
||||
else:
|
||||
return audio_file, thumbnail_file
|
||||
|
||||
async def _move_files_to_final_location(
|
||||
self,
|
||||
@@ -342,7 +573,7 @@ class ExtractionService:
|
||||
"""Move extracted files to their final locations."""
|
||||
# Generate clean filename based on title and service
|
||||
safe_title = self._sanitize_filename(
|
||||
title or f"{service or 'unknown'}_{service_id or 'unknown'}"
|
||||
title or f"{service or 'unknown'}_{service_id or 'unknown'}",
|
||||
)
|
||||
|
||||
# Move audio file
|
||||
@@ -401,6 +632,7 @@ class ExtractionService:
|
||||
async def _create_sound_record(
|
||||
self,
|
||||
audio_path: Path,
|
||||
thumbnail_path: Path | None,
|
||||
title: str | None,
|
||||
service: str | None,
|
||||
service_id: str | None,
|
||||
@@ -419,6 +651,7 @@ class ExtractionService:
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"hash": file_hash,
|
||||
"thumbnail": thumbnail_path.name if thumbnail_path else None,
|
||||
"is_deletable": True, # Extracted sounds can be deleted
|
||||
"is_music": True, # Assume extracted content is music
|
||||
"is_normalized": False,
|
||||
@@ -426,7 +659,11 @@ class ExtractionService:
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
logger.info("Created sound record with ID: %d", sound.id)
|
||||
logger.info(
|
||||
"Created sound record with ID: %d, thumbnail: %s",
|
||||
sound.id,
|
||||
thumbnail_path.name if thumbnail_path else "None",
|
||||
)
|
||||
|
||||
return sound
|
||||
|
||||
@@ -451,14 +688,17 @@ class ExtractionService:
|
||||
else:
|
||||
logger.info("Successfully normalized sound %d", sound_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error normalizing sound %d: %s", sound_id, e)
|
||||
except Exception:
|
||||
logger.exception("Error normalizing sound %d", sound_id)
|
||||
# Don't fail the extraction if normalization fails
|
||||
|
||||
async def _add_to_main_playlist(self, sound_id: int, user_id: int) -> None:
|
||||
"""Add the sound to the user's main playlist."""
|
||||
try:
|
||||
await self.playlist_service.add_sound_to_main_playlist(sound_id, user_id)
|
||||
await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001
|
||||
sound_id,
|
||||
user_id,
|
||||
)
|
||||
logger.info(
|
||||
"Added sound %d to main playlist for user %d",
|
||||
sound_id,
|
||||
@@ -473,12 +713,31 @@ class ExtractionService:
|
||||
)
|
||||
# Don't fail the extraction if playlist addition fails
|
||||
|
||||
async def _emit_extraction_event(self, user_id: int, data: dict) -> None:
|
||||
"""Emit WebSocket event for extraction status updates to all users."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from app.services.socket import socket_manager # noqa: PLC0415
|
||||
|
||||
await socket_manager.broadcast_to_all("extraction_status_update", data)
|
||||
logger.debug(
|
||||
"Broadcasted extraction event (initiated by user %d): %s",
|
||||
user_id,
|
||||
data["status"],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to emit extraction event")
|
||||
|
||||
async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None:
|
||||
"""Get extraction information by ID."""
|
||||
extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||
if not extraction:
|
||||
return None
|
||||
|
||||
# Get user information
|
||||
user = await self.user_repo.get_by_id(extraction.user_id)
|
||||
user_name = user.name if user else None
|
||||
|
||||
return {
|
||||
"id": extraction.id or 0, # Should never be None for existing extraction
|
||||
"url": extraction.url,
|
||||
@@ -488,13 +747,38 @@ class ExtractionService:
|
||||
"status": extraction.status,
|
||||
"error": extraction.error,
|
||||
"sound_id": extraction.sound_id,
|
||||
"user_id": extraction.user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": extraction.created_at.isoformat(),
|
||||
"updated_at": extraction.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
async def get_user_extractions(self, user_id: int) -> list[ExtractionInfo]:
|
||||
"""Get all extractions for a user."""
|
||||
extractions = await self.extraction_repo.get_by_user(user_id)
|
||||
async def get_user_extractions( # noqa: PLR0913
|
||||
self,
|
||||
user_id: int,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
status_filter: str | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
) -> PaginatedExtractionsResponse:
|
||||
"""Get all extractions for a user with filtering, search, and sorting."""
|
||||
offset = (page - 1) * limit
|
||||
(
|
||||
extraction_user_tuples,
|
||||
total_count,
|
||||
) = await self.extraction_repo.get_user_extractions_filtered(
|
||||
user_id=user_id,
|
||||
search=search,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
status_filter=status_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return [
|
||||
extractions = [
|
||||
{
|
||||
"id": extraction.id
|
||||
or 0, # Should never be None for existing extraction
|
||||
@@ -505,13 +789,79 @@ class ExtractionService:
|
||||
"status": extraction.status,
|
||||
"error": extraction.error,
|
||||
"sound_id": extraction.sound_id,
|
||||
"user_id": extraction.user_id,
|
||||
"user_name": user.name,
|
||||
"created_at": extraction.created_at.isoformat(),
|
||||
"updated_at": extraction.updated_at.isoformat(),
|
||||
}
|
||||
for extraction in extractions
|
||||
for extraction, user in extraction_user_tuples
|
||||
]
|
||||
|
||||
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||
|
||||
return {
|
||||
"extractions": extractions,
|
||||
"total": total_count,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"total_pages": total_pages,
|
||||
}
|
||||
|
||||
async def get_all_extractions( # noqa: PLR0913
|
||||
self,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
status_filter: str | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
) -> PaginatedExtractionsResponse:
|
||||
"""Get all extractions with filtering, search, and sorting."""
|
||||
offset = (page - 1) * limit
|
||||
(
|
||||
extraction_user_tuples,
|
||||
total_count,
|
||||
) = await self.extraction_repo.get_all_extractions_filtered(
|
||||
search=search,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
status_filter=status_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
extractions = [
|
||||
{
|
||||
"id": extraction.id
|
||||
or 0, # Should never be None for existing extraction
|
||||
"url": extraction.url,
|
||||
"service": extraction.service,
|
||||
"service_id": extraction.service_id,
|
||||
"title": extraction.title,
|
||||
"status": extraction.status,
|
||||
"error": extraction.error,
|
||||
"sound_id": extraction.sound_id,
|
||||
"user_id": extraction.user_id,
|
||||
"user_name": user.name,
|
||||
"created_at": extraction.created_at.isoformat(),
|
||||
"updated_at": extraction.updated_at.isoformat(),
|
||||
}
|
||||
for extraction, user in extraction_user_tuples
|
||||
]
|
||||
|
||||
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||
|
||||
return {
|
||||
"extractions": extractions,
|
||||
"total": total_count,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"total_pages": total_pages,
|
||||
}
|
||||
|
||||
async def get_pending_extractions(self) -> list[ExtractionInfo]:
|
||||
"""Get all pending extractions."""
|
||||
extractions = await self.extraction_repo.get_pending_extractions()
|
||||
extraction_user_tuples = await self.extraction_repo.get_pending_extractions()
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -524,6 +874,181 @@ class ExtractionService:
|
||||
"status": extraction.status,
|
||||
"error": extraction.error,
|
||||
"sound_id": extraction.sound_id,
|
||||
"user_id": extraction.user_id,
|
||||
"user_name": user.name,
|
||||
"created_at": extraction.created_at.isoformat(),
|
||||
"updated_at": extraction.updated_at.isoformat(),
|
||||
}
|
||||
for extraction in extractions
|
||||
for extraction, user in extraction_user_tuples
|
||||
]
|
||||
|
||||
async def delete_extraction(
|
||||
self,
|
||||
extraction_id: int,
|
||||
user_id: int | None = None,
|
||||
) -> bool:
|
||||
"""Delete an extraction and its associated sound and files.
|
||||
|
||||
Args:
|
||||
extraction_id: The ID of the extraction to delete
|
||||
user_id: Optional user ID for ownership verification (None for admin)
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False if extraction not found
|
||||
|
||||
Raises:
|
||||
ValueError: If user doesn't own the extraction (when user_id is provided)
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
"Deleting extraction: %d (user: %s)",
|
||||
extraction_id,
|
||||
user_id or "admin",
|
||||
)
|
||||
|
||||
# Get the extraction record
|
||||
extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||
if not extraction:
|
||||
logger.warning("Extraction %d not found", extraction_id)
|
||||
return False
|
||||
|
||||
# Check ownership if user_id is provided (non-admin request)
|
||||
if user_id is not None and extraction.user_id != user_id:
|
||||
msg = "You don't have permission to delete this extraction"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get associated sound if it exists and capture its attributes immediately
|
||||
sound_data = None
|
||||
sound_object = None
|
||||
if extraction.sound_id:
|
||||
sound_object = await self.sound_repo.get_by_id(extraction.sound_id)
|
||||
if sound_object:
|
||||
# Capture attributes immediately while session is valid
|
||||
sound_data = {
|
||||
"id": sound_object.id,
|
||||
"type": sound_object.type,
|
||||
"filename": sound_object.filename,
|
||||
"is_normalized": sound_object.is_normalized,
|
||||
"normalized_filename": sound_object.normalized_filename,
|
||||
"thumbnail": sound_object.thumbnail,
|
||||
}
|
||||
|
||||
try:
|
||||
# Delete the extraction record first
|
||||
await self.extraction_repo.delete(extraction)
|
||||
logger.info("Deleted extraction record: %d", extraction_id)
|
||||
|
||||
# Check if sound was in current playlist before deletion
|
||||
sound_was_in_current_playlist = False
|
||||
if sound_object and sound_data:
|
||||
sound_was_in_current_playlist = (
|
||||
await self._check_sound_in_current_playlist(sound_data["id"])
|
||||
)
|
||||
|
||||
# If there's an associated sound, delete it and its files
|
||||
if sound_object and sound_data:
|
||||
await self._delete_sound_and_files(sound_object, sound_data)
|
||||
logger.info(
|
||||
"Deleted associated sound: %d (%s)",
|
||||
sound_data["id"],
|
||||
sound_data["filename"],
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
await self.session.commit()
|
||||
|
||||
# Reload player playlist if deleted sound was in current playlist
|
||||
if sound_was_in_current_playlist and sound_data:
|
||||
await self._reload_player_playlist()
|
||||
logger.info(
|
||||
"Reloaded player playlist after deleting sound %d "
|
||||
"from current playlist",
|
||||
sound_data["id"],
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Rollback on any error
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete extraction %d", extraction_id)
|
||||
raise
|
||||
else:
|
||||
return True
|
||||
|
||||
async def _delete_sound_and_files(
|
||||
self,
|
||||
sound: Sound,
|
||||
sound_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Delete a sound record and all its associated files."""
|
||||
# Collect all file paths to delete using captured attributes
|
||||
files_to_delete = []
|
||||
|
||||
# Original audio file
|
||||
if sound_data["type"] == "EXT": # Extracted sounds
|
||||
original_path = Path("sounds/originals/extracted") / sound_data["filename"]
|
||||
if original_path.exists():
|
||||
files_to_delete.append(original_path)
|
||||
|
||||
# Normalized file
|
||||
if sound_data["is_normalized"] and sound_data["normalized_filename"]:
|
||||
normalized_path = (
|
||||
Path("sounds/normalized/extracted") / sound_data["normalized_filename"]
|
||||
)
|
||||
if normalized_path.exists():
|
||||
files_to_delete.append(normalized_path)
|
||||
|
||||
# Thumbnail file
|
||||
if sound_data["thumbnail"]:
|
||||
thumbnail_path = (
|
||||
Path(settings.EXTRACTION_THUMBNAILS_DIR) / sound_data["thumbnail"]
|
||||
)
|
||||
if thumbnail_path.exists():
|
||||
files_to_delete.append(thumbnail_path)
|
||||
|
||||
# Delete the sound from database first
|
||||
await self.sound_repo.delete(sound)
|
||||
|
||||
# Delete all associated files
|
||||
for file_path in files_to_delete:
|
||||
try:
|
||||
file_path.unlink()
|
||||
logger.info("Deleted file: %s", file_path)
|
||||
except OSError:
|
||||
logger.exception("Failed to delete file %s", file_path)
|
||||
# Continue with other files even if one fails
|
||||
|
||||
async def _check_sound_in_current_playlist(self, sound_id: int) -> bool:
|
||||
"""Check if a sound is in the current playlist."""
|
||||
try:
|
||||
from app.repositories.playlist import PlaylistRepository # noqa: PLC0415
|
||||
|
||||
playlist_repo = PlaylistRepository(self.session)
|
||||
current_playlist = await playlist_repo.get_current_playlist()
|
||||
|
||||
if not current_playlist or not current_playlist.id:
|
||||
return False
|
||||
|
||||
return await playlist_repo.is_sound_in_playlist(
|
||||
current_playlist.id, sound_id,
|
||||
)
|
||||
except (ImportError, AttributeError, ValueError, RuntimeError) as e:
|
||||
logger.warning(
|
||||
"Failed to check if sound %s is in current playlist: %s",
|
||||
sound_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
async def _reload_player_playlist(self) -> None:
|
||||
"""Reload the player playlist after a sound is deleted."""
|
||||
try:
|
||||
# Import here to avoid circular import issues
|
||||
from app.services.player import get_player_service # noqa: PLC0415
|
||||
|
||||
player = get_player_service()
|
||||
await player.reload_playlist()
|
||||
logger.debug("Player playlist reloaded after sound deletion")
|
||||
except (ImportError, AttributeError, ValueError, RuntimeError) as e:
|
||||
# Don't fail the deletion operation if player reload fails
|
||||
logger.warning("Failed to reload player playlist: %s", e, exc_info=True)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Background extraction processor for handling extraction queue."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -34,6 +35,9 @@ class ExtractionProcessor:
|
||||
logger.warning("Extraction processor is already running")
|
||||
return
|
||||
|
||||
# Reset any stuck extractions from previous runs
|
||||
await self._reset_stuck_extractions()
|
||||
|
||||
self.shutdown_event.clear()
|
||||
self.processor_task = asyncio.create_task(self._process_queue())
|
||||
logger.info("Started extraction processor")
|
||||
@@ -46,15 +50,13 @@ class ExtractionProcessor:
|
||||
if self.processor_task and not self.processor_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self.processor_task, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Extraction processor did not stop gracefully, cancelling..."
|
||||
"Extraction processor did not stop gracefully, cancelling...",
|
||||
)
|
||||
self.processor_task.cancel()
|
||||
try:
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self.processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Extraction processor stopped")
|
||||
|
||||
@@ -66,7 +68,8 @@ class ExtractionProcessor:
|
||||
# The processor will pick it up on the next cycle
|
||||
else:
|
||||
logger.warning(
|
||||
"Extraction %d is already being processed", extraction_id
|
||||
"Extraction %d is already being processed",
|
||||
extraction_id,
|
||||
)
|
||||
|
||||
async def _process_queue(self) -> None:
|
||||
@@ -81,16 +84,16 @@ class ExtractionProcessor:
|
||||
try:
|
||||
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
|
||||
break # Shutdown requested
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
continue # Continue processing
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in extraction queue processor: %s", e)
|
||||
except Exception:
|
||||
logger.exception("Error in extraction queue processor")
|
||||
# Wait a bit before retrying to avoid tight error loops
|
||||
try:
|
||||
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
|
||||
break # Shutdown requested
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
logger.info("Extraction queue processor stopped")
|
||||
@@ -125,13 +128,13 @@ class ExtractionProcessor:
|
||||
|
||||
# Start processing this extraction in the background
|
||||
task = asyncio.create_task(
|
||||
self._process_single_extraction(extraction_id)
|
||||
self._process_single_extraction(extraction_id),
|
||||
)
|
||||
task.add_done_callback(
|
||||
lambda t, eid=extraction_id: self._on_extraction_completed(
|
||||
eid,
|
||||
t,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -156,8 +159,8 @@ class ExtractionProcessor:
|
||||
result["status"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error processing extraction %d: %s", extraction_id, e)
|
||||
except Exception:
|
||||
logger.exception("Error processing extraction %d", extraction_id)
|
||||
|
||||
def _on_extraction_completed(self, extraction_id: int, task: asyncio.Task) -> None:
|
||||
"""Handle completion of an extraction task."""
|
||||
@@ -179,6 +182,47 @@ class ExtractionProcessor:
|
||||
self.max_concurrent,
|
||||
)
|
||||
|
||||
async def _reset_stuck_extractions(self) -> None:
|
||||
"""Reset any extractions stuck in 'processing' status back to 'pending'."""
|
||||
try:
|
||||
async with AsyncSession(engine) as session:
|
||||
extraction_service = ExtractionService(session)
|
||||
|
||||
# Get all extractions stuck in processing status
|
||||
stuck_extractions = (
|
||||
await extraction_service.extraction_repo.get_by_status("processing")
|
||||
)
|
||||
|
||||
if not stuck_extractions:
|
||||
logger.info("No stuck extractions found to reset")
|
||||
return
|
||||
|
||||
reset_count = 0
|
||||
for extraction in stuck_extractions:
|
||||
try:
|
||||
await extraction_service.extraction_repo.update(
|
||||
extraction, {"status": "pending", "error": None},
|
||||
)
|
||||
reset_count += 1
|
||||
logger.info(
|
||||
"Reset stuck extraction %d from processing to pending",
|
||||
extraction.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to reset extraction %d", extraction.id,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Successfully reset %d stuck extractions from processing to "
|
||||
"pending",
|
||||
reset_count,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to reset stuck extractions")
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get the current status of the extraction processor."""
|
||||
return {
|
||||
|
||||
382
app/services/favorite.py
Normal file
382
app/services/favorite.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""Service for managing user favorites."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.favorite import Favorite
|
||||
from app.repositories.favorite import FavoriteRepository
|
||||
from app.repositories.playlist import PlaylistRepository
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.repositories.user import UserRepository
|
||||
from app.services.socket import socket_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FavoriteService:
|
||||
"""Service for managing user favorites."""
|
||||
|
||||
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
||||
"""Initialize the favorite service.
|
||||
|
||||
Args:
|
||||
db_session_factory: Factory function to create database sessions
|
||||
|
||||
"""
|
||||
self.db_session_factory = db_session_factory
|
||||
|
||||
async def add_sound_favorite(self, user_id: int, sound_id: int) -> Favorite:
|
||||
"""Add a sound to user's favorites.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
sound_id: The sound ID
|
||||
|
||||
Returns:
|
||||
The created favorite
|
||||
|
||||
Raises:
|
||||
ValueError: If user or sound not found, or already favorited
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
sound_repo = SoundRepository(session)
|
||||
|
||||
# Verify user exists
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User with ID {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Verify sound exists
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
msg = f"Sound with ID {sound_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get data for the event immediately after loading
|
||||
sound_name = sound.name
|
||||
user_name = user.name
|
||||
|
||||
# Check if already favorited
|
||||
existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
|
||||
if existing:
|
||||
msg = f"Sound {sound_id} is already favorited by user {user_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create favorite
|
||||
favorite_data = {
|
||||
"user_id": user_id,
|
||||
"sound_id": sound_id,
|
||||
"playlist_id": None,
|
||||
}
|
||||
favorite = await favorite_repo.create(favorite_data)
|
||||
logger.info("User %s favorited sound %s", user_id, sound_id)
|
||||
|
||||
# Get updated favorite count within the same session
|
||||
favorite_count = await favorite_repo.count_sound_favorites(sound_id)
|
||||
|
||||
# Emit sound_favorited event via WebSocket (outside the session)
|
||||
try:
|
||||
event_data = {
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound_name,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"favorite_count": favorite_count,
|
||||
}
|
||||
await socket_manager.broadcast_to_all("sound_favorited", event_data)
|
||||
logger.info("Broadcasted sound_favorited event for sound %s", sound_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to broadcast sound_favorited event for sound %s",
|
||||
sound_id,
|
||||
)
|
||||
|
||||
return favorite
|
||||
|
||||
async def add_playlist_favorite(self, user_id: int, playlist_id: int) -> Favorite:
|
||||
"""Add a playlist to user's favorites.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
playlist_id: The playlist ID
|
||||
|
||||
Returns:
|
||||
The created favorite
|
||||
|
||||
Raises:
|
||||
ValueError: If user or playlist not found, or already favorited
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
playlist_repo = PlaylistRepository(session)
|
||||
|
||||
# Verify user exists
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User with ID {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Verify playlist exists
|
||||
playlist = await playlist_repo.get_by_id(playlist_id)
|
||||
if not playlist:
|
||||
msg = f"Playlist with ID {playlist_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check if already favorited
|
||||
existing = await favorite_repo.get_by_user_and_playlist(
|
||||
user_id,
|
||||
playlist_id,
|
||||
)
|
||||
if existing:
|
||||
msg = f"Playlist {playlist_id} is already favorited by user {user_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create favorite
|
||||
favorite_data = {
|
||||
"user_id": user_id,
|
||||
"sound_id": None,
|
||||
"playlist_id": playlist_id,
|
||||
}
|
||||
favorite = await favorite_repo.create(favorite_data)
|
||||
logger.info("User %s favorited playlist %s", user_id, playlist_id)
|
||||
return favorite
|
||||
|
||||
async def remove_sound_favorite(self, user_id: int, sound_id: int) -> None:
|
||||
"""Remove a sound from user's favorites.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
sound_id: The sound ID
|
||||
|
||||
Raises:
|
||||
ValueError: If favorite not found
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
|
||||
favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
|
||||
if not favorite:
|
||||
msg = f"Sound {sound_id} is not favorited by user {user_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get user and sound info before deletion for the event
|
||||
user_repo = UserRepository(session)
|
||||
sound_repo = SoundRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
|
||||
# Get data for the event immediately after loading
|
||||
sound_name = sound.name if sound else "Unknown"
|
||||
user_name = user.name if user else "Unknown"
|
||||
|
||||
await favorite_repo.delete(favorite)
|
||||
logger.info("User %s removed sound %s from favorites", user_id, sound_id)
|
||||
|
||||
# Get updated favorite count after deletion within the same session
|
||||
favorite_count = await favorite_repo.count_sound_favorites(sound_id)
|
||||
|
||||
# Emit sound_favorited event via WebSocket (outside the session)
|
||||
try:
|
||||
event_data = {
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound_name,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"favorite_count": favorite_count,
|
||||
}
|
||||
await socket_manager.broadcast_to_all("sound_favorited", event_data)
|
||||
logger.info(
|
||||
"Broadcasted sound_favorited event for sound %s removal",
|
||||
sound_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to broadcast sound_favorited event for sound %s removal",
|
||||
sound_id,
|
||||
)
|
||||
|
||||
async def remove_playlist_favorite(self, user_id: int, playlist_id: int) -> None:
|
||||
"""Remove a playlist from user's favorites.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
playlist_id: The playlist ID
|
||||
|
||||
Raises:
|
||||
ValueError: If favorite not found
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
|
||||
favorite = await favorite_repo.get_by_user_and_playlist(
|
||||
user_id,
|
||||
playlist_id,
|
||||
)
|
||||
if not favorite:
|
||||
msg = f"Playlist {playlist_id} is not favorited by user {user_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
await favorite_repo.delete(favorite)
|
||||
logger.info(
|
||||
"User %s removed playlist %s from favorites",
|
||||
user_id,
|
||||
playlist_id,
|
||||
)
|
||||
|
||||
async def get_user_favorites(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[Favorite]:
|
||||
"""Get all favorites for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of favorites to return
|
||||
offset: Number of favorites to skip
|
||||
|
||||
Returns:
|
||||
List of user favorites
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
return await favorite_repo.get_user_favorites(user_id, limit, offset)
|
||||
|
||||
async def get_user_sound_favorites(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[Favorite]:
|
||||
"""Get sound favorites for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of favorites to return
|
||||
offset: Number of favorites to skip
|
||||
|
||||
Returns:
|
||||
List of user sound favorites
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
return await favorite_repo.get_user_sound_favorites(user_id, limit, offset)
|
||||
|
||||
async def get_user_playlist_favorites(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[Favorite]:
|
||||
"""Get playlist favorites for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of favorites to return
|
||||
offset: Number of favorites to skip
|
||||
|
||||
Returns:
|
||||
List of user playlist favorites
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
return await favorite_repo.get_user_playlist_favorites(
|
||||
user_id,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
|
||||
async def is_sound_favorited(self, user_id: int, sound_id: int) -> bool:
|
||||
"""Check if a sound is favorited by a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
sound_id: The sound ID
|
||||
|
||||
Returns:
|
||||
True if the sound is favorited, False otherwise
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
return await favorite_repo.is_sound_favorited(user_id, sound_id)
|
||||
|
||||
async def is_playlist_favorited(self, user_id: int, playlist_id: int) -> bool:
|
||||
"""Check if a playlist is favorited by a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
playlist_id: The playlist ID
|
||||
|
||||
Returns:
|
||||
True if the playlist is favorited, False otherwise
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
return await favorite_repo.is_playlist_favorited(user_id, playlist_id)
|
||||
|
||||
async def get_favorite_counts(self, user_id: int) -> dict[str, int]:
|
||||
"""Get favorite counts for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
Dictionary with favorite counts
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
|
||||
total = await favorite_repo.count_user_favorites(user_id)
|
||||
sounds = len(await favorite_repo.get_user_sound_favorites(user_id))
|
||||
playlists = len(await favorite_repo.get_user_playlist_favorites(user_id))
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"sounds": sounds,
|
||||
"playlists": playlists,
|
||||
}
|
||||
|
||||
async def get_sound_favorite_count(self, sound_id: int) -> int:
|
||||
"""Get the number of users who have favorited a sound.
|
||||
|
||||
Args:
|
||||
sound_id: The sound ID
|
||||
|
||||
Returns:
|
||||
Number of users who favorited this sound
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
return await favorite_repo.count_sound_favorites(sound_id)
|
||||
|
||||
async def get_playlist_favorite_count(self, playlist_id: int) -> int:
|
||||
"""Get the number of users who have favorited a playlist.
|
||||
|
||||
Args:
|
||||
playlist_id: The playlist ID
|
||||
|
||||
Returns:
|
||||
Number of users who favorited this playlist
|
||||
|
||||
"""
|
||||
async with self.db_session_factory() as session:
|
||||
favorite_repo = FavoriteRepository(session)
|
||||
return await favorite_repo.count_playlist_favorites(playlist_id)
|
||||
@@ -70,7 +70,7 @@ class OAuthProvider(ABC):
|
||||
"""Generate authorization URL with state parameter."""
|
||||
# Construct provider-specific redirect URI
|
||||
redirect_uri = (
|
||||
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
|
||||
f"{settings.BACKEND_URL}/api/v1/auth/{self.provider_name}/callback"
|
||||
)
|
||||
|
||||
params = {
|
||||
@@ -86,7 +86,7 @@ class OAuthProvider(ABC):
|
||||
"""Exchange authorization code for access token."""
|
||||
# Construct provider-specific redirect URI (must match authorization request)
|
||||
redirect_uri = (
|
||||
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
|
||||
f"{settings.BACKEND_URL}/api/v1/auth/{self.provider_name}/callback"
|
||||
)
|
||||
|
||||
data = {
|
||||
@@ -150,7 +150,7 @@ class GoogleOAuthProvider(OAuthProvider):
|
||||
"""Exchange authorization code for access token."""
|
||||
# Construct provider-specific redirect URI (must match authorization request)
|
||||
redirect_uri = (
|
||||
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
|
||||
f"{settings.BACKEND_URL}/api/v1/auth/{self.provider_name}/callback"
|
||||
)
|
||||
|
||||
data = {
|
||||
|
||||
@@ -8,6 +8,8 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import vlc # type: ignore[import-untyped]
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.playlist import Playlist
|
||||
@@ -16,6 +18,7 @@ from app.models.sound_played import SoundPlayed
|
||||
from app.repositories.playlist import PlaylistRepository
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.services.socket import socket_manager
|
||||
from app.services.volume import volume_service
|
||||
from app.utils.audio import get_sound_file_path
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -46,7 +49,11 @@ class PlayerState:
|
||||
"""Initialize player state."""
|
||||
self.status: PlayerStatus = PlayerStatus.STOPPED
|
||||
self.mode: PlayerMode = PlayerMode.CONTINUOUS
|
||||
self.volume: int = 50
|
||||
|
||||
# Initialize volume from host system or default to 80
|
||||
host_volume = volume_service.get_volume()
|
||||
self.volume: int = host_volume if host_volume is not None else 80
|
||||
self.previous_volume: int = self.volume
|
||||
self.current_sound_id: int | None = None
|
||||
self.current_sound_index: int | None = None
|
||||
self.current_sound_position: int = 0
|
||||
@@ -57,24 +64,35 @@ class PlayerState:
|
||||
self.playlist_length: int = 0
|
||||
self.playlist_duration: int = 0
|
||||
self.playlist_sounds: list[Sound] = []
|
||||
self.play_next_queue: list[Sound] = []
|
||||
self.playlist_index_before_play_next: int | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert player state to dictionary for serialization."""
|
||||
return {
|
||||
"status": self.status.value,
|
||||
"mode": self.mode.value,
|
||||
"mode": self.mode.value if isinstance(self.mode, PlayerMode) else self.mode,
|
||||
"volume": self.volume,
|
||||
"current_sound_id": self.current_sound_id,
|
||||
"current_sound_index": self.current_sound_index,
|
||||
"current_sound_position": self.current_sound_position,
|
||||
"current_sound_duration": self.current_sound_duration,
|
||||
"previous_volume": self.previous_volume,
|
||||
"position": self.current_sound_position or 0,
|
||||
"duration": self.current_sound_duration,
|
||||
"index": self.current_sound_index,
|
||||
"current_sound": self._serialize_sound(self.current_sound),
|
||||
"playlist_id": self.playlist_id,
|
||||
"playlist_name": self.playlist_name,
|
||||
"playlist_length": self.playlist_length,
|
||||
"playlist_duration": self.playlist_duration,
|
||||
"playlist_sounds": [
|
||||
self._serialize_sound(sound) for sound in self.playlist_sounds
|
||||
"playlist": (
|
||||
{
|
||||
"id": self.playlist_id,
|
||||
"name": self.playlist_name,
|
||||
"length": self.playlist_length,
|
||||
"duration": self.playlist_duration,
|
||||
"sounds": [
|
||||
self._serialize_sound(sound) for sound in self.playlist_sounds
|
||||
],
|
||||
}
|
||||
if self.playlist_id
|
||||
else None
|
||||
),
|
||||
"play_next_queue": [
|
||||
self._serialize_sound(sound) for sound in self.play_next_queue
|
||||
],
|
||||
}
|
||||
|
||||
@@ -82,6 +100,14 @@ class PlayerState:
|
||||
"""Serialize a sound object for JSON serialization."""
|
||||
if not sound:
|
||||
return None
|
||||
|
||||
# Get extraction URL if sound is linked to an extraction
|
||||
extract_url = None
|
||||
if hasattr(sound, "extractions") and sound.extractions:
|
||||
# Get the first extraction (there should only be one per sound)
|
||||
extraction = sound.extractions[0]
|
||||
extract_url = extraction.url
|
||||
|
||||
return {
|
||||
"id": sound.id,
|
||||
"name": sound.name,
|
||||
@@ -91,6 +117,7 @@ class PlayerState:
|
||||
"type": sound.type,
|
||||
"thumbnail": sound.thumbnail,
|
||||
"play_count": sound.play_count,
|
||||
"extract_url": extract_url,
|
||||
}
|
||||
|
||||
|
||||
@@ -102,6 +129,14 @@ class PlayerService:
|
||||
self.db_session_factory = db_session_factory
|
||||
self.state = PlayerState()
|
||||
self._vlc_instance = vlc.Instance()
|
||||
|
||||
if self._vlc_instance is None:
|
||||
msg = (
|
||||
"VLC instance could not be created. "
|
||||
"Ensure VLC is installed and accessible."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self._player = self._vlc_instance.media_player_new()
|
||||
self._is_running = False
|
||||
self._position_thread: threading.Thread | None = None
|
||||
@@ -124,12 +159,13 @@ class PlayerService:
|
||||
|
||||
# Start position tracking thread
|
||||
self._position_thread = threading.Thread(
|
||||
target=self._position_tracker, daemon=True,
|
||||
target=self._position_tracker,
|
||||
daemon=True,
|
||||
)
|
||||
self._position_thread.start()
|
||||
|
||||
# Set initial volume
|
||||
self._player.audio_set_volume(self.state.volume)
|
||||
# Set VLC to 100% volume - host volume is controlled separately
|
||||
self._player.audio_set_volume(100)
|
||||
|
||||
logger.info("Player service started")
|
||||
|
||||
@@ -152,83 +188,139 @@ class PlayerService:
|
||||
|
||||
async def play(self, index: int | None = None) -> None:
|
||||
"""Play audio at specified index or current position."""
|
||||
# Check if we're resuming from pause
|
||||
is_resuming = (
|
||||
index is None and
|
||||
self.state.status == PlayerStatus.PAUSED and
|
||||
self.state.current_sound is not None
|
||||
)
|
||||
|
||||
if is_resuming:
|
||||
# Simply resume playback
|
||||
result = self._player.play()
|
||||
if result == 0: # VLC returns 0 on success
|
||||
self.state.status = PlayerStatus.PLAYING
|
||||
|
||||
# Ensure play time tracking is initialized for resumed track
|
||||
if (
|
||||
self.state.current_sound_id
|
||||
and self.state.current_sound_id not in self._play_time_tracking
|
||||
):
|
||||
self._play_time_tracking[self.state.current_sound_id] = {
|
||||
"total_time": 0,
|
||||
"last_position": self.state.current_sound_position,
|
||||
"last_update": time.time(),
|
||||
"threshold_reached": False,
|
||||
}
|
||||
|
||||
await self._broadcast_state()
|
||||
logger.info("Resumed playing sound: %s", self.state.current_sound.name)
|
||||
else:
|
||||
logger.error("Failed to resume playback: VLC error code %s", result)
|
||||
if self._should_resume_playback(index):
|
||||
await self._resume_playback()
|
||||
return
|
||||
|
||||
# Starting new track or changing track
|
||||
if index is not None:
|
||||
if index < 0 or index >= len(self.state.playlist_sounds):
|
||||
msg = "Invalid sound index"
|
||||
raise ValueError(msg)
|
||||
self.state.current_sound_index = index
|
||||
self.state.current_sound = self.state.playlist_sounds[index]
|
||||
self.state.current_sound_id = self.state.current_sound.id
|
||||
await self._start_new_track(index)
|
||||
|
||||
def _should_resume_playback(self, index: int | None) -> bool:
|
||||
"""Check if we should resume paused playback."""
|
||||
return (
|
||||
index is None
|
||||
and self.state.status == PlayerStatus.PAUSED
|
||||
and self.state.current_sound is not None
|
||||
)
|
||||
|
||||
async def _resume_playback(self) -> None:
|
||||
"""Resume paused playback."""
|
||||
result = self._player.play()
|
||||
if result == 0: # VLC returns 0 on success
|
||||
self.state.status = PlayerStatus.PLAYING
|
||||
self._ensure_play_time_tracking_for_resume()
|
||||
await self._broadcast_state()
|
||||
|
||||
sound_name = (
|
||||
self.state.current_sound.name if self.state.current_sound else "Unknown"
|
||||
)
|
||||
logger.info("Resumed playing sound: %s", sound_name)
|
||||
else:
|
||||
logger.error("Failed to resume playback: VLC error code %s", result)
|
||||
|
||||
def _ensure_play_time_tracking_for_resume(self) -> None:
|
||||
"""Ensure play time tracking is initialized for resumed track."""
|
||||
if (
|
||||
self.state.current_sound_id
|
||||
and self.state.current_sound_id not in self._play_time_tracking
|
||||
):
|
||||
self._play_time_tracking[self.state.current_sound_id] = {
|
||||
"total_time": 0,
|
||||
"last_position": self.state.current_sound_position,
|
||||
"last_update": time.time(),
|
||||
"threshold_reached": False,
|
||||
}
|
||||
|
||||
async def _start_new_track(self, index: int | None) -> None:
|
||||
"""Start playing a new track."""
|
||||
if not self._prepare_sound_for_playback(index):
|
||||
return
|
||||
|
||||
if not self._load_and_play_media():
|
||||
return
|
||||
|
||||
await self._handle_successful_playback()
|
||||
|
||||
def _prepare_sound_for_playback(self, index: int | None) -> bool:
|
||||
"""Prepare sound for playback, return True if ready."""
|
||||
if index is not None and not self._set_sound_by_index(index):
|
||||
return False
|
||||
|
||||
if not self.state.current_sound:
|
||||
logger.warning("No sound to play")
|
||||
return
|
||||
return False
|
||||
|
||||
return self._validate_sound_file()
|
||||
|
||||
def _set_sound_by_index(self, index: int) -> bool:
|
||||
"""Set current sound by index, return True if valid."""
|
||||
if index < 0 or index >= len(self.state.playlist_sounds):
|
||||
msg = "Invalid sound index"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.state.current_sound_index = index
|
||||
self.state.current_sound = self.state.playlist_sounds[index]
|
||||
self.state.current_sound_id = self.state.current_sound.id
|
||||
return True
|
||||
|
||||
def _validate_sound_file(self) -> bool:
|
||||
"""Validate sound file exists, return True if valid."""
|
||||
if not self.state.current_sound:
|
||||
return False
|
||||
|
||||
# Get sound file path
|
||||
sound_path = get_sound_file_path(self.state.current_sound)
|
||||
if not sound_path.exists():
|
||||
logger.error("Sound file not found: %s", sound_path)
|
||||
return
|
||||
return False
|
||||
return True
|
||||
|
||||
# Load and play media (new track)
|
||||
def _load_and_play_media(self) -> bool:
|
||||
"""Load media and start playback, return True if successful."""
|
||||
if self._vlc_instance is None:
|
||||
logger.error("VLC instance is not initialized. Cannot play media.")
|
||||
return False
|
||||
|
||||
if not self.state.current_sound:
|
||||
logger.error("No current sound to play")
|
||||
return False
|
||||
|
||||
sound_path = get_sound_file_path(self.state.current_sound)
|
||||
media = self._vlc_instance.media_new(str(sound_path))
|
||||
self._player.set_media(media)
|
||||
|
||||
result = self._player.play()
|
||||
if result == 0: # VLC returns 0 on success
|
||||
self.state.status = PlayerStatus.PLAYING
|
||||
self.state.current_sound_duration = self.state.current_sound.duration or 0
|
||||
|
||||
# Initialize play time tracking for new track
|
||||
if self.state.current_sound_id:
|
||||
self._play_time_tracking[self.state.current_sound_id] = {
|
||||
"total_time": 0,
|
||||
"last_position": 0,
|
||||
"last_update": time.time(),
|
||||
"threshold_reached": False,
|
||||
}
|
||||
logger.info(
|
||||
"Initialized play time tracking for sound %s (duration: %s ms)",
|
||||
self.state.current_sound_id,
|
||||
self.state.current_sound_duration,
|
||||
)
|
||||
|
||||
await self._broadcast_state()
|
||||
logger.info("Started playing sound: %s", self.state.current_sound.name)
|
||||
else:
|
||||
if result != 0: # VLC returns 0 on success
|
||||
logger.error("Failed to start playback: VLC error code %s", result)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _handle_successful_playback(self) -> None:
|
||||
"""Handle successful playback start."""
|
||||
if not self.state.current_sound:
|
||||
logger.error("No current sound for successful playback")
|
||||
return
|
||||
|
||||
self.state.status = PlayerStatus.PLAYING
|
||||
self.state.current_sound_duration = self.state.current_sound.duration or 0
|
||||
|
||||
self._initialize_play_time_tracking()
|
||||
await self._broadcast_state()
|
||||
logger.info("Started playing sound: %s", self.state.current_sound.name)
|
||||
|
||||
def _initialize_play_time_tracking(self) -> None:
|
||||
"""Initialize play time tracking for new track."""
|
||||
if self.state.current_sound_id:
|
||||
self._play_time_tracking[self.state.current_sound_id] = {
|
||||
"total_time": 0,
|
||||
"last_position": 0,
|
||||
"last_update": time.time(),
|
||||
"threshold_reached": False,
|
||||
}
|
||||
logger.info(
|
||||
"Initialized play time tracking for sound %s (duration: %s ms)",
|
||||
self.state.current_sound_id,
|
||||
self.state.current_sound_duration,
|
||||
)
|
||||
|
||||
async def pause(self) -> None:
|
||||
"""Pause playback."""
|
||||
@@ -257,6 +349,31 @@ class PlayerService:
|
||||
|
||||
async def next(self) -> None:
|
||||
"""Skip to next track."""
|
||||
# Check if there's a track in the play_next queue
|
||||
if self.state.play_next_queue:
|
||||
await self._play_next_from_queue()
|
||||
return
|
||||
|
||||
# If currently playing from play_next queue (no index but have stored index)
|
||||
if (
|
||||
self.state.current_sound_index is None
|
||||
and self.state.playlist_index_before_play_next is not None
|
||||
and self.state.playlist_sounds
|
||||
):
|
||||
# Skipped the last play_next track, go to next in playlist
|
||||
restored_index = self.state.playlist_index_before_play_next
|
||||
next_index = self._get_next_index(restored_index)
|
||||
|
||||
# Clear the stored index
|
||||
self.state.playlist_index_before_play_next = None
|
||||
|
||||
if next_index is not None:
|
||||
await self.play(next_index)
|
||||
else:
|
||||
await self._stop_playback()
|
||||
await self._broadcast_state()
|
||||
return
|
||||
|
||||
if not self.state.playlist_sounds:
|
||||
return
|
||||
|
||||
@@ -297,26 +414,121 @@ class PlayerService:
|
||||
logger.debug("Seeked to position: %sms", position_ms)
|
||||
|
||||
async def set_volume(self, volume: int) -> None:
|
||||
"""Set playback volume (0-100)."""
|
||||
"""Set playback volume (0-100) by controlling host system volume."""
|
||||
volume = max(0, min(100, volume)) # Clamp to valid range
|
||||
|
||||
# Store previous volume when muting (going from >0 to 0)
|
||||
if self.state.volume > 0 and volume == 0:
|
||||
self.state.previous_volume = self.state.volume
|
||||
|
||||
self.state.volume = volume
|
||||
self._player.audio_set_volume(volume)
|
||||
|
||||
# Control host system volume instead of VLC volume
|
||||
if volume == 0:
|
||||
# Mute the host system
|
||||
volume_service.set_mute(muted=True)
|
||||
else:
|
||||
# Unmute and set host volume
|
||||
if volume_service.is_muted():
|
||||
volume_service.set_mute(muted=False)
|
||||
volume_service.set_volume(volume)
|
||||
|
||||
# Keep VLC at 100% volume
|
||||
self._player.audio_set_volume(100)
|
||||
|
||||
await self._broadcast_state()
|
||||
logger.debug("Volume set to: %s", volume)
|
||||
logger.debug("Host volume set to: %s", volume)
|
||||
|
||||
async def set_mode(self, mode: PlayerMode) -> None:
|
||||
async def mute(self) -> None:
|
||||
"""Mute the host system (stores current volume as previous_volume)."""
|
||||
if self.state.volume > 0:
|
||||
await self.set_volume(0)
|
||||
|
||||
async def unmute(self) -> None:
|
||||
"""Unmute the host system (restores previous_volume)."""
|
||||
if self.state.volume == 0 and self.state.previous_volume > 0:
|
||||
await self.set_volume(self.state.previous_volume)
|
||||
|
||||
async def set_mode(self, mode: PlayerMode | str) -> None:
|
||||
"""Set playback mode."""
|
||||
if isinstance(mode, str):
|
||||
# Convert string to PlayerMode enum
|
||||
try:
|
||||
mode = PlayerMode(mode)
|
||||
except ValueError:
|
||||
logger.exception("Invalid player mode: %s", mode)
|
||||
return
|
||||
|
||||
self.state.mode = mode
|
||||
await self._broadcast_state()
|
||||
logger.info("Playback mode set to: %s", mode.value)
|
||||
|
||||
async def add_to_play_next(self, sound_id: int) -> None:
|
||||
"""Add a sound to the play_next queue."""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
# Eagerly load extractions to avoid lazy loading issues
|
||||
statement = select(Sound).where(Sound.id == sound_id)
|
||||
statement = statement.options(selectinload(Sound.extractions)) # type: ignore[arg-type]
|
||||
result = await session.exec(statement)
|
||||
sound = result.first()
|
||||
|
||||
if not sound:
|
||||
logger.warning("Sound %s not found for play_next", sound_id)
|
||||
return
|
||||
|
||||
self.state.play_next_queue.append(sound)
|
||||
await self._broadcast_state()
|
||||
logger.info("Added sound %s to play_next queue", sound.name)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def _play_next_from_queue(self) -> None:
|
||||
"""Play the first track from the play_next queue."""
|
||||
if not self.state.play_next_queue:
|
||||
return
|
||||
|
||||
# Store current playlist index before switching to play_next track
|
||||
# Only store if we're currently playing from the playlist
|
||||
if (
|
||||
self.state.current_sound_index is not None
|
||||
and self.state.playlist_index_before_play_next is None
|
||||
):
|
||||
self.state.playlist_index_before_play_next = (
|
||||
self.state.current_sound_index
|
||||
)
|
||||
logger.info(
|
||||
"Stored playlist index %s before playing from play_next queue",
|
||||
self.state.playlist_index_before_play_next,
|
||||
)
|
||||
|
||||
# Get the first sound from the queue
|
||||
next_sound = self.state.play_next_queue.pop(0)
|
||||
|
||||
# Stop current playback and process play count
|
||||
if self.state.status != PlayerStatus.STOPPED:
|
||||
await self._stop_playback()
|
||||
|
||||
# Set the sound as current (without index since it's from play_next)
|
||||
self.state.current_sound = next_sound
|
||||
self.state.current_sound_id = next_sound.id
|
||||
self.state.current_sound_index = None # No index for play_next tracks
|
||||
|
||||
# Play the sound
|
||||
if not self._validate_sound_file():
|
||||
return
|
||||
|
||||
if not self._load_and_play_media():
|
||||
return
|
||||
|
||||
await self._handle_successful_playback()
|
||||
|
||||
async def reload_playlist(self) -> None:
|
||||
"""Reload current playlist from database."""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
playlist_repo = PlaylistRepository(session)
|
||||
current_playlist = await playlist_repo.get_main_playlist()
|
||||
current_playlist = await playlist_repo.get_current_playlist()
|
||||
|
||||
if current_playlist and current_playlist.id:
|
||||
sounds = await playlist_repo.get_playlist_sounds(current_playlist.id)
|
||||
@@ -333,6 +545,26 @@ class PlayerService:
|
||||
|
||||
await self._broadcast_state()
|
||||
|
||||
async def load_playlist(self, playlist_id: int) -> None:
|
||||
"""Load a specific playlist by ID."""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
playlist_repo = PlaylistRepository(session)
|
||||
playlist = await playlist_repo.get_by_id(playlist_id)
|
||||
if playlist and playlist.id:
|
||||
sounds = await playlist_repo.get_playlist_sounds(playlist.id)
|
||||
await self._handle_playlist_reload(playlist, sounds)
|
||||
logger.info(
|
||||
"Loaded playlist: %s (%s sounds)",
|
||||
playlist.name,
|
||||
len(sounds),
|
||||
)
|
||||
else:
|
||||
logger.warning("Playlist not found: %s", playlist_id)
|
||||
finally:
|
||||
await session.close()
|
||||
await self._broadcast_state()
|
||||
|
||||
async def _handle_playlist_reload(
|
||||
self,
|
||||
current_playlist: Playlist,
|
||||
@@ -353,7 +585,9 @@ class PlayerService:
|
||||
and previous_playlist_id != current_playlist.id
|
||||
):
|
||||
await self._handle_playlist_id_changed(
|
||||
previous_playlist_id, current_playlist.id, sounds,
|
||||
previous_playlist_id,
|
||||
current_playlist.id,
|
||||
sounds,
|
||||
)
|
||||
elif previous_current_sound_id:
|
||||
await self._handle_same_playlist_track_check(
|
||||
@@ -377,6 +611,16 @@ class PlayerService:
|
||||
current_id,
|
||||
)
|
||||
|
||||
# Clear play_next queue when playlist changes
|
||||
if self.state.play_next_queue:
|
||||
logger.info("Clearing play_next queue due to playlist change")
|
||||
self.state.play_next_queue.clear()
|
||||
|
||||
# Clear stored playlist index
|
||||
if self.state.playlist_index_before_play_next is not None:
|
||||
logger.info("Clearing stored playlist index due to playlist change")
|
||||
self.state.playlist_index_before_play_next = None
|
||||
|
||||
if self.state.status != PlayerStatus.STOPPED:
|
||||
await self._stop_playback()
|
||||
|
||||
@@ -392,6 +636,9 @@ class PlayerService:
|
||||
sounds: list[Sound],
|
||||
) -> None:
|
||||
"""Handle track checking when playlist ID is the same."""
|
||||
# Remove tracks from play_next queue that are no longer in the playlist
|
||||
self._clean_play_next_queue(sounds)
|
||||
|
||||
# Find the current track in the new playlist
|
||||
new_index = self._find_sound_index(previous_sound_id, sounds)
|
||||
|
||||
@@ -431,7 +678,9 @@ class PlayerService:
|
||||
self._clear_current_track()
|
||||
|
||||
def _update_playlist_state(
|
||||
self, current_playlist: Playlist, sounds: list[Sound],
|
||||
self,
|
||||
current_playlist: Playlist,
|
||||
sounds: list[Sound],
|
||||
) -> None:
|
||||
"""Update basic playlist state information."""
|
||||
self.state.playlist_id = current_playlist.id
|
||||
@@ -447,6 +696,29 @@ class PlayerService:
|
||||
return i
|
||||
return None
|
||||
|
||||
def _clean_play_next_queue(self, playlist_sounds: list[Sound]) -> None:
|
||||
"""Remove tracks from play_next queue that are no longer in the playlist."""
|
||||
if not self.state.play_next_queue:
|
||||
return
|
||||
|
||||
# Get IDs of all sounds in the current playlist
|
||||
playlist_sound_ids = {sound.id for sound in playlist_sounds}
|
||||
|
||||
# Filter out tracks that are no longer in the playlist
|
||||
original_length = len(self.state.play_next_queue)
|
||||
self.state.play_next_queue = [
|
||||
sound
|
||||
for sound in self.state.play_next_queue
|
||||
if sound.id in playlist_sound_ids
|
||||
]
|
||||
|
||||
removed_count = original_length - len(self.state.play_next_queue)
|
||||
if removed_count > 0:
|
||||
logger.info(
|
||||
"Removed %s track(s) from play_next queue (no longer in playlist)",
|
||||
removed_count,
|
||||
)
|
||||
|
||||
def _set_first_track_as_current(self, sounds: list[Sound]) -> None:
|
||||
"""Set the first track as the current track."""
|
||||
self.state.current_sound_index = 0
|
||||
@@ -463,7 +735,6 @@ class PlayerService:
|
||||
"""Get current player state."""
|
||||
return self.state.to_dict()
|
||||
|
||||
|
||||
def _get_next_index(self, current_index: int) -> int | None:
|
||||
"""Get next track index based on current mode."""
|
||||
if not self.state.playlist_sounds:
|
||||
@@ -496,11 +767,7 @@ class PlayerService:
|
||||
prev_index = current_index - 1
|
||||
|
||||
if prev_index < 0:
|
||||
return (
|
||||
playlist_length - 1
|
||||
if self.state.mode == PlayerMode.LOOP
|
||||
else None
|
||||
)
|
||||
return playlist_length - 1 if self.state.mode == PlayerMode.LOOP else None
|
||||
return prev_index
|
||||
|
||||
def _position_tracker(self) -> None:
|
||||
@@ -516,15 +783,16 @@ class PlayerService:
|
||||
|
||||
# Check if track finished
|
||||
player_state = self._player.get_state()
|
||||
if hasattr(vlc, "State") and player_state == vlc.State.Ended:
|
||||
vlc_state_ended = 6 # vlc.State.Ended value
|
||||
if player_state == vlc_state_ended:
|
||||
# Track finished, handle auto-advance
|
||||
self._schedule_async_task(self._handle_track_finished())
|
||||
|
||||
# Update play time tracking
|
||||
self._update_play_time()
|
||||
|
||||
# Broadcast state every 0.5 seconds while playing
|
||||
broadcast_interval = 0.5
|
||||
# Broadcast state every second while playing
|
||||
broadcast_interval = 1
|
||||
current_time = time.time()
|
||||
if current_time - self._last_position_broadcast >= broadcast_interval:
|
||||
self._last_position_broadcast = current_time
|
||||
@@ -534,10 +802,7 @@ class PlayerService:
|
||||
|
||||
def _update_play_time(self) -> None:
|
||||
"""Update play time tracking for current sound."""
|
||||
if (
|
||||
not self.state.current_sound_id
|
||||
or self.state.status != PlayerStatus.PLAYING
|
||||
):
|
||||
if not self.state.current_sound_id or self.state.status != PlayerStatus.PLAYING:
|
||||
return
|
||||
|
||||
sound_id = self.state.current_sound_id
|
||||
@@ -576,10 +841,8 @@ class PlayerService:
|
||||
sound_id,
|
||||
tracking["total_time"],
|
||||
self.state.current_sound_duration,
|
||||
(
|
||||
tracking["total_time"]
|
||||
/ self.state.current_sound_duration
|
||||
) * 100,
|
||||
(tracking["total_time"] / self.state.current_sound_duration)
|
||||
* 100,
|
||||
)
|
||||
self._schedule_async_task(self._record_play_count(sound_id))
|
||||
|
||||
@@ -595,7 +858,8 @@ class PlayerService:
|
||||
if sound:
|
||||
old_count = sound.play_count
|
||||
await sound_repo.update(
|
||||
sound, {"play_count": sound.play_count + 1},
|
||||
sound,
|
||||
{"play_count": sound.play_count + 1},
|
||||
)
|
||||
logger.info(
|
||||
"Updated sound %s play_count: %s -> %s",
|
||||
@@ -644,7 +908,12 @@ class PlayerService:
|
||||
"""Handle when a track finishes playing."""
|
||||
await self._process_play_count()
|
||||
|
||||
# Auto-advance to next track
|
||||
# Check if there's a track in the play_next queue
|
||||
if self.state.play_next_queue:
|
||||
await self._play_next_from_queue()
|
||||
return
|
||||
|
||||
# Auto-advance to next track in playlist
|
||||
if self.state.current_sound_index is not None:
|
||||
next_index = self._get_next_index(self.state.current_sound_index)
|
||||
if next_index is not None:
|
||||
@@ -652,6 +921,32 @@ class PlayerService:
|
||||
else:
|
||||
await self._stop_playback()
|
||||
await self._broadcast_state()
|
||||
elif (
|
||||
self.state.playlist_sounds
|
||||
and self.state.playlist_index_before_play_next is not None
|
||||
):
|
||||
# Current track was from play_next queue, restore to next track in playlist
|
||||
restored_index = self.state.playlist_index_before_play_next
|
||||
logger.info(
|
||||
"Play next queue finished, continuing from playlist index %s",
|
||||
restored_index,
|
||||
)
|
||||
|
||||
# Get the next index based on the stored position
|
||||
next_index = self._get_next_index(restored_index)
|
||||
|
||||
# Clear the stored index since we're done with play_next queue
|
||||
self.state.playlist_index_before_play_next = None
|
||||
|
||||
if next_index is not None:
|
||||
await self.play(next_index)
|
||||
else:
|
||||
# No next track (end of playlist in non-loop mode)
|
||||
await self._stop_playback()
|
||||
await self._broadcast_state()
|
||||
else:
|
||||
await self._stop_playback()
|
||||
await self._broadcast_state()
|
||||
|
||||
async def _broadcast_state(self) -> None:
|
||||
"""Broadcast current player state via WebSocket."""
|
||||
|
||||
@@ -1,24 +1,59 @@
|
||||
"""Playlist service for business logic operations."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.playlist import Playlist
|
||||
from app.models.playlist_sound import PlaylistSound
|
||||
from app.models.sound import Sound
|
||||
from app.repositories.playlist import PlaylistRepository
|
||||
from app.repositories.playlist import PlaylistRepository, PlaylistSortField, SortOrder
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.utils.exceptions import (
|
||||
raise_bad_request,
|
||||
raise_internal_server_error,
|
||||
raise_not_found,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PaginatedPlaylistsResponse(TypedDict):
|
||||
"""Response type for paginated playlists."""
|
||||
|
||||
playlists: list[dict]
|
||||
total: int
|
||||
page: int
|
||||
limit: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
async def _reload_player_playlist() -> None:
|
||||
"""Reload the player playlist after current playlist changes."""
|
||||
try:
|
||||
# Import here to avoid circular import issues
|
||||
from app.services.player import get_player_service # noqa: PLC0415
|
||||
|
||||
player = get_player_service()
|
||||
await player.reload_playlist()
|
||||
logger.debug("Player playlist reloaded after current playlist change")
|
||||
except Exception: # noqa: BLE001
|
||||
# Don't fail the playlist operation if player reload fails
|
||||
logger.warning("Failed to reload player playlist", exc_info=True)
|
||||
|
||||
|
||||
async def _is_current_playlist(session: AsyncSession, playlist_id: int) -> bool:
|
||||
"""Check if the given playlist is the current playlist."""
|
||||
try:
|
||||
from app.repositories.playlist import PlaylistRepository # noqa: PLC0415
|
||||
|
||||
playlist_repo = PlaylistRepository(session)
|
||||
current_playlist = await playlist_repo.get_current_playlist()
|
||||
except Exception: # noqa: BLE001
|
||||
logger.warning("Failed to check if playlist is current", exc_info=True)
|
||||
return False
|
||||
else:
|
||||
return current_playlist is not None and current_playlist.id == playlist_id
|
||||
|
||||
|
||||
class PlaylistService:
|
||||
"""Service for playlist operations."""
|
||||
|
||||
@@ -28,11 +63,24 @@ class PlaylistService:
|
||||
self.playlist_repo = PlaylistRepository(session)
|
||||
self.sound_repo = SoundRepository(session)
|
||||
|
||||
async def _is_main_playlist(self, playlist_id: int) -> bool:
|
||||
"""Check if the given playlist is the main playlist."""
|
||||
try:
|
||||
playlist = await self.playlist_repo.get_by_id(playlist_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to check if playlist is main: %s", playlist_id)
|
||||
return False
|
||||
else:
|
||||
return playlist is not None and playlist.is_main
|
||||
|
||||
async def get_playlist_by_id(self, playlist_id: int) -> Playlist:
|
||||
"""Get a playlist by ID."""
|
||||
playlist = await self.playlist_repo.get_by_id(playlist_id)
|
||||
if not playlist:
|
||||
raise_not_found("Playlist")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Playlist not found",
|
||||
)
|
||||
|
||||
return playlist
|
||||
|
||||
@@ -49,27 +97,29 @@ class PlaylistService:
|
||||
main_playlist = await self.playlist_repo.get_main_playlist()
|
||||
|
||||
if not main_playlist:
|
||||
raise_internal_server_error(
|
||||
"Main playlist not found. Make sure to run database seeding."
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Main playlist not found. Make sure to run database seeding.",
|
||||
)
|
||||
|
||||
return main_playlist
|
||||
|
||||
async def get_current_playlist(self, user_id: int) -> Playlist:
|
||||
"""Get the user's current playlist, fallback to main playlist."""
|
||||
current_playlist = await self.playlist_repo.get_current_playlist(user_id)
|
||||
async def get_current_playlist(self) -> Playlist:
|
||||
"""Get the global current playlist, fallback to main playlist."""
|
||||
current_playlist = await self.playlist_repo.get_current_playlist()
|
||||
if current_playlist:
|
||||
return current_playlist
|
||||
|
||||
# Fallback to main playlist if no current playlist is set
|
||||
return await self.get_main_playlist()
|
||||
|
||||
async def create_playlist(
|
||||
async def create_playlist( # noqa: PLR0913
|
||||
self,
|
||||
user_id: int,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
genre: str | None = None,
|
||||
*,
|
||||
is_main: bool = False,
|
||||
is_current: bool = False,
|
||||
is_deletable: bool = True,
|
||||
@@ -85,7 +135,7 @@ class PlaylistService:
|
||||
|
||||
# If this is set as current, unset the previous current playlist
|
||||
if is_current:
|
||||
await self._unset_current_playlist(user_id)
|
||||
await self._unset_current_playlist()
|
||||
|
||||
playlist_data = {
|
||||
"user_id": user_id,
|
||||
@@ -99,18 +149,31 @@ class PlaylistService:
|
||||
|
||||
playlist = await self.playlist_repo.create(playlist_data)
|
||||
logger.info("Created playlist '%s' for user %s", name, user_id)
|
||||
|
||||
# If this was set as current, reload player playlist
|
||||
if is_current:
|
||||
await _reload_player_playlist()
|
||||
|
||||
return playlist
|
||||
|
||||
async def update_playlist(
|
||||
async def update_playlist( # noqa: PLR0913
|
||||
self,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
genre: str | None = None,
|
||||
is_current: bool | None = None,
|
||||
) -> Playlist:
|
||||
"""Update a playlist."""
|
||||
# Check if this is the main playlist
|
||||
if await self._is_main_playlist(playlist_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="The main playlist cannot be edited",
|
||||
)
|
||||
|
||||
playlist = await self.get_playlist_by_id(playlist_id)
|
||||
|
||||
update_data: dict[str, Any] = {}
|
||||
@@ -137,17 +200,28 @@ class PlaylistService:
|
||||
|
||||
if is_current is not None:
|
||||
if is_current:
|
||||
await self._unset_current_playlist(user_id)
|
||||
await self._unset_current_playlist()
|
||||
update_data["is_current"] = is_current
|
||||
|
||||
if update_data:
|
||||
playlist = await self.playlist_repo.update(playlist, update_data)
|
||||
logger.info("Updated playlist %s for user %s", playlist_id, user_id)
|
||||
|
||||
# If is_current was changed, reload player playlist
|
||||
if "is_current" in update_data:
|
||||
await _reload_player_playlist()
|
||||
|
||||
return playlist
|
||||
|
||||
async def delete_playlist(self, playlist_id: int, user_id: int) -> None:
|
||||
"""Delete a playlist."""
|
||||
# Check if this is the main playlist
|
||||
if await self._is_main_playlist(playlist_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="The main playlist cannot be deleted",
|
||||
)
|
||||
|
||||
playlist = await self.get_playlist_by_id(playlist_id)
|
||||
|
||||
if not playlist.is_deletable:
|
||||
@@ -156,15 +230,27 @@ class PlaylistService:
|
||||
detail="This playlist cannot be deleted",
|
||||
)
|
||||
|
||||
# Check if this is the current playlist
|
||||
# Check if this was the current playlist before deleting
|
||||
was_current = playlist.is_current
|
||||
|
||||
# First, delete all playlist_sound relationships
|
||||
await self._delete_playlist_sounds(playlist_id)
|
||||
|
||||
# Then delete the playlist itself
|
||||
await self.playlist_repo.delete(playlist)
|
||||
logger.info("Deleted playlist %s for user %s", playlist_id, user_id)
|
||||
|
||||
# If the deleted playlist was current, set main playlist as current
|
||||
if was_current:
|
||||
await self._set_main_as_current(user_id)
|
||||
main_playlist = await self.get_main_playlist()
|
||||
await self.playlist_repo.update(main_playlist, {"is_current": True})
|
||||
logger.info(
|
||||
"Set main playlist as current after deleting current playlist %s",
|
||||
playlist_id,
|
||||
)
|
||||
|
||||
# Reload player to reflect the change
|
||||
await _reload_player_playlist()
|
||||
|
||||
async def search_playlists(self, query: str, user_id: int) -> list[Playlist]:
|
||||
"""Search user's playlists by name."""
|
||||
@@ -174,15 +260,91 @@ class PlaylistService:
|
||||
"""Search all playlists by name."""
|
||||
return await self.playlist_repo.search_by_name(query)
|
||||
|
||||
async def search_and_sort_playlists( # noqa: PLR0913
|
||||
self,
|
||||
search_query: str | None = None,
|
||||
sort_by: PlaylistSortField | None = None,
|
||||
sort_order: SortOrder = SortOrder.ASC,
|
||||
user_id: int | None = None,
|
||||
*,
|
||||
include_stats: bool = False,
|
||||
limit: int | None = None,
|
||||
offset: int = 0,
|
||||
favorites_only: bool = False,
|
||||
current_user_id: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Search and sort playlists with optional statistics."""
|
||||
return await self.playlist_repo.search_and_sort(
|
||||
search_query=search_query,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
user_id=user_id,
|
||||
include_stats=include_stats,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
favorites_only=favorites_only,
|
||||
current_user_id=current_user_id,
|
||||
)
|
||||
|
||||
async def search_and_sort_playlists_paginated( # noqa: PLR0913
|
||||
self,
|
||||
search_query: str | None = None,
|
||||
sort_by: PlaylistSortField | None = None,
|
||||
sort_order: SortOrder = SortOrder.ASC,
|
||||
user_id: int | None = None,
|
||||
*,
|
||||
include_stats: bool = False,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
favorites_only: bool = False,
|
||||
current_user_id: int | None = None,
|
||||
) -> PaginatedPlaylistsResponse:
|
||||
"""Search and sort playlists with pagination."""
|
||||
offset = (page - 1) * limit
|
||||
|
||||
playlists, total_count = await self.playlist_repo.search_and_sort(
|
||||
search_query=search_query,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
user_id=user_id,
|
||||
include_stats=include_stats,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
favorites_only=favorites_only,
|
||||
current_user_id=current_user_id,
|
||||
return_count=True,
|
||||
)
|
||||
|
||||
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||
|
||||
return PaginatedPlaylistsResponse(
|
||||
playlists=playlists,
|
||||
total=total_count,
|
||||
page=page,
|
||||
limit=limit,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
|
||||
"""Get all sounds in a playlist."""
|
||||
await self.get_playlist_by_id(playlist_id) # Verify playlist exists
|
||||
return await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||
|
||||
async def add_sound_to_playlist(
|
||||
self, playlist_id: int, sound_id: int, user_id: int, position: int | None = None
|
||||
self,
|
||||
playlist_id: int,
|
||||
sound_id: int,
|
||||
user_id: int,
|
||||
position: int | None = None,
|
||||
) -> None:
|
||||
"""Add a sound to a playlist."""
|
||||
# Check if this is the main playlist
|
||||
if await self._is_main_playlist(playlist_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Sounds cannot be added to the main playlist",
|
||||
)
|
||||
|
||||
# Verify playlist exists
|
||||
await self.get_playlist_by_id(playlist_id)
|
||||
|
||||
@@ -201,15 +363,43 @@ class PlaylistService:
|
||||
detail="Sound is already in this playlist",
|
||||
)
|
||||
|
||||
# If position is None or beyond current positions, place at the end
|
||||
if position is None:
|
||||
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||
position = len(current_sounds)
|
||||
else:
|
||||
# Ensure position doesn't create gaps - if too high, place at end
|
||||
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||
max_position = len(current_sounds)
|
||||
position = min(position, max_position)
|
||||
|
||||
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
|
||||
logger.info(
|
||||
"Added sound %s to playlist %s for user %s", sound_id, playlist_id, user_id
|
||||
"Added sound %s to playlist %s for user %s at position %s",
|
||||
sound_id,
|
||||
playlist_id,
|
||||
user_id,
|
||||
position,
|
||||
)
|
||||
|
||||
# If this is the current playlist, reload player
|
||||
if await _is_current_playlist(self.session, playlist_id):
|
||||
await _reload_player_playlist()
|
||||
|
||||
async def remove_sound_from_playlist(
|
||||
self, playlist_id: int, sound_id: int, user_id: int
|
||||
self,
|
||||
playlist_id: int,
|
||||
sound_id: int,
|
||||
user_id: int,
|
||||
) -> None:
|
||||
"""Remove a sound from a playlist."""
|
||||
# Check if this is the main playlist
|
||||
if await self._is_main_playlist(playlist_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Sounds cannot be removed from the main playlist",
|
||||
)
|
||||
|
||||
# Verify playlist exists
|
||||
await self.get_playlist_by_id(playlist_id)
|
||||
|
||||
@@ -221,15 +411,41 @@ class PlaylistService:
|
||||
)
|
||||
|
||||
await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id)
|
||||
|
||||
# Reorder remaining sounds to eliminate gaps
|
||||
await self._reorder_playlist_positions(playlist_id)
|
||||
|
||||
logger.info(
|
||||
"Removed sound %s from playlist %s for user %s",
|
||||
"Removed sound %s from playlist %s for user %s and reordered positions",
|
||||
sound_id,
|
||||
playlist_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
# If this is the current playlist, reload player
|
||||
if await _is_current_playlist(self.session, playlist_id):
|
||||
await _reload_player_playlist()
|
||||
|
||||
async def _reorder_playlist_positions(self, playlist_id: int) -> None:
|
||||
"""Reorder all sounds in a playlist to eliminate position gaps."""
|
||||
sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||
if not sounds:
|
||||
return
|
||||
|
||||
# Create sequential positions: 0, 1, 2, 3...
|
||||
sound_positions = [(sound.id, index) for index, sound in enumerate(sounds)]
|
||||
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)
|
||||
logger.debug(
|
||||
"Reordered %s sounds in playlist %s to eliminate gaps",
|
||||
len(sounds),
|
||||
playlist_id,
|
||||
)
|
||||
|
||||
async def reorder_playlist_sounds(
|
||||
self, playlist_id: int, user_id: int, sound_positions: list[tuple[int, int]]
|
||||
self,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
sound_positions: list[tuple[int, int]],
|
||||
) -> None:
|
||||
"""Reorder sounds in a playlist."""
|
||||
# Verify playlist exists
|
||||
@@ -246,25 +462,9 @@ class PlaylistService:
|
||||
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)
|
||||
logger.info("Reordered sounds in playlist %s for user %s", playlist_id, user_id)
|
||||
|
||||
async def set_current_playlist(self, playlist_id: int, user_id: int) -> Playlist:
|
||||
"""Set a playlist as the current playlist."""
|
||||
playlist = await self.get_playlist_by_id(playlist_id)
|
||||
|
||||
# Unset previous current playlist
|
||||
await self._unset_current_playlist(user_id)
|
||||
|
||||
# Set new current playlist
|
||||
playlist = await self.playlist_repo.update(playlist, {"is_current": True})
|
||||
logger.info("Set playlist %s as current for user %s", playlist_id, user_id)
|
||||
return playlist
|
||||
|
||||
async def unset_current_playlist(self, user_id: int) -> None:
|
||||
"""Unset the current playlist and set main playlist as current."""
|
||||
await self._unset_current_playlist(user_id)
|
||||
await self._set_main_as_current(user_id)
|
||||
logger.info(
|
||||
"Unset current playlist and set main as current for user %s", user_id
|
||||
)
|
||||
# If this is the current playlist, reload player
|
||||
if await _is_current_playlist(self.session, playlist_id):
|
||||
await _reload_player_playlist()
|
||||
|
||||
async def get_playlist_stats(self, playlist_id: int) -> dict[str, Any]:
|
||||
"""Get statistics for a playlist."""
|
||||
@@ -282,36 +482,97 @@ class PlaylistService:
|
||||
"total_play_count": total_plays,
|
||||
}
|
||||
|
||||
async def add_sound_to_main_playlist(self, sound_id: int, user_id: int) -> None:
|
||||
async def add_sound_to_main_playlist(
|
||||
self,
|
||||
sound_id: int, # noqa: ARG002
|
||||
user_id: int, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Add a sound to the global main playlist."""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Sounds cannot be added to the main playlist",
|
||||
)
|
||||
|
||||
async def _add_sound_to_main_playlist_internal(
|
||||
self,
|
||||
sound_id: int,
|
||||
user_id: int,
|
||||
) -> None:
|
||||
"""Add sound to main playlist bypassing restrictions.
|
||||
|
||||
This method is intended for internal system use only (e.g., extraction service).
|
||||
It bypasses the main playlist modification restrictions.
|
||||
"""
|
||||
main_playlist = await self.get_main_playlist()
|
||||
|
||||
if main_playlist.id is None:
|
||||
raise ValueError("Main playlist has no ID")
|
||||
msg = "Main playlist has no ID, cannot add sound"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Extract ID before async operations to avoid session issues
|
||||
main_playlist_id = main_playlist.id
|
||||
|
||||
# Check if sound is already in main playlist
|
||||
if not await self.playlist_repo.is_sound_in_playlist(
|
||||
main_playlist.id, sound_id
|
||||
main_playlist_id,
|
||||
sound_id,
|
||||
):
|
||||
await self.playlist_repo.add_sound_to_playlist(main_playlist.id, sound_id)
|
||||
await self.playlist_repo.add_sound_to_playlist(main_playlist_id, sound_id)
|
||||
logger.info(
|
||||
"Added sound %s to main playlist for user %s",
|
||||
"Added sound %s to main playlist for user %s (internal)",
|
||||
sound_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def _unset_current_playlist(self, user_id: int) -> None:
|
||||
"""Unset the current playlist for a user."""
|
||||
current_playlist = await self.playlist_repo.get_current_playlist(user_id)
|
||||
# If main playlist is current, reload player
|
||||
if await _is_current_playlist(self.session, main_playlist_id):
|
||||
await _reload_player_playlist()
|
||||
|
||||
# Current playlist methods (global by default)
|
||||
async def set_current_playlist(self, playlist_id: int) -> Playlist:
|
||||
"""Set a playlist as the current playlist (app-wide)."""
|
||||
playlist = await self.get_playlist_by_id(playlist_id)
|
||||
|
||||
# Unset any existing current playlist globally
|
||||
await self._unset_current_playlist()
|
||||
|
||||
# Set new current playlist
|
||||
playlist = await self.playlist_repo.update(playlist, {"is_current": True})
|
||||
logger.info("Set playlist %s as current playlist", playlist_id)
|
||||
|
||||
# Reload player playlist to reflect the change
|
||||
await _reload_player_playlist()
|
||||
|
||||
return playlist
|
||||
|
||||
async def unset_current_playlist(self) -> None:
|
||||
"""Unset the current playlist (main playlist becomes fallback)."""
|
||||
await self._unset_current_playlist()
|
||||
logger.info("Unset current playlist, main playlist is now fallback")
|
||||
|
||||
# Reload player playlist to reflect the change (will fallback to main)
|
||||
await _reload_player_playlist()
|
||||
|
||||
async def _delete_playlist_sounds(self, playlist_id: int) -> None:
|
||||
"""Delete all playlist_sound records for a given playlist."""
|
||||
# Get all playlist_sound records for this playlist
|
||||
stmt = select(PlaylistSound).where(PlaylistSound.playlist_id == playlist_id)
|
||||
result = await self.session.exec(stmt)
|
||||
playlist_sounds = result.all()
|
||||
|
||||
# Delete each playlist_sound record
|
||||
for playlist_sound in playlist_sounds:
|
||||
await self.session.delete(playlist_sound)
|
||||
|
||||
await self.session.commit()
|
||||
logger.info(
|
||||
"Deleted %d playlist_sound records for playlist %s",
|
||||
len(playlist_sounds),
|
||||
playlist_id,
|
||||
)
|
||||
|
||||
async def _unset_current_playlist(self) -> None:
|
||||
"""Unset any current playlist globally."""
|
||||
current_playlist = await self.playlist_repo.get_current_playlist()
|
||||
if current_playlist:
|
||||
await self.playlist_repo.update(current_playlist, {"is_current": False})
|
||||
|
||||
async def _set_main_as_current(self, user_id: int) -> None:
|
||||
"""Unset current playlist so main playlist becomes the fallback current."""
|
||||
# Just ensure no user playlist is marked as current
|
||||
# The get_current_playlist method will fallback to main playlist
|
||||
await self._unset_current_playlist(user_id)
|
||||
logger.info(
|
||||
"Unset current playlist for user %s, main playlist is now fallback",
|
||||
user_id,
|
||||
)
|
||||
|
||||
490
app/services/scheduler.py
Normal file
490
app/services/scheduler.py
Normal file
@@ -0,0 +1,490 @@
|
||||
"""Enhanced scheduler service for flexible task scheduling with timezone support."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytz
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
from app.schemas.scheduler import ScheduledTaskCreate
|
||||
from app.services.credit import CreditService
|
||||
from app.services.player import PlayerService
|
||||
from app.services.task_handlers import TaskHandlerRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""Enhanced service for managing scheduled tasks with timezone support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_session_factory: Callable[[], AsyncSession],
|
||||
player_service: PlayerService,
|
||||
) -> None:
|
||||
"""Initialize the scheduler service.
|
||||
|
||||
Args:
|
||||
db_session_factory: Factory function to create database sessions
|
||||
player_service: Player service for audio playback tasks
|
||||
|
||||
"""
|
||||
self.db_session_factory = db_session_factory
|
||||
self.scheduler = AsyncIOScheduler(timezone=pytz.UTC)
|
||||
self.credit_service = CreditService(db_session_factory)
|
||||
self.player_service = player_service
|
||||
self._running_tasks: set[str] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the scheduler and load all active tasks."""
|
||||
logger.info("Starting enhanced scheduler service...")
|
||||
|
||||
self.scheduler.start()
|
||||
|
||||
# Schedule system tasks initialization for after startup
|
||||
self.scheduler.add_job(
|
||||
self._initialize_system_tasks,
|
||||
"date",
|
||||
run_date=datetime.now(tz=UTC) + timedelta(seconds=2),
|
||||
id="initialize_system_tasks",
|
||||
name="Initialize System Tasks",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# Schedule periodic cleanup and maintenance
|
||||
self.scheduler.add_job(
|
||||
self._maintenance_job,
|
||||
"interval",
|
||||
minutes=5,
|
||||
id="scheduler_maintenance",
|
||||
name="Scheduler Maintenance",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.info("Enhanced scheduler service started successfully")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the scheduler."""
|
||||
logger.info("Stopping scheduler service...")
|
||||
self.scheduler.shutdown(wait=True)
|
||||
logger.info("Scheduler service stopped")
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
task_data: ScheduledTaskCreate,
|
||||
user_id: int | None = None,
|
||||
) -> ScheduledTask:
|
||||
"""Create a new scheduled task from schema data."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
# Convert scheduled_at to UTC if it's in a different timezone
|
||||
scheduled_at = task_data.scheduled_at
|
||||
if task_data.timezone != "UTC":
|
||||
tz = pytz.timezone(task_data.timezone)
|
||||
if scheduled_at.tzinfo is None:
|
||||
# Assume the datetime is in the specified timezone
|
||||
scheduled_at = tz.localize(scheduled_at)
|
||||
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
|
||||
|
||||
db_task_data = {
|
||||
"name": task_data.name,
|
||||
"task_type": task_data.task_type,
|
||||
"scheduled_at": scheduled_at,
|
||||
"timezone": task_data.timezone,
|
||||
"parameters": task_data.parameters,
|
||||
"user_id": user_id,
|
||||
"recurrence_type": task_data.recurrence_type,
|
||||
"cron_expression": task_data.cron_expression,
|
||||
"recurrence_count": task_data.recurrence_count,
|
||||
"expires_at": task_data.expires_at,
|
||||
}
|
||||
|
||||
created_task = await repo.create(db_task_data)
|
||||
await self._schedule_apscheduler_job(created_task)
|
||||
|
||||
logger.info(
|
||||
"Created scheduled task: %s (%s)",
|
||||
created_task.name,
|
||||
created_task.id,
|
||||
)
|
||||
return created_task
|
||||
|
||||
async def cancel_task(self, task_id: int) -> bool:
|
||||
"""Cancel a scheduled task."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
task = await repo.get_by_id(task_id)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
await repo.update(task, {
|
||||
"status": TaskStatus.CANCELLED,
|
||||
"is_active": False,
|
||||
})
|
||||
|
||||
# Remove from APScheduler (job might not exist in scheduler)
|
||||
with suppress(Exception):
|
||||
self.scheduler.remove_job(str(task_id))
|
||||
|
||||
logger.info("Cancelled task: %s (%s)", task.name, task_id)
|
||||
return True
|
||||
|
||||
async def delete_task(self, task_id: int) -> bool:
|
||||
"""Delete a scheduled task completely."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
task = await repo.get_by_id(task_id)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
# Remove from APScheduler first (job might not exist in scheduler)
|
||||
with suppress(Exception):
|
||||
self.scheduler.remove_job(str(task_id))
|
||||
|
||||
# Delete from database
|
||||
await repo.delete(task)
|
||||
|
||||
logger.info("Deleted task: %s (%s)", task.name, task_id)
|
||||
return True
|
||||
|
||||
async def get_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
status: TaskStatus | None = None,
|
||||
task_type: TaskType | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get tasks for a specific user."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
|
||||
|
||||
async def _initialize_system_tasks(self) -> None:
|
||||
"""Initialize system tasks and load active tasks from database."""
|
||||
logger.info("Initializing system tasks...")
|
||||
|
||||
try:
|
||||
# Create system tasks if they don't exist
|
||||
await self._ensure_system_tasks()
|
||||
|
||||
# Load all active tasks from database
|
||||
await self._load_active_tasks()
|
||||
|
||||
logger.info("System tasks initialized successfully")
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize system tasks")
|
||||
|
||||
async def _ensure_system_tasks(self) -> None:
|
||||
"""Ensure required system tasks exist."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
# Check if daily credit recharge task exists
|
||||
system_tasks = await repo.get_system_tasks(
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
)
|
||||
|
||||
daily_recharge_exists = any(
|
||||
task.recurrence_type == RecurrenceType.DAILY
|
||||
and task.is_active
|
||||
for task in system_tasks
|
||||
)
|
||||
|
||||
if not daily_recharge_exists:
|
||||
# Create daily credit recharge task
|
||||
tomorrow_midnight = datetime.now(tz=UTC).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0,
|
||||
) + timedelta(days=1)
|
||||
|
||||
task_data = {
|
||||
"name": "Daily Credit Recharge",
|
||||
"task_type": TaskType.CREDIT_RECHARGE,
|
||||
"scheduled_at": tomorrow_midnight,
|
||||
"recurrence_type": RecurrenceType.DAILY,
|
||||
"parameters": {},
|
||||
}
|
||||
|
||||
await repo.create(task_data)
|
||||
logger.info("Created system daily credit recharge task")
|
||||
|
||||
async def _load_active_tasks(self) -> None:
|
||||
"""Load all active tasks from database into scheduler."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
active_tasks = await repo.get_active_tasks()
|
||||
|
||||
for task in active_tasks:
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
logger.info("Loaded %s active tasks into scheduler", len(active_tasks))
|
||||
|
||||
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
|
||||
"""Schedule a task in APScheduler."""
|
||||
job_id = str(task.id)
|
||||
|
||||
# Remove existing job if it exists
|
||||
with suppress(Exception):
|
||||
self.scheduler.remove_job(job_id)
|
||||
|
||||
# Don't schedule if task is not active or already completed/failed
|
||||
inactive_statuses = [
|
||||
TaskStatus.COMPLETED,
|
||||
TaskStatus.FAILED,
|
||||
TaskStatus.CANCELLED,
|
||||
]
|
||||
if not task.is_active or task.status in inactive_statuses:
|
||||
return
|
||||
|
||||
# Create trigger based on recurrence type
|
||||
trigger = self._create_trigger(task)
|
||||
if not trigger:
|
||||
logger.warning("Could not create trigger for task %s", task.id)
|
||||
return
|
||||
|
||||
# Schedule the job
|
||||
self.scheduler.add_job(
|
||||
self._execute_task,
|
||||
trigger=trigger,
|
||||
args=[task.id],
|
||||
id=job_id,
|
||||
name=task.name,
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.debug("Scheduled APScheduler job for task %s", task.id)
|
||||
|
||||
def _create_trigger(
|
||||
self, task: ScheduledTask,
|
||||
) -> DateTrigger | IntervalTrigger | CronTrigger | None:
|
||||
"""Create APScheduler trigger based on task configuration."""
|
||||
tz = pytz.timezone(task.timezone)
|
||||
scheduled_time = task.scheduled_at
|
||||
|
||||
# Handle special cases first
|
||||
if task.recurrence_type == RecurrenceType.NONE:
|
||||
return DateTrigger(run_date=scheduled_time, timezone=tz)
|
||||
|
||||
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
||||
|
||||
# Handle interval-based recurrence types
|
||||
interval_configs = {
|
||||
RecurrenceType.MINUTELY: {"minutes": 1},
|
||||
RecurrenceType.HOURLY: {"hours": 1},
|
||||
RecurrenceType.DAILY: {"days": 1},
|
||||
RecurrenceType.WEEKLY: {"weeks": 1},
|
||||
}
|
||||
|
||||
if task.recurrence_type in interval_configs:
|
||||
config = interval_configs[task.recurrence_type]
|
||||
return IntervalTrigger(start_date=scheduled_time, timezone=tz, **config)
|
||||
|
||||
# Handle cron-based recurrence types
|
||||
cron_configs = {
|
||||
RecurrenceType.MONTHLY: {
|
||||
"day": scheduled_time.day,
|
||||
"hour": scheduled_time.hour,
|
||||
"minute": scheduled_time.minute,
|
||||
},
|
||||
RecurrenceType.YEARLY: {
|
||||
"month": scheduled_time.month,
|
||||
"day": scheduled_time.day,
|
||||
"hour": scheduled_time.hour,
|
||||
"minute": scheduled_time.minute,
|
||||
},
|
||||
}
|
||||
|
||||
if task.recurrence_type in cron_configs:
|
||||
config = cron_configs[task.recurrence_type]
|
||||
return CronTrigger(timezone=tz, **config)
|
||||
|
||||
return None
|
||||
|
||||
async def _execute_task(self, task_id: int) -> None:
|
||||
"""Execute a scheduled task."""
|
||||
task_id_str = str(task_id)
|
||||
|
||||
logger.info("APScheduler triggered task %s execution", task_id)
|
||||
|
||||
# Prevent concurrent execution of the same task
|
||||
if task_id_str in self._running_tasks:
|
||||
logger.warning("Task %s is already running, skipping execution", task_id)
|
||||
return
|
||||
|
||||
self._running_tasks.add(task_id_str)
|
||||
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
# Get fresh task data
|
||||
task = await repo.get_by_id(task_id)
|
||||
if not task:
|
||||
logger.warning("Task %s not found", task_id)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Task %s current state - status: %s, is_active: %s, executions: %s",
|
||||
task_id, task.status, task.is_active, task.executions_count,
|
||||
)
|
||||
|
||||
# Check if task is still active and pending
|
||||
if not task.is_active or task.status != TaskStatus.PENDING:
|
||||
logger.warning(
|
||||
"Task %s execution skipped - is_active: %s, status: %s "
|
||||
"(should be %s)",
|
||||
task_id,
|
||||
task.is_active,
|
||||
task.status,
|
||||
TaskStatus.PENDING,
|
||||
)
|
||||
return
|
||||
|
||||
# Check if task has expired
|
||||
if task.is_expired():
|
||||
logger.info("Task %s has expired, marking as cancelled", task_id)
|
||||
await repo.update(task, {
|
||||
"status": TaskStatus.CANCELLED,
|
||||
"is_active": False,
|
||||
})
|
||||
return
|
||||
|
||||
# Mark task as running
|
||||
logger.info(
|
||||
"Task %s starting execution (type: %s)",
|
||||
task_id,
|
||||
task.recurrence_type,
|
||||
)
|
||||
await repo.mark_as_running(task)
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
handler_registry = TaskHandlerRegistry(
|
||||
session,
|
||||
self.db_session_factory,
|
||||
self.credit_service,
|
||||
self.player_service,
|
||||
)
|
||||
await handler_registry.execute_task(task)
|
||||
|
||||
# Handle completion based on task type
|
||||
if task.recurrence_type == RecurrenceType.CRON:
|
||||
# For CRON tasks, update execution metadata but keep PENDING
|
||||
# APScheduler handles the recurring schedule automatically
|
||||
logger.info(
|
||||
"Task %s (CRON) executed successfully, updating metadata",
|
||||
task_id,
|
||||
)
|
||||
task.last_executed_at = datetime.now(tz=UTC)
|
||||
task.executions_count += 1
|
||||
task.error_message = None
|
||||
task.status = TaskStatus.PENDING # Explicitly set to PENDING
|
||||
session.add(task)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Task %s (CRON) metadata updated, status: %s, "
|
||||
"executions: %s",
|
||||
task_id,
|
||||
task.status,
|
||||
task.executions_count,
|
||||
)
|
||||
else:
|
||||
# For non-CRON recurring tasks, calculate next execution
|
||||
next_execution_at = None
|
||||
if task.should_repeat():
|
||||
next_execution_at = self._calculate_next_execution(task)
|
||||
|
||||
# Mark as completed
|
||||
await repo.mark_as_completed(task, next_execution_at)
|
||||
|
||||
# Reschedule if recurring
|
||||
if next_execution_at and task.should_repeat():
|
||||
# Refresh task to get updated data
|
||||
await session.refresh(task)
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
except Exception as e:
|
||||
await repo.mark_as_failed(task, str(e))
|
||||
logger.exception("Task %s execution failed", task_id)
|
||||
|
||||
finally:
|
||||
self._running_tasks.discard(task_id_str)
|
||||
|
||||
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
|
||||
"""Calculate the next execution time for a recurring task."""
|
||||
now = datetime.now(tz=UTC)
|
||||
|
||||
recurrence_deltas = {
|
||||
RecurrenceType.MINUTELY: timedelta(minutes=1),
|
||||
RecurrenceType.HOURLY: timedelta(hours=1),
|
||||
RecurrenceType.DAILY: timedelta(days=1),
|
||||
RecurrenceType.WEEKLY: timedelta(weeks=1),
|
||||
RecurrenceType.MONTHLY: timedelta(days=30), # Approximate
|
||||
RecurrenceType.YEARLY: timedelta(days=365), # Approximate
|
||||
}
|
||||
|
||||
if task.recurrence_type in recurrence_deltas:
|
||||
return now + recurrence_deltas[task.recurrence_type]
|
||||
|
||||
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||
# For CRON tasks, let APScheduler handle the timing
|
||||
return now
|
||||
|
||||
return None
|
||||
|
||||
async def _maintenance_job(self) -> None:
|
||||
"""Periodic maintenance job to clean up expired tasks and handle scheduling."""
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
# Handle expired tasks
|
||||
expired_tasks = await repo.get_expired_tasks()
|
||||
for task in expired_tasks:
|
||||
await repo.update(task, {
|
||||
"status": TaskStatus.CANCELLED,
|
||||
"is_active": False,
|
||||
})
|
||||
|
||||
# Remove from scheduler
|
||||
with suppress(Exception):
|
||||
self.scheduler.remove_job(str(task.id))
|
||||
|
||||
if expired_tasks:
|
||||
logger.info("Cleaned up %s expired tasks", len(expired_tasks))
|
||||
|
||||
# Handle any missed recurring tasks
|
||||
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
|
||||
for task in due_recurring:
|
||||
if task.should_repeat():
|
||||
next_scheduled_at = (
|
||||
task.next_execution_at or datetime.now(tz=UTC)
|
||||
)
|
||||
await repo.update(task, {
|
||||
"status": TaskStatus.PENDING,
|
||||
"scheduled_at": next_scheduled_at,
|
||||
})
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
if due_recurring:
|
||||
logger.info("Rescheduled %s recurring tasks", len(due_recurring))
|
||||
|
||||
except Exception:
|
||||
logger.exception("Maintenance job failed")
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
|
||||
import socketio
|
||||
|
||||
from app.core.config import settings
|
||||
from app.utils.auth import JWTUtils
|
||||
from app.utils.cookies import extract_access_token_from_cookies
|
||||
|
||||
@@ -13,9 +14,10 @@ logger = logging.getLogger(__name__)
|
||||
class SocketManager:
|
||||
"""Manages WebSocket connections and user rooms."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the SocketManager with a Socket.IO server."""
|
||||
self.sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=["http://localhost:8001"],
|
||||
cors_allowed_origins=settings.CORS_ORIGINS,
|
||||
logger=True,
|
||||
engineio_logger=True,
|
||||
async_mode="asgi",
|
||||
@@ -27,20 +29,20 @@ class SocketManager:
|
||||
|
||||
self._setup_handlers()
|
||||
|
||||
def _setup_handlers(self):
|
||||
def _setup_handlers(self) -> None:
|
||||
"""Set up socket event handlers."""
|
||||
|
||||
@self.sio.event
|
||||
async def connect(sid, environ, auth=None):
|
||||
async def connect(sid: str, environ: dict) -> None:
|
||||
"""Handle client connection."""
|
||||
logger.info(f"Client {sid} attempting to connect")
|
||||
logger.info("Client %s attempting to connect", sid)
|
||||
|
||||
# Extract access token from cookies
|
||||
cookie_header = environ.get("HTTP_COOKIE", "")
|
||||
access_token = extract_access_token_from_cookies(cookie_header)
|
||||
|
||||
if not access_token:
|
||||
logger.warning(f"Client {sid} connecting without access token")
|
||||
logger.warning("Client %s connecting without access token", sid)
|
||||
await self.sio.disconnect(sid)
|
||||
return
|
||||
|
||||
@@ -50,13 +52,13 @@ class SocketManager:
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
logger.warning(f"Client {sid} token missing user ID")
|
||||
logger.warning("Client %s token missing user ID", sid)
|
||||
await self.sio.disconnect(sid)
|
||||
return
|
||||
|
||||
logger.info(f"User {user_id} connected with socket {sid}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Client {sid} invalid token: {e}")
|
||||
logger.info("User %s connected with socket %s", user_id, sid)
|
||||
except Exception:
|
||||
logger.exception("Client %s invalid token", sid)
|
||||
await self.sio.disconnect(sid)
|
||||
return
|
||||
|
||||
@@ -70,7 +72,7 @@ class SocketManager:
|
||||
# Update room tracking
|
||||
self.user_rooms[user_id] = room_id
|
||||
|
||||
logger.info(f"User {user_id} joined room {room_id}")
|
||||
logger.info("User %s joined room %s", user_id, room_id)
|
||||
|
||||
# Send welcome message to user
|
||||
await self.sio.emit(
|
||||
@@ -84,33 +86,78 @@ class SocketManager:
|
||||
)
|
||||
|
||||
@self.sio.event
|
||||
async def disconnect(sid):
|
||||
async def disconnect(sid: str) -> None:
|
||||
"""Handle client disconnection."""
|
||||
user_id = self.socket_users.get(sid)
|
||||
|
||||
if user_id:
|
||||
logger.info(f"User {user_id} disconnected (socket {sid})")
|
||||
logger.info("User %s disconnected (socket %s)", user_id, sid)
|
||||
# Clean up mappings
|
||||
del self.socket_users[sid]
|
||||
if user_id in self.user_rooms:
|
||||
del self.user_rooms[user_id]
|
||||
else:
|
||||
logger.info(f"Unknown client {sid} disconnected")
|
||||
logger.info("Unknown client %s disconnected", sid)
|
||||
|
||||
async def send_to_user(self, user_id: str, event: str, data: dict):
|
||||
@self.sio.event
|
||||
async def play_sound(sid: str, data: dict) -> None:
|
||||
"""Handle play sound event from client."""
|
||||
await self._handle_play_sound(sid, data)
|
||||
|
||||
async def _handle_play_sound(self, sid: str, data: dict) -> None:
|
||||
"""Handle play sound request from WebSocket client."""
|
||||
user_id = self.socket_users.get(sid)
|
||||
|
||||
if not user_id:
|
||||
logger.warning("Play sound request from unknown client %s", sid)
|
||||
return
|
||||
|
||||
sound_id = data.get("sound_id")
|
||||
if not sound_id:
|
||||
logger.warning(
|
||||
"Play sound request missing sound_id from user %s",
|
||||
user_id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from app.core.database import get_session_factory # noqa: PLC0415
|
||||
from app.services.vlc_player import get_vlc_player_service # noqa: PLC0415
|
||||
|
||||
# Get VLC player service with database factory
|
||||
vlc_player = get_vlc_player_service(get_session_factory())
|
||||
|
||||
# Call the service method
|
||||
await vlc_player.play_sound_with_credits(int(sound_id), int(user_id))
|
||||
logger.info("User %s played sound %s via WebSocket", user_id, sound_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error playing sound %s for user %s",
|
||||
sound_id,
|
||||
user_id,
|
||||
)
|
||||
# Emit error back to user
|
||||
await self.sio.emit(
|
||||
"sound_play_error",
|
||||
{"sound_id": sound_id, "error": str(e)},
|
||||
room=sid,
|
||||
)
|
||||
|
||||
async def send_to_user(self, user_id: str, event: str, data: dict) -> bool:
|
||||
"""Send a message to a specific user's room."""
|
||||
room_id = self.user_rooms.get(user_id)
|
||||
if room_id:
|
||||
await self.sio.emit(event, data, room=room_id)
|
||||
logger.debug(f"Sent {event} to user {user_id} in room {room_id}")
|
||||
logger.debug("Sent %s to user %s in room %s", event, user_id, room_id)
|
||||
return True
|
||||
logger.warning(f"User {user_id} not found in any room")
|
||||
logger.warning("User %s not found in any room", user_id)
|
||||
return False
|
||||
|
||||
async def broadcast_to_all(self, event: str, data: dict):
|
||||
async def broadcast_to_all(self, event: str, data: dict) -> None:
|
||||
"""Broadcast a message to all connected users."""
|
||||
await self.sio.emit(event, data)
|
||||
logger.info(f"Broadcasted {event} to all users")
|
||||
logger.info("Broadcasted %s to all users", event)
|
||||
|
||||
def get_connected_users(self) -> list:
|
||||
"""Get list of currently connected user IDs."""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Sound normalizer service for normalizing audio files using ffmpeg loudnorm."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -138,10 +139,15 @@ class SoundNormalizerService:
|
||||
stream = ffmpeg.output(stream, str(output_path), **output_args)
|
||||
stream = ffmpeg.overwrite_output(stream)
|
||||
|
||||
ffmpeg.run(stream, quiet=True, overwrite_output=True)
|
||||
await asyncio.to_thread(
|
||||
ffmpeg.run,
|
||||
stream,
|
||||
quiet=True,
|
||||
overwrite_output=True,
|
||||
)
|
||||
logger.info("One-pass normalization completed: %s", output_path)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("One-pass normalization failed for %s", input_path)
|
||||
raise
|
||||
|
||||
@@ -153,7 +159,9 @@ class SoundNormalizerService:
|
||||
"""Normalize audio using two-pass loudnorm for better quality."""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting two-pass normalization: %s -> %s", input_path, output_path
|
||||
"Starting two-pass normalization: %s -> %s",
|
||||
input_path,
|
||||
output_path,
|
||||
)
|
||||
|
||||
# First pass: analyze
|
||||
@@ -174,10 +182,15 @@ class SoundNormalizerService:
|
||||
|
||||
# Run first pass and capture output
|
||||
try:
|
||||
result = ffmpeg.run(stream, capture_stderr=True, quiet=True)
|
||||
result = await asyncio.to_thread(
|
||||
ffmpeg.run,
|
||||
stream,
|
||||
capture_stderr=True,
|
||||
quiet=True,
|
||||
)
|
||||
analysis_output = result[1].decode("utf-8")
|
||||
except ffmpeg.Error as e:
|
||||
logger.error(
|
||||
logger.exception(
|
||||
"FFmpeg first pass failed for %s. Stdout: %s, Stderr: %s",
|
||||
input_path,
|
||||
e.stdout.decode() if e.stdout else "None",
|
||||
@@ -193,9 +206,11 @@ class SoundNormalizerService:
|
||||
json_match = re.search(r'\{[^{}]*"input_i"[^{}]*\}', analysis_output)
|
||||
if not json_match:
|
||||
logger.error(
|
||||
"Could not find JSON in loudnorm output: %s", analysis_output
|
||||
"Could not find JSON in loudnorm output: %s",
|
||||
analysis_output,
|
||||
)
|
||||
raise ValueError("Could not extract loudnorm analysis data")
|
||||
msg = "Could not find JSON in loudnorm output"
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.debug("Found JSON match: %s", json_match.group())
|
||||
analysis_data = json.loads(json_match.group())
|
||||
@@ -211,7 +226,10 @@ class SoundNormalizerService:
|
||||
]:
|
||||
if str(analysis_data.get(key, "")).lower() in invalid_values:
|
||||
logger.warning(
|
||||
"Invalid analysis value for %s: %s. Falling back to one-pass normalization.",
|
||||
(
|
||||
"Invalid analysis value for %s: %s. "
|
||||
"Falling back to one-pass normalization."
|
||||
),
|
||||
key,
|
||||
analysis_data.get(key),
|
||||
)
|
||||
@@ -249,10 +267,15 @@ class SoundNormalizerService:
|
||||
stream = ffmpeg.overwrite_output(stream)
|
||||
|
||||
try:
|
||||
ffmpeg.run(stream, quiet=True, overwrite_output=True)
|
||||
await asyncio.to_thread(
|
||||
ffmpeg.run,
|
||||
stream,
|
||||
quiet=True,
|
||||
overwrite_output=True,
|
||||
)
|
||||
logger.info("Two-pass normalization completed: %s", output_path)
|
||||
except ffmpeg.Error as e:
|
||||
logger.error(
|
||||
logger.exception(
|
||||
"FFmpeg second pass failed for %s. Stdout: %s, Stderr: %s",
|
||||
input_path,
|
||||
e.stdout.decode() if e.stdout else "None",
|
||||
@@ -260,19 +283,21 @@ class SoundNormalizerService:
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Two-pass normalization failed for %s", input_path)
|
||||
raise
|
||||
|
||||
async def normalize_sound(
|
||||
self,
|
||||
sound: Sound,
|
||||
*,
|
||||
force: bool = False,
|
||||
one_pass: bool | None = None,
|
||||
sound_data: dict | None = None,
|
||||
) -> NormalizationInfo:
|
||||
"""Normalize a single sound."""
|
||||
# Use provided sound_data to avoid detached instance issues, or capture from sound
|
||||
# Use provided sound_data to avoid detached instance issues,
|
||||
# or capture from sound
|
||||
if sound_data:
|
||||
filename = sound_data["filename"]
|
||||
sound_id = sound_data["id"]
|
||||
@@ -391,6 +416,7 @@ class SoundNormalizerService:
|
||||
|
||||
async def normalize_all_sounds(
|
||||
self,
|
||||
*,
|
||||
force: bool = False,
|
||||
one_pass: bool | None = None,
|
||||
) -> NormalizationResults:
|
||||
@@ -409,7 +435,7 @@ class SoundNormalizerService:
|
||||
if force:
|
||||
# Get all sounds if forcing
|
||||
sounds = []
|
||||
for sound_type in self.type_directories.keys():
|
||||
for sound_type in self.type_directories:
|
||||
type_sounds = await self.sound_repo.get_by_type(sound_type)
|
||||
sounds.extend(type_sounds)
|
||||
else:
|
||||
@@ -419,17 +445,16 @@ class SoundNormalizerService:
|
||||
logger.info("Found %d sounds to process", len(sounds))
|
||||
|
||||
# Capture all sound data upfront to avoid session detachment issues
|
||||
sound_data_list = []
|
||||
for sound in sounds:
|
||||
sound_data_list.append(
|
||||
{
|
||||
"id": sound.id,
|
||||
"filename": sound.filename,
|
||||
"type": sound.type,
|
||||
"is_normalized": sound.is_normalized,
|
||||
"name": sound.name,
|
||||
}
|
||||
)
|
||||
sound_data_list = [
|
||||
{
|
||||
"id": sound.id,
|
||||
"filename": sound.filename,
|
||||
"type": sound.type,
|
||||
"is_normalized": sound.is_normalized,
|
||||
"name": sound.name,
|
||||
}
|
||||
for sound in sounds
|
||||
]
|
||||
|
||||
# Process each sound using captured data
|
||||
for i, sound in enumerate(sounds):
|
||||
@@ -476,7 +501,7 @@ class SoundNormalizerService:
|
||||
"normalized_hash": None,
|
||||
"id": sound_id,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Normalization completed: %s", results)
|
||||
@@ -485,6 +510,7 @@ class SoundNormalizerService:
|
||||
async def normalize_sounds_by_type(
|
||||
self,
|
||||
sound_type: str,
|
||||
*,
|
||||
force: bool = False,
|
||||
one_pass: bool | None = None,
|
||||
) -> NormalizationResults:
|
||||
@@ -508,17 +534,16 @@ class SoundNormalizerService:
|
||||
logger.info("Found %d %s sounds to process", len(sounds), sound_type)
|
||||
|
||||
# Capture all sound data upfront to avoid session detachment issues
|
||||
sound_data_list = []
|
||||
for sound in sounds:
|
||||
sound_data_list.append(
|
||||
{
|
||||
"id": sound.id,
|
||||
"filename": sound.filename,
|
||||
"type": sound.type,
|
||||
"is_normalized": sound.is_normalized,
|
||||
"name": sound.name,
|
||||
}
|
||||
)
|
||||
sound_data_list = [
|
||||
{
|
||||
"id": sound.id,
|
||||
"filename": sound.filename,
|
||||
"type": sound.type,
|
||||
"is_normalized": sound.is_normalized,
|
||||
"name": sound.name,
|
||||
}
|
||||
for sound in sounds
|
||||
]
|
||||
|
||||
# Process each sound using captured data
|
||||
for i, sound in enumerate(sounds):
|
||||
@@ -565,7 +590,7 @@ class SoundNormalizerService:
|
||||
"normalized_hash": None,
|
||||
"id": sound_id,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Type normalization completed: %s", results)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Sound scanner service for scanning and importing audio files."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -13,6 +14,28 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioFileInfo:
|
||||
"""Data class for audio file metadata."""
|
||||
|
||||
filename: str
|
||||
name: str
|
||||
duration: int
|
||||
size: int
|
||||
file_hash: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncContext:
|
||||
"""Context data for audio file synchronization."""
|
||||
|
||||
file_path: Path
|
||||
sound_type: str
|
||||
existing_sound_by_hash: dict | Sound | None
|
||||
existing_sound_by_filename: dict | Sound | None
|
||||
file_hash: str
|
||||
|
||||
|
||||
class FileInfo(TypedDict):
|
||||
"""Type definition for file information in scan results."""
|
||||
|
||||
@@ -35,6 +58,7 @@ class ScanResults(TypedDict):
|
||||
updated: int
|
||||
deleted: int
|
||||
skipped: int
|
||||
duplicates: int
|
||||
errors: int
|
||||
files: list[FileInfo]
|
||||
|
||||
@@ -56,6 +80,13 @@ class SoundScannerService:
|
||||
".aac",
|
||||
}
|
||||
|
||||
# Directory mappings for normalized files (matching sound_normalizer)
|
||||
self.normalized_directories = {
|
||||
"SDB": "sounds/normalized/soundboard",
|
||||
"TTS": "sounds/normalized/text_to_speech",
|
||||
"EXT": "sounds/normalized/extracted",
|
||||
}
|
||||
|
||||
def extract_name_from_filename(self, filename: str) -> str:
|
||||
"""Extract a clean name from filename."""
|
||||
# Remove extension
|
||||
@@ -65,6 +96,415 @@ class SoundScannerService:
|
||||
# Capitalize words
|
||||
return " ".join(word.capitalize() for word in name.split())
|
||||
|
||||
def _get_normalized_path(self, sound_type: str, filename: str) -> Path:
|
||||
"""Get the normalized file path for a sound."""
|
||||
directory = self.normalized_directories.get(
|
||||
sound_type, "sounds/normalized/other",
|
||||
)
|
||||
return Path(directory) / filename
|
||||
|
||||
def _rename_normalized_file(
|
||||
self, sound_type: str, old_filename: str, new_filename: str,
|
||||
) -> bool:
|
||||
"""Rename normalized file if exists. Returns True if renamed, else False."""
|
||||
old_path = self._get_normalized_path(sound_type, old_filename)
|
||||
new_path = self._get_normalized_path(sound_type, new_filename)
|
||||
|
||||
if old_path.exists():
|
||||
try:
|
||||
# Ensure the directory exists
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
old_path.rename(new_path)
|
||||
logger.info("Renamed normalized file: %s -> %s", old_path, new_path)
|
||||
except OSError:
|
||||
logger.exception(
|
||||
"Failed to rename normalized file %s -> %s",
|
||||
old_path,
|
||||
new_path,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _delete_normalized_file(self, sound_type: str, filename: str) -> bool:
|
||||
"""Delete normalized file if exists. Returns True if deleted, else False."""
|
||||
normalized_path = self._get_normalized_path(sound_type, filename)
|
||||
|
||||
if normalized_path.exists():
|
||||
try:
|
||||
normalized_path.unlink()
|
||||
logger.info("Deleted normalized file: %s", normalized_path)
|
||||
except OSError:
|
||||
logger.exception(
|
||||
"Failed to delete normalized file %s", normalized_path,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_sound_attributes(self, sound_data: dict | Sound | None) -> dict:
|
||||
"""Extract attributes from sound data (dict or Sound object)."""
|
||||
if sound_data is None:
|
||||
return {}
|
||||
|
||||
if isinstance(sound_data, dict):
|
||||
return {
|
||||
"filename": sound_data.get("filename"),
|
||||
"name": sound_data.get("name"),
|
||||
"duration": sound_data.get("duration"),
|
||||
"size": sound_data.get("size"),
|
||||
"id": sound_data.get("id"),
|
||||
"object": sound_data.get("sound_object"),
|
||||
"type": sound_data.get("type"),
|
||||
"is_normalized": sound_data.get("is_normalized"),
|
||||
"normalized_filename": sound_data.get("normalized_filename"),
|
||||
}
|
||||
# Sound object (for tests)
|
||||
return {
|
||||
"filename": sound_data.filename,
|
||||
"name": sound_data.name,
|
||||
"duration": sound_data.duration,
|
||||
"size": sound_data.size,
|
||||
"id": sound_data.id,
|
||||
"object": sound_data,
|
||||
"type": sound_data.type,
|
||||
"is_normalized": sound_data.is_normalized,
|
||||
"normalized_filename": sound_data.normalized_filename,
|
||||
}
|
||||
|
||||
def _handle_unchanged_file(
|
||||
self,
|
||||
filename: str,
|
||||
existing_attrs: dict,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle unchanged file (same hash, same filename)."""
|
||||
logger.debug("Sound unchanged: %s", filename)
|
||||
results["skipped"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "skipped",
|
||||
"reason": "file unchanged",
|
||||
"name": existing_attrs["name"],
|
||||
"duration": existing_attrs["duration"],
|
||||
"size": existing_attrs["size"],
|
||||
"id": existing_attrs["id"],
|
||||
"error": None,
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
def _handle_duplicate_file(
|
||||
self,
|
||||
filename: str,
|
||||
existing_filename: str,
|
||||
file_hash: str,
|
||||
existing_attrs: dict,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle duplicate file (same hash, different filename)."""
|
||||
logger.warning(
|
||||
"Duplicate file detected: '%s' has same content as existing "
|
||||
"'%s' (hash: %s). Skipping duplicate file.",
|
||||
filename,
|
||||
existing_filename,
|
||||
file_hash[:8] + "...",
|
||||
)
|
||||
results["skipped"] += 1
|
||||
results["duplicates"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "skipped",
|
||||
"reason": "duplicate content",
|
||||
"name": existing_attrs["name"],
|
||||
"duration": existing_attrs["duration"],
|
||||
"size": existing_attrs["size"],
|
||||
"id": existing_attrs["id"],
|
||||
"error": None,
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
async def _handle_file_rename(
|
||||
self,
|
||||
file_info: AudioFileInfo,
|
||||
existing_attrs: dict,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle file rename (same hash, different filename)."""
|
||||
update_data = {
|
||||
"filename": file_info.filename,
|
||||
"name": file_info.name,
|
||||
}
|
||||
|
||||
# If the sound has a normalized file, rename it too
|
||||
if existing_attrs["is_normalized"] and existing_attrs["normalized_filename"]:
|
||||
old_normalized_base = Path(existing_attrs["normalized_filename"]).name
|
||||
new_normalized_base = (
|
||||
Path(file_info.filename).stem
|
||||
+ Path(existing_attrs["normalized_filename"]).suffix
|
||||
)
|
||||
|
||||
renamed = self._rename_normalized_file(
|
||||
existing_attrs["type"],
|
||||
old_normalized_base,
|
||||
new_normalized_base,
|
||||
)
|
||||
|
||||
if renamed:
|
||||
update_data["normalized_filename"] = new_normalized_base
|
||||
logger.info(
|
||||
"Renamed normalized file: %s -> %s",
|
||||
old_normalized_base,
|
||||
new_normalized_base,
|
||||
)
|
||||
|
||||
await self.sound_repo.update(existing_attrs["object"], update_data)
|
||||
logger.info(
|
||||
"Detected rename: %s -> %s (ID: %s)",
|
||||
existing_attrs["filename"],
|
||||
file_info.filename,
|
||||
existing_attrs["id"],
|
||||
)
|
||||
|
||||
# Build changes list
|
||||
changes = ["filename", "name"]
|
||||
if "normalized_filename" in update_data:
|
||||
changes.append("normalized_filename")
|
||||
|
||||
results["updated"] += 1
|
||||
results["files"].append({
|
||||
"filename": file_info.filename,
|
||||
"status": "updated",
|
||||
"reason": "file was renamed",
|
||||
"name": file_info.name,
|
||||
"duration": existing_attrs["duration"],
|
||||
"size": existing_attrs["size"],
|
||||
"id": existing_attrs["id"],
|
||||
"error": None,
|
||||
"changes": changes,
|
||||
# Store old filename to prevent deletion
|
||||
"old_filename": existing_attrs["filename"],
|
||||
})
|
||||
|
||||
async def _handle_file_modification(
|
||||
self,
|
||||
file_info: AudioFileInfo,
|
||||
existing_attrs: dict,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle file modification (same filename, different hash)."""
|
||||
update_data = {
|
||||
"name": file_info.name,
|
||||
"duration": file_info.duration,
|
||||
"size": file_info.size,
|
||||
"hash": file_info.file_hash,
|
||||
}
|
||||
|
||||
await self.sound_repo.update(existing_attrs["object"], update_data)
|
||||
logger.info(
|
||||
"Updated modified sound: %s (ID: %s)",
|
||||
file_info.name,
|
||||
existing_attrs["id"],
|
||||
)
|
||||
|
||||
results["updated"] += 1
|
||||
results["files"].append({
|
||||
"filename": file_info.filename,
|
||||
"status": "updated",
|
||||
"reason": "file was modified",
|
||||
"name": file_info.name,
|
||||
"duration": file_info.duration,
|
||||
"size": file_info.size,
|
||||
"id": existing_attrs["id"],
|
||||
"error": None,
|
||||
"changes": ["hash", "duration", "size", "name"],
|
||||
})
|
||||
|
||||
async def _handle_new_file(
|
||||
self,
|
||||
file_info: AudioFileInfo,
|
||||
sound_type: str,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle new file (neither hash nor filename exists)."""
|
||||
sound_data = {
|
||||
"type": sound_type,
|
||||
"name": file_info.name,
|
||||
"filename": file_info.filename,
|
||||
"duration": file_info.duration,
|
||||
"size": file_info.size,
|
||||
"hash": file_info.file_hash,
|
||||
"is_deletable": False,
|
||||
"is_music": False,
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id)
|
||||
|
||||
results["added"] += 1
|
||||
results["files"].append({
|
||||
"filename": file_info.filename,
|
||||
"status": "added",
|
||||
"reason": None,
|
||||
"name": file_info.name,
|
||||
"duration": file_info.duration,
|
||||
"size": file_info.size,
|
||||
"id": sound.id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
async def _load_existing_sounds(self, sound_type: str) -> tuple[dict, dict]:
|
||||
"""Load existing sounds and create lookup dictionaries."""
|
||||
existing_sounds = await self.sound_repo.get_by_type(sound_type)
|
||||
|
||||
# Create lookup dictionaries with immediate attribute access
|
||||
# to avoid session detachment
|
||||
sounds_by_hash = {}
|
||||
sounds_by_filename = {}
|
||||
|
||||
for sound in existing_sounds:
|
||||
# Capture all attributes immediately while session is valid
|
||||
sound_data = {
|
||||
"id": sound.id,
|
||||
"hash": sound.hash,
|
||||
"filename": sound.filename,
|
||||
"name": sound.name,
|
||||
"duration": sound.duration,
|
||||
"size": sound.size,
|
||||
"type": sound.type,
|
||||
"is_normalized": sound.is_normalized,
|
||||
"normalized_filename": sound.normalized_filename,
|
||||
"sound_object": sound, # Keep reference for database operations
|
||||
}
|
||||
sounds_by_hash[sound.hash] = sound_data
|
||||
sounds_by_filename[sound.filename] = sound_data
|
||||
|
||||
return sounds_by_hash, sounds_by_filename
|
||||
|
||||
async def _process_audio_files(
|
||||
self,
|
||||
scan_path: Path,
|
||||
sound_type: str,
|
||||
sounds_by_hash: dict,
|
||||
sounds_by_filename: dict,
|
||||
results: ScanResults,
|
||||
) -> set[str]:
|
||||
"""Process all audio files in directory and return processed filenames."""
|
||||
# Get all audio files from directory
|
||||
audio_files = [
|
||||
f
|
||||
for f in scan_path.iterdir()
|
||||
if f.is_file() and f.suffix.lower() in self.supported_extensions
|
||||
]
|
||||
|
||||
# Process each file in directory
|
||||
processed_filenames = set()
|
||||
for file_path in audio_files:
|
||||
results["scanned"] += 1
|
||||
filename = file_path.name
|
||||
processed_filenames.add(filename)
|
||||
|
||||
try:
|
||||
# Calculate hash first to enable hash-based lookup
|
||||
file_hash = get_file_hash(file_path)
|
||||
existing_sound_by_hash = sounds_by_hash.get(file_hash)
|
||||
existing_sound_by_filename = sounds_by_filename.get(filename)
|
||||
|
||||
# Create sync context
|
||||
sync_context = SyncContext(
|
||||
file_path=file_path,
|
||||
sound_type=sound_type,
|
||||
existing_sound_by_hash=existing_sound_by_hash,
|
||||
existing_sound_by_filename=existing_sound_by_filename,
|
||||
file_hash=file_hash,
|
||||
)
|
||||
|
||||
await self._sync_audio_file(sync_context, results)
|
||||
|
||||
# Check if this was a rename and mark old filename as processed
|
||||
if results["files"] and results["files"][-1].get("old_filename"):
|
||||
old_filename = results["files"][-1]["old_filename"]
|
||||
processed_filenames.add(old_filename)
|
||||
logger.debug("Marked old filename as processed: %s", old_filename)
|
||||
# Remove temporary tracking field from results
|
||||
del results["files"][-1]["old_filename"]
|
||||
except Exception as e:
|
||||
logger.exception("Error processing file %s", file_path)
|
||||
results["errors"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "error",
|
||||
"reason": None,
|
||||
"name": None,
|
||||
"duration": None,
|
||||
"size": None,
|
||||
"id": None,
|
||||
"error": str(e),
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
return processed_filenames
|
||||
|
||||
async def _delete_missing_sounds(
|
||||
self,
|
||||
sounds_by_filename: dict,
|
||||
processed_filenames: set[str],
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Delete sounds that no longer exist in directory."""
|
||||
for filename, sound_data in sounds_by_filename.items():
|
||||
if filename not in processed_filenames:
|
||||
# Attributes already captured in sound_data dictionary
|
||||
sound_name = sound_data["name"]
|
||||
sound_duration = sound_data["duration"]
|
||||
sound_size = sound_data["size"]
|
||||
sound_id = sound_data["id"]
|
||||
sound_object = sound_data["sound_object"]
|
||||
sound_type = sound_data["type"]
|
||||
sound_is_normalized = sound_data["is_normalized"]
|
||||
sound_normalized_filename = sound_data["normalized_filename"]
|
||||
|
||||
try:
|
||||
# Delete the sound from database first
|
||||
await self.sound_repo.delete(sound_object)
|
||||
logger.info("Deleted sound no longer in directory: %s", filename)
|
||||
|
||||
# If the sound had a normalized file, delete it too
|
||||
if sound_is_normalized and sound_normalized_filename:
|
||||
normalized_base = Path(sound_normalized_filename).name
|
||||
self._delete_normalized_file(sound_type, normalized_base)
|
||||
|
||||
results["deleted"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "deleted",
|
||||
"reason": "file no longer exists",
|
||||
"name": sound_name,
|
||||
"duration": sound_duration,
|
||||
"size": sound_size,
|
||||
"id": sound_id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting sound %s", filename)
|
||||
results["errors"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "error",
|
||||
"reason": "failed to delete",
|
||||
"name": sound_name,
|
||||
"duration": sound_duration,
|
||||
"size": sound_size,
|
||||
"id": sound_id,
|
||||
"error": str(e),
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
async def scan_directory(
|
||||
self,
|
||||
directory_path: str,
|
||||
@@ -87,185 +527,91 @@ class SoundScannerService:
|
||||
"updated": 0,
|
||||
"deleted": 0,
|
||||
"skipped": 0,
|
||||
"duplicates": 0,
|
||||
"errors": 0,
|
||||
"files": [],
|
||||
}
|
||||
|
||||
logger.info("Starting sync of directory: %s", directory_path)
|
||||
|
||||
# Get all existing sounds of this type from database
|
||||
existing_sounds = await self.sound_repo.get_by_type(sound_type)
|
||||
sounds_by_filename = {sound.filename: sound for sound in existing_sounds}
|
||||
# Load existing sounds from database
|
||||
sounds_by_hash, sounds_by_filename = await self._load_existing_sounds(
|
||||
sound_type,
|
||||
)
|
||||
|
||||
# Get all audio files from directory
|
||||
audio_files = [
|
||||
f
|
||||
for f in scan_path.iterdir()
|
||||
if f.is_file() and f.suffix.lower() in self.supported_extensions
|
||||
]
|
||||
|
||||
# Process each file in directory
|
||||
processed_filenames = set()
|
||||
for file_path in audio_files:
|
||||
results["scanned"] += 1
|
||||
filename = file_path.name
|
||||
processed_filenames.add(filename)
|
||||
|
||||
try:
|
||||
await self._sync_audio_file(
|
||||
file_path,
|
||||
sound_type,
|
||||
sounds_by_filename.get(filename),
|
||||
results,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error processing file %s", file_path)
|
||||
results["errors"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "error",
|
||||
"reason": None,
|
||||
"name": None,
|
||||
"duration": None,
|
||||
"size": None,
|
||||
"id": None,
|
||||
"error": str(e),
|
||||
"changes": None,
|
||||
}
|
||||
)
|
||||
# Process audio files in directory
|
||||
processed_filenames = await self._process_audio_files(
|
||||
scan_path,
|
||||
sound_type,
|
||||
sounds_by_hash,
|
||||
sounds_by_filename,
|
||||
results,
|
||||
)
|
||||
|
||||
# Delete sounds that no longer exist in directory
|
||||
for filename, sound in sounds_by_filename.items():
|
||||
if filename not in processed_filenames:
|
||||
try:
|
||||
await self.sound_repo.delete(sound)
|
||||
logger.info("Deleted sound no longer in directory: %s", filename)
|
||||
results["deleted"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "deleted",
|
||||
"reason": "file no longer exists",
|
||||
"name": sound.name,
|
||||
"duration": sound.duration,
|
||||
"size": sound.size,
|
||||
"id": sound.id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting sound %s", filename)
|
||||
results["errors"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "error",
|
||||
"reason": "failed to delete",
|
||||
"name": sound.name,
|
||||
"duration": sound.duration,
|
||||
"size": sound.size,
|
||||
"id": sound.id,
|
||||
"error": str(e),
|
||||
"changes": None,
|
||||
}
|
||||
)
|
||||
await self._delete_missing_sounds(
|
||||
sounds_by_filename,
|
||||
processed_filenames,
|
||||
results,
|
||||
)
|
||||
|
||||
logger.info("Sync completed: %s", results)
|
||||
return results
|
||||
|
||||
async def _sync_audio_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
sound_type: str,
|
||||
existing_sound: Sound | None,
|
||||
sync_context: SyncContext,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Sync a single audio file (add new or update existing)."""
|
||||
filename = file_path.name
|
||||
file_hash = get_file_hash(file_path)
|
||||
duration = get_audio_duration(file_path)
|
||||
size = get_file_size(file_path)
|
||||
"""Sync a single audio file using hash-first identification strategy."""
|
||||
filename = sync_context.file_path.name
|
||||
duration = get_audio_duration(sync_context.file_path)
|
||||
size = get_file_size(sync_context.file_path)
|
||||
name = self.extract_name_from_filename(filename)
|
||||
|
||||
if existing_sound is None:
|
||||
# Add new sound
|
||||
sound_data = {
|
||||
"type": sound_type,
|
||||
"name": name,
|
||||
"filename": filename,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"hash": file_hash,
|
||||
"is_deletable": False,
|
||||
"is_music": False,
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
# Create file info object
|
||||
file_info = AudioFileInfo(
|
||||
filename=filename,
|
||||
name=name,
|
||||
duration=duration,
|
||||
size=size,
|
||||
file_hash=sync_context.file_hash,
|
||||
)
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id)
|
||||
# Extract attributes from existing sounds
|
||||
hash_attrs = self._extract_sound_attributes(sync_context.existing_sound_by_hash)
|
||||
filename_attrs = self._extract_sound_attributes(
|
||||
sync_context.existing_sound_by_filename,
|
||||
)
|
||||
|
||||
results["added"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "added",
|
||||
"reason": None,
|
||||
"name": name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"id": sound.id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
}
|
||||
)
|
||||
|
||||
elif existing_sound.hash != file_hash:
|
||||
# Update existing sound (file was modified)
|
||||
update_data = {
|
||||
"name": name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"hash": file_hash,
|
||||
}
|
||||
|
||||
await self.sound_repo.update(existing_sound, update_data)
|
||||
logger.info("Updated modified sound: %s (ID: %s)", name, existing_sound.id)
|
||||
|
||||
results["updated"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "updated",
|
||||
"reason": "file was modified",
|
||||
"name": name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"id": existing_sound.id,
|
||||
"error": None,
|
||||
"changes": ["hash", "duration", "size", "name"],
|
||||
}
|
||||
)
|
||||
# Hash-first identification strategy
|
||||
if sync_context.existing_sound_by_hash is not None:
|
||||
# Content exists in database (same hash)
|
||||
if hash_attrs["filename"] == filename:
|
||||
# Same hash, same filename - file unchanged
|
||||
self._handle_unchanged_file(filename, hash_attrs, results)
|
||||
else:
|
||||
# Same hash, different filename - could be rename or duplicate
|
||||
old_file_path = sync_context.file_path.parent / hash_attrs["filename"]
|
||||
if old_file_path.exists():
|
||||
# Both files exist with same hash - this is a duplicate
|
||||
self._handle_duplicate_file(
|
||||
filename,
|
||||
hash_attrs["filename"],
|
||||
sync_context.file_hash,
|
||||
hash_attrs,
|
||||
results,
|
||||
)
|
||||
else:
|
||||
# Old file doesn't exist - this is a genuine rename
|
||||
await self._handle_file_rename(file_info, hash_attrs, results)
|
||||
|
||||
elif sync_context.existing_sound_by_filename is not None:
|
||||
# Same filename but different hash - file was modified
|
||||
await self._handle_file_modification(file_info, filename_attrs, results)
|
||||
else:
|
||||
# File unchanged, skip
|
||||
logger.debug("Sound unchanged: %s", filename)
|
||||
results["skipped"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "skipped",
|
||||
"reason": "file unchanged",
|
||||
"name": existing_sound.name,
|
||||
"duration": existing_sound.duration,
|
||||
"size": existing_sound.size,
|
||||
"id": existing_sound.id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
}
|
||||
)
|
||||
# New file - neither hash nor filename exists
|
||||
await self._handle_new_file(file_info, sync_context.sound_type, results)
|
||||
|
||||
async def scan_soundboard_directory(self) -> ScanResults:
|
||||
"""Sync the default soundboard directory."""
|
||||
|
||||
194
app/services/task_handlers.py
Normal file
194
app/services/task_handlers.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Task execution handlers for different task types."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.scheduled_task import ScheduledTask, TaskType
|
||||
from app.repositories.playlist import PlaylistRepository
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.services.credit import CreditService
|
||||
from app.services.player import PlayerService
|
||||
from app.services.vlc_player import VLCPlayerService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TaskExecutionError(Exception):
|
||||
"""Exception raised when task execution fails."""
|
||||
|
||||
|
||||
|
||||
class TaskHandlerRegistry:
|
||||
"""Registry for task execution handlers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
db_session_factory: Callable[[], AsyncSession],
|
||||
credit_service: CreditService,
|
||||
player_service: PlayerService,
|
||||
) -> None:
|
||||
"""Initialize the task handler registry."""
|
||||
self.db_session = db_session
|
||||
self.db_session_factory = db_session_factory
|
||||
self.credit_service = credit_service
|
||||
self.player_service = player_service
|
||||
self.sound_repository = SoundRepository(db_session)
|
||||
self.playlist_repository = PlaylistRepository(db_session)
|
||||
|
||||
# Register handlers
|
||||
self._handlers = {
|
||||
TaskType.CREDIT_RECHARGE: self._handle_credit_recharge,
|
||||
TaskType.PLAY_SOUND: self._handle_play_sound,
|
||||
TaskType.PLAY_PLAYLIST: self._handle_play_playlist,
|
||||
}
|
||||
|
||||
async def execute_task(self, task: ScheduledTask) -> None:
|
||||
"""Execute a task based on its type."""
|
||||
handler = self._handlers.get(task.task_type)
|
||||
if not handler:
|
||||
msg = f"No handler registered for task type: {task.task_type}"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
logger.info(
|
||||
"Executing task %s (%s): %s",
|
||||
task.id,
|
||||
task.task_type.value,
|
||||
task.name,
|
||||
)
|
||||
|
||||
try:
|
||||
await handler(task)
|
||||
logger.info("Task %s executed successfully", task.id)
|
||||
except Exception as e:
|
||||
logger.exception("Task %s execution failed", task.id)
|
||||
msg = f"Task execution failed: {e!s}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
|
||||
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
|
||||
"""Handle credit recharge task."""
|
||||
parameters = task.parameters
|
||||
user_id = parameters.get("user_id")
|
||||
|
||||
if user_id:
|
||||
# Recharge specific user
|
||||
try:
|
||||
user_id_int = int(user_id)
|
||||
except (ValueError, TypeError) as e:
|
||||
msg = f"Invalid user_id format: {user_id}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
|
||||
transaction = await self.credit_service.recharge_user_credits_auto(
|
||||
user_id_int,
|
||||
)
|
||||
if transaction:
|
||||
logger.info(
|
||||
"Recharged credits for user %s: %s credits added",
|
||||
user_id,
|
||||
transaction.amount,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No credits added for user %s (already at maximum)", user_id,
|
||||
)
|
||||
else:
|
||||
# Recharge all users (system task)
|
||||
stats = await self.credit_service.recharge_all_users_credits()
|
||||
logger.info("Recharged credits for all users: %s", stats)
|
||||
|
||||
async def _handle_play_sound(self, task: ScheduledTask) -> None:
|
||||
"""Handle play sound task."""
|
||||
parameters = task.parameters
|
||||
sound_id = parameters.get("sound_id")
|
||||
|
||||
if not sound_id:
|
||||
msg = "sound_id parameter is required for PLAY_SOUND tasks"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
try:
|
||||
# Handle both integer and string sound IDs
|
||||
sound_id_int = int(sound_id)
|
||||
except (ValueError, TypeError) as e:
|
||||
msg = f"Invalid sound_id format: {sound_id}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
|
||||
# Check if this is a user task (has user_id)
|
||||
if task.user_id:
|
||||
# User task: use credit-aware playback
|
||||
|
||||
vlc_service = VLCPlayerService(self.db_session_factory)
|
||||
try:
|
||||
result = await vlc_service.play_sound_with_credits(
|
||||
sound_id_int, task.user_id,
|
||||
)
|
||||
logger.info(
|
||||
(
|
||||
"Played sound %s via scheduled task for user %s "
|
||||
"(credits deducted: %s)"
|
||||
),
|
||||
result.get("sound_name", sound_id),
|
||||
task.user_id,
|
||||
result.get("credits_deducted", 0),
|
||||
)
|
||||
except Exception as e:
|
||||
# Convert HTTP exceptions or credit errors to task execution errors
|
||||
msg = f"Failed to play sound with credits: {e!s}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
else:
|
||||
# System task: play without credit deduction
|
||||
sound = await self.sound_repository.get_by_id(sound_id_int)
|
||||
if not sound:
|
||||
msg = f"Sound not found: {sound_id}"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
|
||||
vlc_service = VLCPlayerService(self.db_session_factory)
|
||||
success = await vlc_service.play_sound(sound)
|
||||
|
||||
if not success:
|
||||
msg = f"Failed to play sound {sound.filename}"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
logger.info("Played sound %s via scheduled system task", sound.filename)
|
||||
|
||||
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
|
||||
"""Handle play playlist task."""
|
||||
parameters = task.parameters
|
||||
playlist_id = parameters.get("playlist_id")
|
||||
play_mode = parameters.get("play_mode", "continuous")
|
||||
shuffle = parameters.get("shuffle", False)
|
||||
|
||||
if not playlist_id:
|
||||
msg = "playlist_id parameter is required for PLAY_PLAYLIST tasks"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
try:
|
||||
# Handle both integer and string playlist IDs
|
||||
playlist_id_int = int(playlist_id)
|
||||
except (ValueError, TypeError) as e:
|
||||
msg = f"Invalid playlist_id format: {playlist_id}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
|
||||
# Get the playlist from database
|
||||
playlist = await self.playlist_repository.get_by_id(playlist_id_int)
|
||||
if not playlist:
|
||||
msg = f"Playlist not found: {playlist_id}"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
# Load playlist in player
|
||||
await self.player_service.load_playlist(playlist_id_int)
|
||||
|
||||
# Set play mode if specified
|
||||
if play_mode in ["continuous", "loop", "loop_one", "random", "single"]:
|
||||
await self.player_service.set_mode(play_mode)
|
||||
|
||||
# Enable shuffle if requested
|
||||
if shuffle:
|
||||
await self.player_service.set_shuffle(shuffle=True)
|
||||
|
||||
# Start playing
|
||||
await self.player_service.play()
|
||||
|
||||
logger.info("Started playing playlist %s via scheduled task", playlist.name)
|
||||
6
app/services/tts/__init__.py
Normal file
6
app/services/tts/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Text-to-Speech services package."""
|
||||
|
||||
from .base import TTSProvider
|
||||
from .service import TTSService
|
||||
|
||||
__all__ = ["TTSProvider", "TTSService"]
|
||||
41
app/services/tts/base.py
Normal file
41
app/services/tts/base.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Base TTS provider interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# Type alias for TTS options
|
||||
TTSOptions = dict[str, str | bool | int | float]
|
||||
|
||||
|
||||
class TTSProvider(ABC):
|
||||
"""Abstract base class for TTS providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
|
||||
"""Generate speech from text with provider-specific options.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech
|
||||
**options: Provider-specific options
|
||||
|
||||
Returns:
|
||||
Audio data as bytes
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_languages(self) -> list[str]:
|
||||
"""Return list of supported language codes."""
|
||||
|
||||
@abstractmethod
|
||||
def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
|
||||
"""Return schema for provider-specific options."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def file_extension(self) -> str:
|
||||
"""Return the default file extension for this provider."""
|
||||
5
app/services/tts/providers/__init__.py
Normal file
5
app/services/tts/providers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""TTS providers package."""
|
||||
|
||||
from .gtts import GTTSProvider
|
||||
|
||||
__all__ = ["GTTSProvider"]
|
||||
80
app/services/tts/providers/gtts.py
Normal file
80
app/services/tts/providers/gtts.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Google Text-to-Speech provider."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
|
||||
from gtts import gTTS
|
||||
|
||||
from app.services.tts.base import TTSProvider
|
||||
|
||||
|
||||
class GTTSProvider(TTSProvider):
|
||||
"""Google Text-to-Speech provider implementation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "gtts"
|
||||
|
||||
@property
|
||||
def file_extension(self) -> str:
|
||||
"""Return the default file extension for this provider."""
|
||||
return "mp3"
|
||||
|
||||
async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
|
||||
"""Generate speech from text using Google TTS.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech
|
||||
**options: GTTS-specific options (lang, tld, slow)
|
||||
|
||||
Returns:
|
||||
MP3 audio data as bytes
|
||||
|
||||
"""
|
||||
lang = options.get("lang", "en")
|
||||
tld = options.get("tld", "com")
|
||||
slow = options.get("slow", False)
|
||||
|
||||
# Run TTS generation in thread pool since gTTS is synchronous
|
||||
def _generate() -> bytes:
|
||||
tts = gTTS(text=text, lang=lang, tld=tld, slow=slow)
|
||||
fp = io.BytesIO()
|
||||
tts.write_to_fp(fp)
|
||||
fp.seek(0)
|
||||
return fp.read()
|
||||
|
||||
# Use asyncio.to_thread which is more reliable than run_in_executor
|
||||
return await asyncio.to_thread(_generate)
|
||||
|
||||
def get_supported_languages(self) -> list[str]:
|
||||
"""Return list of supported language codes."""
|
||||
# Complete list of GTTS supported languages including regional variants
|
||||
return [
|
||||
"af", "ar", "bg", "bn", "bs", "ca", "cs", "cy", "da", "de", "el",
|
||||
"en", "en-au", "en-ca", "en-gb", "en-ie", "en-in", "en-ng", "en-nz",
|
||||
"en-ph", "en-za", "en-tz", "en-uk", "en-us",
|
||||
"eo", "es", "es-es", "es-mx", "es-us", "et", "eu", "fa", "fi",
|
||||
"fr", "fr-ca", "fr-fr", "ga", "gu", "he", "hi", "hr", "hu", "hy",
|
||||
"id", "is", "it", "ja", "jw", "ka", "kk", "km", "kn", "ko", "la",
|
||||
"lv", "mk", "ml", "mr", "ms", "mt", "my", "ne", "nl", "no", "pa",
|
||||
"pl", "pt", "pt-br", "pt-pt", "ro", "ru", "si", "sk", "sl", "sq",
|
||||
"sr", "su", "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur",
|
||||
"vi", "yo", "zh", "zh-cn", "zh-tw", "zu",
|
||||
]
|
||||
|
||||
def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
|
||||
"""Return schema for GTTS-specific options."""
|
||||
return {
|
||||
"lang": {
|
||||
"type": "string",
|
||||
"default": "en",
|
||||
"description": "Language code",
|
||||
"enum": self.get_supported_languages(),
|
||||
},
|
||||
"slow": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "Speak slowly",
|
||||
},
|
||||
}
|
||||
555
app/services/tts/service.py
Normal file
555
app/services/tts/service.py
Normal file
@@ -0,0 +1,555 @@
|
||||
"""TTS service implementation."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from gtts import gTTS
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_session_factory
|
||||
from app.core.logging import get_logger
|
||||
from app.models.sound import Sound
|
||||
from app.models.tts import TTS
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.repositories.tts import TTSRepository
|
||||
from app.services.socket import socket_manager
|
||||
from app.services.sound_normalizer import SoundNormalizerService
|
||||
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||
|
||||
from .base import TTSProvider
|
||||
from .providers import GTTSProvider
|
||||
|
||||
# Constants
|
||||
MAX_TEXT_LENGTH = 1000
|
||||
MAX_NAME_LENGTH = 50
|
||||
|
||||
|
||||
async def _get_tts_processor() -> object:
|
||||
"""Get TTS processor instance, avoiding circular import."""
|
||||
from app.services.tts_processor import tts_processor # noqa: PLC0415
|
||||
return tts_processor
|
||||
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-Speech service with provider management."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize TTS service.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
||||
"""
|
||||
self.session = session
|
||||
self.sound_repo = SoundRepository(session)
|
||||
self.tts_repo = TTSRepository(session)
|
||||
self.providers: dict[str, TTSProvider] = {}
|
||||
|
||||
# Register default providers
|
||||
self._register_default_providers()
|
||||
|
||||
def _register_default_providers(self) -> None:
|
||||
"""Register default TTS providers."""
|
||||
self.register_provider(GTTSProvider())
|
||||
|
||||
def register_provider(self, provider: TTSProvider) -> None:
|
||||
"""Register a TTS provider.
|
||||
|
||||
Args:
|
||||
provider: TTS provider instance
|
||||
|
||||
"""
|
||||
self.providers[provider.name] = provider
|
||||
|
||||
def get_providers(self) -> dict[str, TTSProvider]:
|
||||
"""Get all registered providers."""
|
||||
return self.providers.copy()
|
||||
|
||||
def get_provider(self, name: str) -> TTSProvider | None:
|
||||
"""Get a specific provider by name."""
|
||||
return self.providers.get(name)
|
||||
|
||||
async def create_tts_request(
|
||||
self,
|
||||
text: str,
|
||||
user_id: int,
|
||||
provider: str = "gtts",
|
||||
**options: str | bool | float,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a TTS request that will be processed in the background.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
user_id: ID of user creating the sound
|
||||
provider: TTS provider name
|
||||
**options: Provider-specific options
|
||||
|
||||
Returns:
|
||||
Dictionary with TTS record information
|
||||
|
||||
Raises:
|
||||
ValueError: If provider not found or text too long
|
||||
Exception: If request creation fails
|
||||
|
||||
"""
|
||||
provider_not_found_msg = f"Provider '{provider}' not found"
|
||||
if provider not in self.providers:
|
||||
raise ValueError(provider_not_found_msg)
|
||||
|
||||
text_too_long_msg = f"Text too long (max {MAX_TEXT_LENGTH} characters)"
|
||||
if len(text) > MAX_TEXT_LENGTH:
|
||||
raise ValueError(text_too_long_msg)
|
||||
|
||||
empty_text_msg = "Text cannot be empty"
|
||||
if not text.strip():
|
||||
raise ValueError(empty_text_msg)
|
||||
|
||||
# Create TTS record with pending status
|
||||
tts = TTS(
|
||||
text=text,
|
||||
provider=provider,
|
||||
options=options,
|
||||
status="pending",
|
||||
sound_id=None, # Will be set when processing completes
|
||||
user_id=user_id,
|
||||
)
|
||||
self.session.add(tts)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(tts)
|
||||
|
||||
# Queue for background processing using the TTS processor
|
||||
if tts.id is not None:
|
||||
tts_processor = await _get_tts_processor()
|
||||
await tts_processor.queue_tts(tts.id)
|
||||
|
||||
return {"tts": tts, "message": "TTS generation queued successfully"}
|
||||
|
||||
async def _queue_tts_processing(self, tts_id: int) -> None:
|
||||
"""Queue TTS for background processing."""
|
||||
# For now, process immediately in a different way
|
||||
# This could be moved to a proper background queue later
|
||||
task = asyncio.create_task(self._process_tts_in_background(tts_id))
|
||||
# Store reference to prevent garbage collection
|
||||
self._background_tasks = getattr(self, "_background_tasks", set())
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _process_tts_in_background(self, tts_id: int) -> None:
|
||||
"""Process TTS generation in background."""
|
||||
try:
|
||||
# Create a new session for background processing
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as background_session:
|
||||
tts_service = TTSService(background_session)
|
||||
|
||||
# Get the TTS record
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await background_session.exec(stmt)
|
||||
tts = result.first()
|
||||
|
||||
if not tts:
|
||||
return
|
||||
|
||||
# Use a synchronous approach for the actual generation
|
||||
sound = await tts_service._generate_tts_sync(
|
||||
tts.text,
|
||||
tts.provider,
|
||||
tts.user_id,
|
||||
tts.options,
|
||||
)
|
||||
|
||||
# Update the TTS record with the sound ID
|
||||
if sound.id is not None:
|
||||
tts.sound_id = sound.id
|
||||
background_session.add(tts)
|
||||
await background_session.commit()
|
||||
|
||||
except Exception:
|
||||
# Log error but don't fail - avoiding print for production
|
||||
logger = get_logger(__name__)
|
||||
logger.exception("Error processing TTS generation %s", tts_id)
|
||||
|
||||
async def _generate_tts_sync(
|
||||
self,
|
||||
text: str,
|
||||
provider: str,
|
||||
user_id: int,
|
||||
options: dict[str, Any],
|
||||
) -> Sound:
|
||||
"""Generate TTS using a synchronous approach."""
|
||||
# Generate the audio using the provider
|
||||
# (avoid async issues by doing it directly)
|
||||
tts_provider = self.providers[provider]
|
||||
|
||||
# Create directories if they don't exist
|
||||
original_dir = Path("sounds/originals/text_to_speech")
|
||||
original_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create UUID filename
|
||||
sound_uuid = str(uuid.uuid4())
|
||||
original_filename = f"{sound_uuid}.{tts_provider.file_extension}"
|
||||
original_path = original_dir / original_filename
|
||||
|
||||
# Generate audio synchronously
|
||||
try:
|
||||
# Generate TTS audio
|
||||
lang = options.get("lang", "en")
|
||||
tld = options.get("tld", "com")
|
||||
slow = options.get("slow", False)
|
||||
|
||||
tts_instance = gTTS(text=text, lang=lang, tld=tld, slow=slow)
|
||||
fp = io.BytesIO()
|
||||
tts_instance.write_to_fp(fp)
|
||||
fp.seek(0)
|
||||
audio_bytes = fp.read()
|
||||
|
||||
# Save the file
|
||||
original_path.write_bytes(audio_bytes)
|
||||
|
||||
except Exception:
|
||||
logger = get_logger(__name__)
|
||||
logger.exception("Error generating TTS audio")
|
||||
raise
|
||||
|
||||
# Create Sound record with proper metadata
|
||||
sound = await self._create_sound_record_complete(
|
||||
original_path,
|
||||
text,
|
||||
user_id,
|
||||
)
|
||||
|
||||
# Normalize the sound
|
||||
if sound.id is not None:
|
||||
await self._normalize_sound_safe(sound.id)
|
||||
|
||||
return sound
|
||||
|
||||
async def get_user_tts_history(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[TTS]:
|
||||
"""Get TTS history for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
limit: Maximum number of records
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of TTS records
|
||||
|
||||
"""
|
||||
result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
|
||||
return list(result)
|
||||
|
||||
async def _create_sound_record(
|
||||
self,
|
||||
audio_path: Path,
|
||||
text: str,
|
||||
user_id: int,
|
||||
file_hash: str,
|
||||
) -> Sound:
|
||||
"""Create a Sound record for the TTS audio."""
|
||||
# Get audio metadata
|
||||
duration = get_audio_duration(audio_path)
|
||||
size = get_file_size(audio_path)
|
||||
name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "")
|
||||
name = " ".join(word.capitalize() for word in name.split())
|
||||
|
||||
# Create sound data
|
||||
sound_data = {
|
||||
"type": "TTS",
|
||||
"name": name,
|
||||
"filename": audio_path.name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"hash": file_hash,
|
||||
"user_id": user_id,
|
||||
"is_deletable": True,
|
||||
"is_music": False, # TTS is speech, not music
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
return await self.sound_repo.create(sound_data)
|
||||
|
||||
async def _create_sound_record_simple(
|
||||
self,
|
||||
audio_path: Path,
|
||||
text: str,
|
||||
user_id: int,
|
||||
) -> Sound:
|
||||
"""Create a Sound record for the TTS audio with minimal processing."""
|
||||
# Create sound data with basic info
|
||||
name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "")
|
||||
name = " ".join(word.capitalize() for word in name.split())
|
||||
|
||||
sound_data = {
|
||||
"type": "TTS",
|
||||
"name": name,
|
||||
"filename": audio_path.name,
|
||||
"duration": 0, # Skip duration calculation for now
|
||||
"size": 0, # Skip size calculation for now
|
||||
"hash": str(uuid.uuid4()), # Use UUID as temporary hash
|
||||
"user_id": user_id,
|
||||
"is_deletable": True,
|
||||
"is_music": False, # TTS is speech, not music
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
return await self.sound_repo.create(sound_data)
|
||||
|
||||
async def _create_sound_record_complete(
|
||||
self,
|
||||
audio_path: Path,
|
||||
text: str,
|
||||
user_id: int,
|
||||
) -> Sound:
|
||||
"""Create a Sound record for the TTS audio with complete metadata."""
|
||||
# Get audio metadata
|
||||
duration = get_audio_duration(audio_path)
|
||||
size = get_file_size(audio_path)
|
||||
file_hash = get_file_hash(audio_path)
|
||||
name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "")
|
||||
name = " ".join(word.capitalize() for word in name.split())
|
||||
|
||||
# Check if a sound with this hash already exists
|
||||
existing_sound = await self.sound_repo.get_by_hash(file_hash)
|
||||
|
||||
if existing_sound:
|
||||
# Clean up the temporary file since we have a duplicate
|
||||
if audio_path.exists():
|
||||
audio_path.unlink()
|
||||
return existing_sound
|
||||
|
||||
# Create sound data with complete metadata
|
||||
sound_data = {
|
||||
"type": "TTS",
|
||||
"name": name,
|
||||
"filename": audio_path.name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"hash": file_hash,
|
||||
"user_id": user_id,
|
||||
"is_deletable": True,
|
||||
"is_music": False, # TTS is speech, not music
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
return await self.sound_repo.create(sound_data)
|
||||
|
||||
async def _normalize_sound_safe(self, sound_id: int) -> None:
|
||||
"""Normalize the TTS sound with error handling."""
|
||||
try:
|
||||
# Get fresh sound object from database for normalization
|
||||
sound = await self.sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
return
|
||||
|
||||
normalizer_service = SoundNormalizerService(self.session)
|
||||
result = await normalizer_service.normalize_sound(sound)
|
||||
|
||||
if result["status"] == "error":
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(
|
||||
"Warning: Failed to normalize TTS sound %s: %s",
|
||||
sound_id,
|
||||
result.get("error"),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger = get_logger(__name__)
|
||||
logger.exception("Exception during TTS sound normalization %s", sound_id)
|
||||
# Don't fail the TTS generation if normalization fails
|
||||
|
||||
async def _normalize_sound(self, sound_id: int) -> None:
|
||||
"""Normalize the TTS sound."""
|
||||
try:
|
||||
# Get fresh sound object from database for normalization
|
||||
sound = await self.sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
return
|
||||
|
||||
normalizer_service = SoundNormalizerService(self.session)
|
||||
result = await normalizer_service.normalize_sound(sound)
|
||||
|
||||
if result["status"] == "error":
|
||||
# Log warning but don't fail the TTS generation
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
# Don't fail the TTS generation if normalization fails
|
||||
logger = get_logger(__name__)
|
||||
logger.exception("Error normalizing sound %s", sound_id)
|
||||
|
||||
async def delete_tts(self, tts_id: int, user_id: int) -> None:
|
||||
"""Delete a TTS generation and its associated sound and files."""
|
||||
# Get the TTS record
|
||||
tts = await self.tts_repo.get_by_id(tts_id)
|
||||
if not tts:
|
||||
tts_not_found_msg = f"TTS with ID {tts_id} not found"
|
||||
raise ValueError(tts_not_found_msg)
|
||||
|
||||
# Check ownership
|
||||
if tts.user_id != user_id:
|
||||
permission_error_msg = (
|
||||
"You don't have permission to delete this TTS generation"
|
||||
)
|
||||
raise PermissionError(permission_error_msg)
|
||||
|
||||
# If there's an associated sound, delete it and its files
|
||||
if tts.sound_id:
|
||||
sound = await self.sound_repo.get_by_id(tts.sound_id)
|
||||
if sound:
|
||||
# Delete the sound files
|
||||
await self._delete_sound_files(sound)
|
||||
# Delete the sound record
|
||||
await self.sound_repo.delete(sound)
|
||||
|
||||
# Delete the TTS record
|
||||
await self.tts_repo.delete(tts)
|
||||
|
||||
async def _delete_sound_files(self, sound: Sound) -> None:
|
||||
"""Delete all files associated with a sound."""
|
||||
# Delete original file
|
||||
original_path = Path("sounds/originals/text_to_speech") / sound.filename
|
||||
if original_path.exists():
|
||||
original_path.unlink()
|
||||
|
||||
# Delete normalized file if it exists
|
||||
if sound.normalized_filename:
|
||||
normalized_path = (
|
||||
Path("sounds/normalized/text_to_speech") / sound.normalized_filename
|
||||
)
|
||||
if normalized_path.exists():
|
||||
normalized_path.unlink()
|
||||
|
||||
async def get_pending_tts(self) -> list[TTS]:
|
||||
"""Get all pending TTS generations."""
|
||||
stmt = select(TTS).where(TTS.status == "pending").order_by(TTS.created_at)
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
|
||||
async def mark_tts_processing(self, tts_id: int) -> None:
|
||||
"""Mark a TTS generation as processing."""
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await self.session.exec(stmt)
|
||||
tts = result.first()
|
||||
if tts:
|
||||
tts.status = "processing"
|
||||
self.session.add(tts)
|
||||
await self.session.commit()
|
||||
|
||||
async def mark_tts_completed(self, tts_id: int, sound_id: int) -> None:
|
||||
"""Mark a TTS generation as completed."""
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await self.session.exec(stmt)
|
||||
tts = result.first()
|
||||
if tts:
|
||||
tts.status = "completed"
|
||||
tts.sound_id = sound_id
|
||||
tts.error = None
|
||||
self.session.add(tts)
|
||||
await self.session.commit()
|
||||
|
||||
async def mark_tts_failed(self, tts_id: int, error_message: str) -> None:
|
||||
"""Mark a TTS generation as failed."""
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await self.session.exec(stmt)
|
||||
tts = result.first()
|
||||
if tts:
|
||||
tts.status = "failed"
|
||||
tts.error = error_message
|
||||
self.session.add(tts)
|
||||
await self.session.commit()
|
||||
|
||||
async def reset_stuck_tts(self) -> int:
|
||||
"""Reset stuck TTS generations from processing back to pending."""
|
||||
stmt = select(TTS).where(TTS.status == "processing")
|
||||
result = await self.session.exec(stmt)
|
||||
stuck_tts = list(result.all())
|
||||
|
||||
for tts in stuck_tts:
|
||||
tts.status = "pending"
|
||||
tts.error = None
|
||||
self.session.add(tts)
|
||||
|
||||
await self.session.commit()
|
||||
return len(stuck_tts)
|
||||
|
||||
async def process_tts_generation(self, tts_id: int) -> None:
|
||||
"""Process a TTS generation (used by the processor)."""
|
||||
# Mark as processing
|
||||
await self.mark_tts_processing(tts_id)
|
||||
|
||||
try:
|
||||
# Get the TTS record
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await self.session.exec(stmt)
|
||||
tts = result.first()
|
||||
|
||||
if not tts:
|
||||
tts_not_found_msg = f"TTS with ID {tts_id} not found"
|
||||
raise ValueError(tts_not_found_msg)
|
||||
|
||||
# Generate the TTS
|
||||
sound = await self._generate_tts_sync(
|
||||
tts.text,
|
||||
tts.provider,
|
||||
tts.user_id,
|
||||
tts.options,
|
||||
)
|
||||
|
||||
# Capture sound ID before session issues
|
||||
sound_id = sound.id
|
||||
if sound_id is None:
|
||||
sound_creation_error = "Sound creation failed - no ID assigned"
|
||||
raise ValueError(sound_creation_error)
|
||||
|
||||
# Mark as completed
|
||||
await self.mark_tts_completed(tts_id, sound_id)
|
||||
|
||||
# Emit socket event for completion
|
||||
await self._emit_tts_event("tts_completed", tts_id, sound_id)
|
||||
|
||||
except Exception as e:
|
||||
# Mark as failed
|
||||
await self.mark_tts_failed(tts_id, str(e))
|
||||
|
||||
# Emit socket event for failure
|
||||
await self._emit_tts_event("tts_failed", tts_id, None, str(e))
|
||||
raise
|
||||
|
||||
async def _emit_tts_event(
|
||||
self,
|
||||
event: str,
|
||||
tts_id: int,
|
||||
sound_id: int | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Emit a socket event for TTS status change."""
|
||||
try:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
data = {
|
||||
"tts_id": tts_id,
|
||||
"sound_id": sound_id,
|
||||
}
|
||||
if error:
|
||||
data["error"] = error
|
||||
|
||||
logger.info("Emitting TTS socket event: %s with data: %s", event, data)
|
||||
await socket_manager.broadcast_to_all(event, data)
|
||||
logger.info("Successfully emitted TTS socket event: %s", event)
|
||||
except Exception:
|
||||
# Don't fail TTS processing if socket emission fails
|
||||
logger = get_logger(__name__)
|
||||
logger.exception("Failed to emit TTS socket event %s", event)
|
||||
193
app/services/tts_processor.py
Normal file
193
app/services/tts_processor.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Background TTS processor for handling TTS generation queue."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import engine
|
||||
from app.core.logging import get_logger
|
||||
from app.services.tts import TTSService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TTSProcessor:
|
||||
"""Background processor for handling TTS generation queue."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the TTS processor."""
|
||||
self.max_concurrent = getattr(settings, "TTS_MAX_CONCURRENT", 3)
|
||||
self.running_tts: set[int] = set()
|
||||
self.processing_lock = asyncio.Lock()
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.processor_task: asyncio.Task | None = None
|
||||
|
||||
logger.info(
|
||||
"Initialized TTS processor with max concurrent: %d",
|
||||
self.max_concurrent,
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the background TTS processor."""
|
||||
if self.processor_task and not self.processor_task.done():
|
||||
logger.warning("TTS processor is already running")
|
||||
return
|
||||
|
||||
# Reset any stuck TTS generations from previous runs
|
||||
await self._reset_stuck_tts()
|
||||
|
||||
self.shutdown_event.clear()
|
||||
self.processor_task = asyncio.create_task(self._process_queue())
|
||||
logger.info("Started TTS processor")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the background TTS processor."""
|
||||
logger.info("Stopping TTS processor...")
|
||||
self.shutdown_event.set()
|
||||
|
||||
if self.processor_task and not self.processor_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self.processor_task, timeout=30.0)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"TTS processor did not stop gracefully, cancelling...",
|
||||
)
|
||||
self.processor_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self.processor_task
|
||||
|
||||
logger.info("TTS processor stopped")
|
||||
|
||||
async def queue_tts(self, tts_id: int) -> None:
|
||||
"""Queue a TTS generation for processing."""
|
||||
async with self.processing_lock:
|
||||
if tts_id not in self.running_tts:
|
||||
logger.info("Queued TTS %d for processing", tts_id)
|
||||
# The processor will pick it up on the next cycle
|
||||
else:
|
||||
logger.warning(
|
||||
"TTS %d is already being processed",
|
||||
tts_id,
|
||||
)
|
||||
|
||||
async def _process_queue(self) -> None:
|
||||
"""Process the TTS queue in the main processing loop."""
|
||||
logger.info("Starting TTS queue processor")
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
try:
|
||||
await self._process_pending_tts()
|
||||
|
||||
# Wait before checking for new TTS generations
|
||||
try:
|
||||
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
|
||||
break # Shutdown requested
|
||||
except TimeoutError:
|
||||
continue # Continue processing
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in TTS queue processor")
|
||||
# Wait a bit before retrying to avoid tight error loops
|
||||
try:
|
||||
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
|
||||
break # Shutdown requested
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
logger.info("TTS queue processor stopped")
|
||||
|
||||
async def _process_pending_tts(self) -> None:
|
||||
"""Process pending TTS generations up to the concurrency limit."""
|
||||
async with self.processing_lock:
|
||||
# Check how many slots are available
|
||||
available_slots = self.max_concurrent - len(self.running_tts)
|
||||
|
||||
if available_slots <= 0:
|
||||
return # No available slots
|
||||
|
||||
# Get pending TTS generations from database
|
||||
async with AsyncSession(engine) as session:
|
||||
tts_service = TTSService(session)
|
||||
pending_tts = await tts_service.get_pending_tts()
|
||||
|
||||
# Filter out TTS that are already being processed
|
||||
available_tts = [
|
||||
tts
|
||||
for tts in pending_tts
|
||||
if tts.id not in self.running_tts
|
||||
]
|
||||
|
||||
# Start processing up to available slots
|
||||
tts_to_start = available_tts[:available_slots]
|
||||
|
||||
for tts in tts_to_start:
|
||||
tts_id = tts.id
|
||||
self.running_tts.add(tts_id)
|
||||
|
||||
# Start processing this TTS in the background
|
||||
task = asyncio.create_task(
|
||||
self._process_single_tts(tts_id),
|
||||
)
|
||||
task.add_done_callback(
|
||||
lambda t, tid=tts_id: self._on_tts_completed(
|
||||
tid,
|
||||
t,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Started processing TTS %d (%d/%d slots used)",
|
||||
tts_id,
|
||||
len(self.running_tts),
|
||||
self.max_concurrent,
|
||||
)
|
||||
|
||||
async def _process_single_tts(self, tts_id: int) -> None:
|
||||
"""Process a single TTS generation."""
|
||||
try:
|
||||
async with AsyncSession(engine) as session:
|
||||
tts_service = TTSService(session)
|
||||
await tts_service.process_tts_generation(tts_id)
|
||||
logger.info("Successfully processed TTS %d", tts_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to process TTS %d", tts_id)
|
||||
# Mark TTS as failed in database
|
||||
try:
|
||||
async with AsyncSession(engine) as session:
|
||||
tts_service = TTSService(session)
|
||||
await tts_service.mark_tts_failed(tts_id, "Processing failed")
|
||||
except Exception:
|
||||
logger.exception("Failed to mark TTS %d as failed", tts_id)
|
||||
|
||||
def _on_tts_completed(self, tts_id: int, task: asyncio.Task) -> None:
|
||||
"""Handle completion of a TTS processing task."""
|
||||
self.running_tts.discard(tts_id)
|
||||
|
||||
if task.exception():
|
||||
logger.error(
|
||||
"TTS processing task %d failed: %s",
|
||||
tts_id,
|
||||
task.exception(),
|
||||
)
|
||||
else:
|
||||
logger.info("TTS processing task %d completed successfully", tts_id)
|
||||
|
||||
async def _reset_stuck_tts(self) -> None:
|
||||
"""Reset any TTS generations that were stuck in 'processing' state."""
|
||||
try:
|
||||
async with AsyncSession(engine) as session:
|
||||
tts_service = TTSService(session)
|
||||
reset_count = await tts_service.reset_stuck_tts()
|
||||
if reset_count > 0:
|
||||
logger.info("Reset %d stuck TTS generations", reset_count)
|
||||
else:
|
||||
logger.info("No stuck TTS generations found to reset")
|
||||
except Exception:
|
||||
logger.exception("Failed to reset stuck TTS generations")
|
||||
|
||||
|
||||
# Global TTS processor instance
|
||||
tts_processor = TTSProcessor()
|
||||
@@ -6,10 +6,10 @@ from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.models.sound import Sound
|
||||
from app.models.sound_played import SoundPlayed
|
||||
from app.repositories.sound import SoundRepository
|
||||
@@ -24,7 +24,8 @@ class VLCPlayerService:
|
||||
"""Service for launching VLC instances via subprocess to play sounds."""
|
||||
|
||||
def __init__(
|
||||
self, db_session_factory: Callable[[], AsyncSession] | None = None,
|
||||
self,
|
||||
db_session_factory: Callable[[], AsyncSession] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the VLC player service."""
|
||||
self.vlc_executable = self._find_vlc_executable()
|
||||
@@ -53,7 +54,7 @@ class VLCPlayerService:
|
||||
# For "vlc", try to find it in PATH
|
||||
if path == "vlc":
|
||||
result = subprocess.run(
|
||||
["which", "vlc"],
|
||||
["which", "vlc"], # noqa: S607
|
||||
capture_output=True,
|
||||
check=False,
|
||||
text=True,
|
||||
@@ -72,6 +73,9 @@ class VLCPlayerService:
|
||||
async def play_sound(self, sound: Sound) -> bool:
|
||||
"""Play a sound using a new VLC subprocess instance.
|
||||
|
||||
VLC always plays at 100% volume. Host system volume is controlled separately
|
||||
by the player service.
|
||||
|
||||
Args:
|
||||
sound: The Sound object to play
|
||||
|
||||
@@ -96,6 +100,7 @@ class VLCPlayerService:
|
||||
"--no-video", # Audio only
|
||||
"--no-repeat", # Don't repeat
|
||||
"--no-loop", # Don't loop
|
||||
"--volume=100", # Always use 100% VLC volume
|
||||
]
|
||||
|
||||
# Launch VLC process asynchronously without waiting
|
||||
@@ -113,13 +118,19 @@ class VLCPlayerService:
|
||||
|
||||
# Record play count and emit event
|
||||
if self.db_session_factory and sound.id:
|
||||
asyncio.create_task(self._record_play_count(sound.id, sound.name))
|
||||
|
||||
return True
|
||||
task = asyncio.create_task(
|
||||
self._record_play_count(sound.id, sound.name),
|
||||
)
|
||||
# Store reference to prevent garbage collection
|
||||
self._background_tasks = getattr(self, "_background_tasks", set())
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to launch VLC for sound %s", sound.name)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
async def stop_all_vlc_instances(self) -> dict[str, Any]:
|
||||
"""Stop all running VLC processes by killing them.
|
||||
@@ -137,7 +148,7 @@ class VLCPlayerService:
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await find_process.communicate()
|
||||
stdout, _stderr = await find_process.communicate()
|
||||
|
||||
if find_process.returncode != 0:
|
||||
# No VLC processes found
|
||||
@@ -231,72 +242,164 @@ class VLCPlayerService:
|
||||
return
|
||||
|
||||
logger.info("Recording play count for sound %s", sound_id)
|
||||
session = self.db_session_factory()
|
||||
|
||||
# Initialize variables for WebSocket event
|
||||
old_count = 0
|
||||
sound = None
|
||||
admin_user_id = None
|
||||
admin_user_name = None
|
||||
|
||||
try:
|
||||
sound_repo = SoundRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
async with self.db_session_factory() as session:
|
||||
sound_repo = SoundRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# Update sound play count
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
old_count = 0
|
||||
if sound:
|
||||
old_count = sound.play_count
|
||||
await sound_repo.update(
|
||||
sound,
|
||||
{"play_count": sound.play_count + 1},
|
||||
# Update sound play count
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
if sound:
|
||||
old_count = sound.play_count
|
||||
# Update the sound's play count using direct attribute modification
|
||||
sound.play_count = sound.play_count + 1
|
||||
session.add(sound)
|
||||
await session.commit()
|
||||
await session.refresh(sound)
|
||||
logger.info(
|
||||
"Updated sound %s play_count: %s -> %s",
|
||||
sound_id,
|
||||
old_count,
|
||||
old_count + 1,
|
||||
)
|
||||
else:
|
||||
logger.warning("Sound %s not found for play count update", sound_id)
|
||||
|
||||
# Record play history for admin user (ID 1) as placeholder
|
||||
# This could be refined to track per-user play history
|
||||
admin_user = await user_repo.get_by_id(1)
|
||||
if admin_user:
|
||||
admin_user_id = admin_user.id
|
||||
admin_user_name = admin_user.name
|
||||
|
||||
# Always create a new SoundPlayed record for each play event
|
||||
sound_played = SoundPlayed(
|
||||
user_id=admin_user_id, # Can be None for player-based plays
|
||||
sound_id=sound_id,
|
||||
)
|
||||
session.add(sound_played)
|
||||
logger.info(
|
||||
"Updated sound %s play_count: %s -> %s",
|
||||
"Created SoundPlayed record for user %s, sound %s",
|
||||
admin_user_id,
|
||||
sound_id,
|
||||
old_count,
|
||||
old_count + 1,
|
||||
)
|
||||
else:
|
||||
logger.warning("Sound %s not found for play count update", sound_id)
|
||||
|
||||
# Record play history for admin user (ID 1) as placeholder
|
||||
# This could be refined to track per-user play history
|
||||
admin_user = await user_repo.get_by_id(1)
|
||||
admin_user_id = None
|
||||
if admin_user:
|
||||
admin_user_id = admin_user.id
|
||||
await session.commit()
|
||||
logger.info("Successfully recorded play count for sound %s", sound_id)
|
||||
except Exception:
|
||||
logger.exception("Error recording play count for sound %s", sound_id)
|
||||
|
||||
# Always create a new SoundPlayed record for each play event
|
||||
sound_played = SoundPlayed(
|
||||
user_id=admin_user_id, # Can be None for player-based plays
|
||||
sound_id=sound_id,
|
||||
)
|
||||
session.add(sound_played)
|
||||
logger.info(
|
||||
"Created SoundPlayed record for user %s, sound %s",
|
||||
admin_user_id,
|
||||
# Emit sound_played event via WebSocket (outside session context)
|
||||
try:
|
||||
event_data = {
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound_name,
|
||||
"user_id": admin_user_id,
|
||||
"user_name": admin_user_name,
|
||||
"play_count": (old_count + 1) if sound else None,
|
||||
}
|
||||
await socket_manager.broadcast_to_all("sound_played", event_data)
|
||||
logger.info("Broadcasted sound_played event for sound %s", sound_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to broadcast sound_played event for sound %s",
|
||||
sound_id,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Successfully recorded play count for sound %s", sound_id)
|
||||
async def play_sound_with_credits(
|
||||
self,
|
||||
sound_id: int,
|
||||
user_id: int,
|
||||
) -> dict[str, str | int | bool]:
|
||||
"""Play sound with VLC with credit validation and deduction.
|
||||
|
||||
# Emit sound_played event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound_name,
|
||||
"user_id": admin_user_id,
|
||||
"play_count": (old_count + 1) if sound else None,
|
||||
}
|
||||
await socket_manager.broadcast_to_all("sound_played", event_data)
|
||||
logger.info("Broadcasted sound_played event for sound %s", sound_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to broadcast sound_played event for sound %s", sound_id,
|
||||
This method combines credit checking, sound playing, and credit deduction
|
||||
in a single operation. Used by both HTTP and WebSocket endpoints.
|
||||
|
||||
Args:
|
||||
sound_id: ID of the sound to play
|
||||
user_id: ID of the user playing the sound
|
||||
|
||||
Returns:
|
||||
dict: Result information including success status and message
|
||||
|
||||
Raises:
|
||||
HTTPException: For various error conditions (sound not found,
|
||||
insufficient credits, VLC failure)
|
||||
|
||||
"""
|
||||
from fastapi import HTTPException, status # noqa: PLC0415, I001
|
||||
from app.services.credit import ( # noqa: PLC0415
|
||||
CreditService,
|
||||
InsufficientCreditsError,
|
||||
)
|
||||
|
||||
if not self.db_session_factory:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Database session factory not configured",
|
||||
)
|
||||
|
||||
async with self.db_session_factory() as session:
|
||||
sound_repo = SoundRepository(session)
|
||||
|
||||
# Get the sound
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Sound with ID {sound_id} not found",
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error recording play count for sound %s", sound_id)
|
||||
await session.rollback()
|
||||
finally:
|
||||
await session.close()
|
||||
# Get credit service
|
||||
credit_service = CreditService(self.db_session_factory)
|
||||
|
||||
# Check and validate credits before playing
|
||||
try:
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
user_id,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
)
|
||||
except InsufficientCreditsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=(
|
||||
f"Insufficient credits: {e.required} required, "
|
||||
f"{e.available} available"
|
||||
),
|
||||
) from e
|
||||
|
||||
# Play the sound using VLC (always at 100% VLC volume)
|
||||
success = await self.play_sound(sound)
|
||||
|
||||
# Deduct credits based on success
|
||||
await credit_service.deduct_credits(
|
||||
user_id,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
success=success,
|
||||
metadata={"sound_id": sound_id, "sound_name": sound.name},
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to launch VLC for sound playback",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Sound '{sound.name}' is now playing via VLC",
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound.name,
|
||||
"success": True,
|
||||
"credits_deducted": 1,
|
||||
}
|
||||
|
||||
|
||||
# Global VLC player service instance
|
||||
|
||||
251
app/services/volume.py
Normal file
251
app/services/volume.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Volume service for host system volume control."""
|
||||
|
||||
import platform
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Constants
|
||||
MIN_VOLUME = 0
|
||||
MAX_VOLUME = 100
|
||||
|
||||
|
||||
class VolumeService:
|
||||
"""Service for controlling host system volume."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize volume service."""
|
||||
self._system = platform.system().lower()
|
||||
self._pycaw_available = False
|
||||
self._pulsectl_available = False
|
||||
|
||||
# Try to import Windows volume control
|
||||
if self._system == "windows":
|
||||
try:
|
||||
from comtypes import ( # noqa: PLC0415
|
||||
CLSCTX_ALL, # type: ignore[import-untyped]
|
||||
)
|
||||
from pycaw.pycaw import ( # type: ignore[import-untyped] # noqa: PLC0415
|
||||
AudioUtilities,
|
||||
IAudioEndpointVolume,
|
||||
)
|
||||
|
||||
self._AudioUtilities = AudioUtilities
|
||||
self._IAudioEndpointVolume = IAudioEndpointVolume
|
||||
self._CLSCTX_ALL = CLSCTX_ALL
|
||||
self._pycaw_available = True
|
||||
logger.info("Windows volume control (pycaw) initialized")
|
||||
except ImportError as e:
|
||||
logger.warning("pycaw not available: %s", e)
|
||||
|
||||
# Try to import Linux volume control
|
||||
elif self._system == "linux":
|
||||
try:
|
||||
import pulsectl # type: ignore[import-untyped] # noqa: PLC0415
|
||||
|
||||
self._pulsectl = pulsectl
|
||||
self._pulsectl_available = True
|
||||
logger.info("Linux volume control (pulsectl) initialized")
|
||||
except ImportError as e:
|
||||
logger.warning("pulsectl not available: %s", e)
|
||||
|
||||
else:
|
||||
logger.warning("Volume control not supported on %s", self._system)
|
||||
|
||||
def get_volume(self) -> int | None:
|
||||
"""Get the current system volume as a percentage (0-100).
|
||||
|
||||
Returns:
|
||||
Volume level as percentage, or None if not available
|
||||
|
||||
"""
|
||||
try:
|
||||
if self._system == "windows" and self._pycaw_available:
|
||||
return self._get_windows_volume()
|
||||
if self._system == "linux" and self._pulsectl_available:
|
||||
return self._get_linux_volume()
|
||||
except Exception:
|
||||
logger.exception("Failed to get volume")
|
||||
return None
|
||||
else:
|
||||
logger.warning("Volume control not available for this system")
|
||||
return None
|
||||
|
||||
def set_volume(self, volume: int) -> bool:
|
||||
"""Set the system volume to a percentage (0-100).
|
||||
|
||||
Args:
|
||||
volume: Volume level as percentage (0-100)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
||||
"""
|
||||
if not (MIN_VOLUME <= volume <= MAX_VOLUME):
|
||||
logger.error(
|
||||
"Volume must be between %s and %s, got %s",
|
||||
MIN_VOLUME,
|
||||
MAX_VOLUME,
|
||||
volume,
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
if self._system == "windows" and self._pycaw_available:
|
||||
return self._set_windows_volume(volume)
|
||||
if self._system == "linux" and self._pulsectl_available:
|
||||
return self._set_linux_volume(volume)
|
||||
except Exception:
|
||||
logger.exception("Failed to set volume")
|
||||
return False
|
||||
else:
|
||||
logger.warning("Volume control not available for this system")
|
||||
return False
|
||||
|
||||
def is_muted(self) -> bool | None:
|
||||
"""Check if the system is muted.
|
||||
|
||||
Returns:
|
||||
True if muted, False if not muted, None if not available
|
||||
|
||||
"""
|
||||
try:
|
||||
if self._system == "windows" and self._pycaw_available:
|
||||
return self._get_windows_mute_status()
|
||||
if self._system == "linux" and self._pulsectl_available:
|
||||
return self._get_linux_mute_status()
|
||||
except Exception:
|
||||
logger.exception("Failed to get mute status")
|
||||
return None
|
||||
else:
|
||||
logger.warning("Mute status not available for this system")
|
||||
return None
|
||||
|
||||
def set_mute(self, *, muted: bool) -> bool:
|
||||
"""Set the system mute status.
|
||||
|
||||
Args:
|
||||
muted: True to mute, False to unmute
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
if self._system == "windows" and self._pycaw_available:
|
||||
return self._set_windows_mute(muted=muted)
|
||||
if self._system == "linux" and self._pulsectl_available:
|
||||
return self._set_linux_mute(muted=muted)
|
||||
except Exception:
|
||||
logger.exception("Failed to set mute status")
|
||||
return False
|
||||
else:
|
||||
logger.warning("Mute control not available for this system")
|
||||
return False
|
||||
|
||||
def _get_windows_volume(self) -> int:
|
||||
"""Get Windows volume using pycaw."""
|
||||
devices = self._AudioUtilities.GetSpeakers()
|
||||
interface = devices.Activate(
|
||||
self._IAudioEndpointVolume._iid_, self._CLSCTX_ALL, None,
|
||||
)
|
||||
volume = interface.QueryInterface(self._IAudioEndpointVolume)
|
||||
current_volume = volume.GetMasterVolume()
|
||||
|
||||
# Convert from scalar (0.0-1.0) to percentage (0-100)
|
||||
return int(current_volume * MAX_VOLUME)
|
||||
|
||||
def _set_windows_volume(self, volume_percent: int) -> bool:
|
||||
"""Set Windows volume using pycaw."""
|
||||
devices = self._AudioUtilities.GetSpeakers()
|
||||
interface = devices.Activate(
|
||||
self._IAudioEndpointVolume._iid_, self._CLSCTX_ALL, None,
|
||||
)
|
||||
volume = interface.QueryInterface(self._IAudioEndpointVolume)
|
||||
|
||||
# Convert from percentage (0-100) to scalar (0.0-1.0)
|
||||
volume_scalar = volume_percent / MAX_VOLUME
|
||||
volume.SetMasterVolume(volume_scalar, None)
|
||||
logger.info("Windows volume set to %s%%", volume_percent)
|
||||
return True
|
||||
|
||||
def _get_windows_mute_status(self) -> bool:
|
||||
"""Get Windows mute status using pycaw."""
|
||||
devices = self._AudioUtilities.GetSpeakers()
|
||||
interface = devices.Activate(
|
||||
self._IAudioEndpointVolume._iid_, self._CLSCTX_ALL, None,
|
||||
)
|
||||
volume = interface.QueryInterface(self._IAudioEndpointVolume)
|
||||
return bool(volume.GetMute())
|
||||
|
||||
def _set_windows_mute(self, *, muted: bool) -> bool:
|
||||
"""Set Windows mute status using pycaw."""
|
||||
devices = self._AudioUtilities.GetSpeakers()
|
||||
interface = devices.Activate(
|
||||
self._IAudioEndpointVolume._iid_, self._CLSCTX_ALL, None,
|
||||
)
|
||||
volume = interface.QueryInterface(self._IAudioEndpointVolume)
|
||||
volume.SetMute(muted, None)
|
||||
logger.info("Windows mute set to %s", muted)
|
||||
return True
|
||||
|
||||
def _get_linux_volume(self) -> int:
|
||||
"""Get Linux volume using pulsectl."""
|
||||
with self._pulsectl.Pulse("volume-service") as pulse:
|
||||
# Get the default sink (output device)
|
||||
default_sink = pulse.get_sink_by_name(pulse.server_info().default_sink_name)
|
||||
if default_sink is None:
|
||||
logger.error("No default audio sink found")
|
||||
return MIN_VOLUME
|
||||
|
||||
# Get volume as percentage (PulseAudio uses 0.0-1.0, we convert to 0-100)
|
||||
volume = default_sink.volume
|
||||
avg_volume = sum(volume.values) / len(volume.values)
|
||||
return int(avg_volume * MAX_VOLUME)
|
||||
|
||||
def _set_linux_volume(self, volume_percent: int) -> bool:
|
||||
"""Set Linux volume using pulsectl."""
|
||||
with self._pulsectl.Pulse("volume-service") as pulse:
|
||||
# Get the default sink (output device)
|
||||
default_sink = pulse.get_sink_by_name(pulse.server_info().default_sink_name)
|
||||
if default_sink is None:
|
||||
logger.error("No default audio sink found")
|
||||
return False
|
||||
|
||||
# Convert percentage to PulseAudio volume (0.0-1.0)
|
||||
volume_scalar = volume_percent / MAX_VOLUME
|
||||
|
||||
# Set volume for all channels
|
||||
pulse.volume_set_all_chans(default_sink, volume_scalar)
|
||||
logger.info("Linux volume set to %s%%", volume_percent)
|
||||
return True
|
||||
|
||||
def _get_linux_mute_status(self) -> bool:
|
||||
"""Get Linux mute status using pulsectl."""
|
||||
with self._pulsectl.Pulse("volume-service") as pulse:
|
||||
# Get the default sink (output device)
|
||||
default_sink = pulse.get_sink_by_name(pulse.server_info().default_sink_name)
|
||||
if default_sink is None:
|
||||
logger.error("No default audio sink found")
|
||||
return False
|
||||
|
||||
return bool(default_sink.mute)
|
||||
|
||||
def _set_linux_mute(self, *, muted: bool) -> bool:
|
||||
"""Set Linux mute status using pulsectl."""
|
||||
with self._pulsectl.Pulse("volume-service") as pulse:
|
||||
# Get the default sink (output device)
|
||||
default_sink = pulse.get_sink_by_name(pulse.server_info().default_sink_name)
|
||||
if default_sink is None:
|
||||
logger.error("No default audio sink found")
|
||||
return False
|
||||
|
||||
# Set mute status
|
||||
pulse.mute(default_sink, muted)
|
||||
logger.info("Linux mute set to %s", muted)
|
||||
return True
|
||||
|
||||
|
||||
# Global volume service instance
|
||||
volume_service = VolumeService()
|
||||
@@ -34,7 +34,7 @@ def get_audio_duration(file_path: Path) -> int:
|
||||
probe = ffmpeg.probe(str(file_path))
|
||||
duration = float(probe["format"]["duration"])
|
||||
return int(duration * 1000) # Convert to milliseconds
|
||||
except Exception as e:
|
||||
except (ffmpeg.Error, KeyError, ValueError, TypeError, Exception) as e: # noqa: BLE001
|
||||
logger.warning("Failed to get duration for %s: %s", file_path, e)
|
||||
return 0
|
||||
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
"""Cookie parsing utilities for WebSocket authentication."""
|
||||
"""Cookie parsing and setting utilities for WebSocket and HTTP authentication."""
|
||||
|
||||
from fastapi import Response
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def parse_cookies(cookie_header: str) -> dict[str, str]:
|
||||
"""Parse HTTP cookie header into a dictionary."""
|
||||
cookies = {}
|
||||
cookies: dict[str, str] = {}
|
||||
if not cookie_header:
|
||||
return cookies
|
||||
|
||||
for cookie in cookie_header.split(";"):
|
||||
cookie = cookie.strip()
|
||||
if "=" in cookie:
|
||||
name, value = cookie.split("=", 1)
|
||||
for cookie_part in cookie_header.split(";"):
|
||||
cookie_str = cookie_part.strip()
|
||||
if "=" in cookie_str:
|
||||
name, value = cookie_str.split("=", 1)
|
||||
cookies[name.strip()] = value.strip()
|
||||
|
||||
return cookies
|
||||
@@ -20,3 +24,52 @@ def extract_access_token_from_cookies(cookie_header: str) -> str | None:
|
||||
"""Extract access token from HTTP cookies."""
|
||||
cookies = parse_cookies(cookie_header)
|
||||
return cookies.get("access_token")
|
||||
|
||||
|
||||
def set_access_token_cookie(
|
||||
response: Response,
|
||||
access_token: str,
|
||||
expires_in: int,
|
||||
path: str = "/",
|
||||
) -> None:
|
||||
"""Set access token cookie with consistent configuration."""
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=access_token,
|
||||
max_age=expires_in,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain=settings.COOKIE_DOMAIN,
|
||||
path=path,
|
||||
)
|
||||
|
||||
|
||||
def set_refresh_token_cookie(
|
||||
response: Response,
|
||||
refresh_token: str,
|
||||
path: str = "/",
|
||||
) -> None:
|
||||
"""Set refresh token cookie with consistent configuration."""
|
||||
response.set_cookie(
|
||||
key="refresh_token",
|
||||
value=refresh_token,
|
||||
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
domain=settings.COOKIE_DOMAIN,
|
||||
path=path,
|
||||
)
|
||||
|
||||
|
||||
def set_auth_cookies(
|
||||
response: Response,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
expires_in: int,
|
||||
path: str = "/",
|
||||
) -> None:
|
||||
"""Set both access and refresh token cookies with consistent configuration."""
|
||||
set_access_token_cookie(response, access_token, expires_in, path)
|
||||
set_refresh_token_cookie(response, refresh_token, path)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
"""Decorators for credit management and validation."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
from app.services.credit import CreditService
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
|
||||
|
||||
@@ -16,7 +18,7 @@ def requires_credits(
|
||||
user_id_param: str = "user_id",
|
||||
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to enforce credit requirements for actions.
|
||||
"""Enforce credit requirements for actions.
|
||||
|
||||
Args:
|
||||
action_type: The type of action that requires credits
|
||||
@@ -38,16 +40,16 @@ def requires_credits(
|
||||
return True
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
# Extract user ID from parameters
|
||||
user_id = None
|
||||
if user_id_param in kwargs:
|
||||
user_id = kwargs[user_id_param]
|
||||
else:
|
||||
# Try to find user_id in function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
if user_id_param in param_names:
|
||||
@@ -69,26 +71,31 @@ def requires_credits(
|
||||
|
||||
# Validate credits before execution
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
user_id, action_type, metadata
|
||||
user_id,
|
||||
action_type,
|
||||
)
|
||||
|
||||
# Execute the function
|
||||
success = False
|
||||
result = None
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
success = bool(result) # Consider function result as success indicator
|
||||
return result
|
||||
except Exception:
|
||||
success = False
|
||||
raise
|
||||
else:
|
||||
return result
|
||||
finally:
|
||||
# Deduct credits based on success
|
||||
await credit_service.deduct_credits(
|
||||
user_id, action_type, success, metadata
|
||||
user_id,
|
||||
action_type,
|
||||
success=success,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -97,7 +104,7 @@ def validate_credits_only(
|
||||
credit_service_factory: Callable[[], CreditService],
|
||||
user_id_param: str = "user_id",
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to only validate credits without deducting them.
|
||||
"""Validate credits without deducting them.
|
||||
|
||||
Useful for checking if a user can perform an action before actual execution.
|
||||
|
||||
@@ -110,16 +117,16 @@ def validate_credits_only(
|
||||
Decorated function that validates credits only
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
# Extract user ID from parameters
|
||||
user_id = None
|
||||
if user_id_param in kwargs:
|
||||
user_id = kwargs[user_id_param]
|
||||
else:
|
||||
# Try to find user_id in function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
if user_id_param in param_names:
|
||||
@@ -141,6 +148,7 @@ def validate_credits_only(
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -173,20 +181,29 @@ class CreditManager:
|
||||
async def __aenter__(self) -> "CreditManager":
|
||||
"""Enter context manager - validate credits."""
|
||||
await self.credit_service.validate_and_reserve_credits(
|
||||
self.user_id, self.action_type, self.metadata
|
||||
self.user_id,
|
||||
self.action_type,
|
||||
)
|
||||
self.validated = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: types.TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit context manager - deduct credits based on success."""
|
||||
if self.validated:
|
||||
# If no exception occurred, consider it successful
|
||||
success = exc_type is None and self.success
|
||||
await self.credit_service.deduct_credits(
|
||||
self.user_id, self.action_type, success, self.metadata
|
||||
self.user_id,
|
||||
self.action_type,
|
||||
success=success,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
def mark_success(self) -> None:
|
||||
"""Mark the operation as successful."""
|
||||
self.success = True
|
||||
self.success = True
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
"""Database utility functions for common operations."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar
|
||||
from sqlmodel import select, SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
T = TypeVar("T", bound=SQLModel)
|
||||
|
||||
|
||||
async def create_and_save(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
**kwargs: Any
|
||||
) -> T:
|
||||
"""Create, add, commit, and refresh a model instance.
|
||||
|
||||
This consolidates the common database pattern of:
|
||||
- instance = ModelClass(**kwargs)
|
||||
- session.add(instance)
|
||||
- await session.commit()
|
||||
- await session.refresh(instance)
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class to instantiate
|
||||
**kwargs: Arguments to pass to model constructor
|
||||
|
||||
Returns:
|
||||
Created and refreshed model instance
|
||||
"""
|
||||
instance = model_class(**kwargs)
|
||||
session.add(instance)
|
||||
await session.commit()
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
async def get_or_create(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
defaults: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> tuple[T, bool]:
|
||||
"""Get an existing instance or create a new one.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class
|
||||
defaults: Default values for creation (if not found)
|
||||
**kwargs: Filter criteria for lookup
|
||||
|
||||
Returns:
|
||||
Tuple of (instance, created) where created is True if instance was created
|
||||
"""
|
||||
# Build filter conditions
|
||||
filters = []
|
||||
for key, value in kwargs.items():
|
||||
filters.append(getattr(model_class, key) == value)
|
||||
|
||||
# Try to find existing instance
|
||||
statement = select(model_class).where(*filters)
|
||||
result = await session.exec(statement)
|
||||
instance = result.first()
|
||||
|
||||
if instance:
|
||||
return instance, False
|
||||
|
||||
# Create new instance
|
||||
create_kwargs = {**kwargs}
|
||||
if defaults:
|
||||
create_kwargs.update(defaults)
|
||||
|
||||
instance = await create_and_save(session, model_class, **create_kwargs)
|
||||
return instance, True
|
||||
|
||||
|
||||
async def update_and_save(
|
||||
session: AsyncSession,
|
||||
instance: T,
|
||||
**updates: Any
|
||||
) -> T:
|
||||
"""Update model instance fields and save to database.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
instance: Model instance to update
|
||||
**updates: Field updates to apply
|
||||
|
||||
Returns:
|
||||
Updated and refreshed model instance
|
||||
"""
|
||||
for field, value in updates.items():
|
||||
setattr(instance, field, value)
|
||||
|
||||
session.add(instance)
|
||||
await session.commit()
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
async def bulk_create(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
items: List[Dict[str, Any]]
|
||||
) -> List[T]:
|
||||
"""Create multiple model instances in bulk.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class to instantiate
|
||||
items: List of dictionaries with model data
|
||||
|
||||
Returns:
|
||||
List of created model instances
|
||||
"""
|
||||
instances = []
|
||||
for item_data in items:
|
||||
instance = model_class(**item_data)
|
||||
session.add(instance)
|
||||
instances.append(instance)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Refresh all instances
|
||||
for instance in instances:
|
||||
await session.refresh(instance)
|
||||
|
||||
return instances
|
||||
|
||||
|
||||
async def delete_and_commit(
|
||||
session: AsyncSession,
|
||||
instance: T
|
||||
) -> None:
|
||||
"""Delete an instance and commit the transaction.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
instance: Model instance to delete
|
||||
"""
|
||||
await session.delete(instance)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def exists(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
**kwargs: Any
|
||||
) -> bool:
|
||||
"""Check if a model instance exists with given criteria.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class
|
||||
**kwargs: Filter criteria
|
||||
|
||||
Returns:
|
||||
True if instance exists, False otherwise
|
||||
"""
|
||||
filters = []
|
||||
for key, value in kwargs.items():
|
||||
filters.append(getattr(model_class, key) == value)
|
||||
|
||||
statement = select(model_class).where(*filters)
|
||||
result = await session.exec(statement)
|
||||
return result.first() is not None
|
||||
@@ -1,121 +0,0 @@
|
||||
"""Utility functions for common HTTP exception patterns."""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def raise_not_found(resource: str, identifier: str = None) -> None:
|
||||
"""Raise a standardized 404 Not Found exception.
|
||||
|
||||
Args:
|
||||
resource: Name of the resource that wasn't found
|
||||
identifier: Optional identifier for the specific resource
|
||||
|
||||
Raises:
|
||||
HTTPException with 404 status code
|
||||
"""
|
||||
if identifier:
|
||||
detail = f"{resource} with ID {identifier} not found"
|
||||
else:
|
||||
detail = f"{resource} not found"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_unauthorized(detail: str = "Could not validate credentials") -> None:
|
||||
"""Raise a standardized 401 Unauthorized exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 401 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_bad_request(detail: str) -> None:
|
||||
"""Raise a standardized 400 Bad Request exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 400 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_internal_server_error(detail: str, cause: Exception = None) -> None:
|
||||
"""Raise a standardized 500 Internal Server Error exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
cause: Optional underlying exception
|
||||
|
||||
Raises:
|
||||
HTTPException with 500 status code
|
||||
"""
|
||||
if cause:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=detail,
|
||||
) from cause
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_payment_required(detail: str = "Insufficient credits") -> None:
|
||||
"""Raise a standardized 402 Payment Required exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 402 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_forbidden(detail: str = "Access forbidden") -> None:
|
||||
"""Raise a standardized 403 Forbidden exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 403 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_conflict(detail: str) -> None:
|
||||
"""Raise a standardized 409 Conflict exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 409 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=detail,
|
||||
)
|
||||
@@ -1,179 +0,0 @@
|
||||
"""Test helper utilities for reducing code duplication."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Optional, Type, TypeVar
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.models.user import User
|
||||
from app.utils.auth import JWTUtils
|
||||
|
||||
T = TypeVar("T", bound=SQLModel)
|
||||
|
||||
|
||||
def create_jwt_token_data(user: User) -> Dict[str, str]:
|
||||
"""Create standardized JWT token data dictionary for a user.
|
||||
|
||||
Args:
|
||||
user: User object to create token data for
|
||||
|
||||
Returns:
|
||||
Dictionary with sub, email, and role fields
|
||||
"""
|
||||
return {
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
|
||||
def create_access_token_for_user(user: User) -> str:
|
||||
"""Create an access token for a user using standardized token data.
|
||||
|
||||
Args:
|
||||
user: User object to create token for
|
||||
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
token_data = create_jwt_token_data(user)
|
||||
return JWTUtils.create_access_token(token_data)
|
||||
|
||||
|
||||
async def create_and_save_model(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
**kwargs: Any
|
||||
) -> T:
|
||||
"""Create, save, and refresh a model instance.
|
||||
|
||||
This consolidates the common pattern of:
|
||||
- model = ModelClass(**kwargs)
|
||||
- session.add(model)
|
||||
- await session.commit()
|
||||
- await session.refresh(model)
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class to instantiate
|
||||
**kwargs: Arguments to pass to model constructor
|
||||
|
||||
Returns:
|
||||
Created and refreshed model instance
|
||||
"""
|
||||
instance = model_class(**kwargs)
|
||||
session.add(instance)
|
||||
await session.commit()
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def override_dependencies(
|
||||
app: FastAPI,
|
||||
overrides: Dict[Any, Any]
|
||||
):
|
||||
"""Context manager for FastAPI dependency overrides with automatic cleanup.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
overrides: Dictionary mapping dependency functions to mock implementations
|
||||
|
||||
Usage:
|
||||
async with override_dependencies(test_app, {
|
||||
get_service: lambda: mock_service,
|
||||
get_repo: lambda: mock_repo
|
||||
}):
|
||||
# Test code here
|
||||
pass
|
||||
# Dependencies automatically cleaned up
|
||||
"""
|
||||
# Apply overrides
|
||||
for dependency, override in overrides.items():
|
||||
app.dependency_overrides[dependency] = override
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Clean up overrides
|
||||
for dependency in overrides:
|
||||
app.dependency_overrides.pop(dependency, None)
|
||||
|
||||
|
||||
def create_mock_vlc_services() -> Dict[str, AsyncMock]:
|
||||
"""Create standard set of mocked VLC-related services.
|
||||
|
||||
Returns:
|
||||
Dictionary with mocked vlc_service, sound_repository, and credit_service
|
||||
"""
|
||||
return {
|
||||
"vlc_service": AsyncMock(),
|
||||
"sound_repository": AsyncMock(),
|
||||
"credit_service": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
def configure_mock_sound_play_success(
|
||||
mocks: Dict[str, AsyncMock],
|
||||
sound_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Configure mocks for successful sound playback scenario.
|
||||
|
||||
Args:
|
||||
mocks: Dictionary of mock services from create_mock_vlc_services()
|
||||
sound_data: Dictionary with sound properties (id, name, etc.)
|
||||
"""
|
||||
from app.models.sound import Sound
|
||||
|
||||
mock_sound = Sound(**sound_data)
|
||||
|
||||
# Configure repository mock
|
||||
mocks["sound_repository"].get_by_id.return_value = mock_sound
|
||||
|
||||
# Configure credit service mocks
|
||||
mocks["credit_service"].validate_and_reserve_credits.return_value = None
|
||||
mocks["credit_service"].deduct_credits.return_value = None
|
||||
|
||||
# Configure VLC service mock
|
||||
mocks["vlc_service"].play_sound.return_value = True
|
||||
|
||||
|
||||
def create_mock_vlc_stop_result(
|
||||
success: bool = True,
|
||||
processes_found: int = 3,
|
||||
processes_killed: int = 3,
|
||||
processes_remaining: int = 0,
|
||||
message: Optional[str] = None,
|
||||
error: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create standardized VLC stop operation result.
|
||||
|
||||
Args:
|
||||
success: Whether operation succeeded
|
||||
processes_found: Number of VLC processes found
|
||||
processes_killed: Number of processes successfully killed
|
||||
processes_remaining: Number of processes still running
|
||||
message: Success/status message
|
||||
error: Error message (for failed operations)
|
||||
|
||||
Returns:
|
||||
Dictionary with VLC stop operation result
|
||||
"""
|
||||
result = {
|
||||
"success": success,
|
||||
"processes_found": processes_found,
|
||||
"processes_killed": processes_killed,
|
||||
}
|
||||
|
||||
if not success:
|
||||
result["error"] = error or "Command failed"
|
||||
result["message"] = message or "Failed to stop VLC processes"
|
||||
else:
|
||||
# Always include processes_remaining for successful operations
|
||||
result["processes_remaining"] = processes_remaining
|
||||
result["message"] = message or f"Killed {processes_killed} VLC processes"
|
||||
|
||||
return result
|
||||
@@ -1,140 +0,0 @@
|
||||
"""Common validation utility functions."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
# Password validation constants
|
||||
MIN_PASSWORD_LENGTH = 8
|
||||
|
||||
|
||||
def validate_email(email: str) -> bool:
|
||||
"""Validate email address format.
|
||||
|
||||
Args:
|
||||
email: Email address to validate
|
||||
|
||||
Returns:
|
||||
True if email format is valid, False otherwise
|
||||
"""
|
||||
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
return bool(re.match(pattern, email))
|
||||
|
||||
|
||||
def validate_password_strength(password: str) -> tuple[bool, str | None]:
|
||||
"""Validate password meets security requirements.
|
||||
|
||||
Args:
|
||||
password: Password to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if len(password) < MIN_PASSWORD_LENGTH:
|
||||
msg = f"Password must be at least {MIN_PASSWORD_LENGTH} characters long"
|
||||
return False, msg
|
||||
|
||||
if not re.search(r"[A-Z]", password):
|
||||
return False, "Password must contain at least one uppercase letter"
|
||||
|
||||
if not re.search(r"[a-z]", password):
|
||||
return False, "Password must contain at least one lowercase letter"
|
||||
|
||||
if not re.search(r"\d", password):
|
||||
return False, "Password must contain at least one number"
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def validate_filename(
|
||||
filename: str, allowed_extensions: list[str] | None = None
|
||||
) -> bool:
|
||||
"""Validate filename format and extension.
|
||||
|
||||
Args:
|
||||
filename: Filename to validate
|
||||
allowed_extensions: List of allowed file extensions (with dots)
|
||||
|
||||
Returns:
|
||||
True if filename is valid, False otherwise
|
||||
"""
|
||||
if not filename or filename.startswith(".") or "/" in filename or "\\" in filename:
|
||||
return False
|
||||
|
||||
if allowed_extensions:
|
||||
file_path = Path(filename)
|
||||
return file_path.suffix.lower() in [ext.lower() for ext in allowed_extensions]
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_audio_filename(filename: str) -> bool:
|
||||
"""Validate audio filename has allowed extension.
|
||||
|
||||
Args:
|
||||
filename: Audio filename to validate
|
||||
|
||||
Returns:
|
||||
True if filename has valid audio extension, False otherwise
|
||||
"""
|
||||
audio_extensions = [".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"]
|
||||
return validate_filename(filename, audio_extensions)
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""Sanitize filename by removing/replacing invalid characters.
|
||||
|
||||
Args:
|
||||
filename: Filename to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized filename safe for filesystem
|
||||
"""
|
||||
# Remove or replace problematic characters
|
||||
sanitized = re.sub(r'[<>:"/\\|?*]', "_", filename)
|
||||
|
||||
# Remove leading/trailing whitespace and dots
|
||||
sanitized = sanitized.strip(" .")
|
||||
|
||||
# Ensure not empty
|
||||
if not sanitized:
|
||||
sanitized = "untitled"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def validate_url(url: str) -> bool:
|
||||
"""Validate URL format.
|
||||
|
||||
Args:
|
||||
url: URL to validate
|
||||
|
||||
Returns:
|
||||
True if URL format is valid, False otherwise
|
||||
"""
|
||||
pattern = r"^https?://[^\s/$.?#].[^\s]*$"
|
||||
return bool(re.match(pattern, url))
|
||||
|
||||
|
||||
def validate_positive_integer(value: Any, field_name: str = "value") -> int:
|
||||
"""Validate and convert value to positive integer.
|
||||
|
||||
Args:
|
||||
value: Value to validate and convert
|
||||
field_name: Name of field for error messages
|
||||
|
||||
Returns:
|
||||
Validated positive integer
|
||||
|
||||
Raises:
|
||||
ValueError: If value is not a positive integer
|
||||
"""
|
||||
try:
|
||||
int_value = int(value)
|
||||
if int_value <= 0:
|
||||
msg = f"{field_name} must be a positive integer"
|
||||
raise ValueError(msg)
|
||||
return int_value
|
||||
except (TypeError, ValueError) as e:
|
||||
msg = f"{field_name} must be a positive integer"
|
||||
raise ValueError(msg) from e
|
||||
100
migrate.py
Executable file
100
migrate.py
Executable file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Database migration CLI tool."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from alembic.config import Config
|
||||
|
||||
from alembic import command
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run database migration CLI tool."""
|
||||
parser = argparse.ArgumentParser(description="Database migration tool")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Migration commands")
|
||||
|
||||
# Upgrade command
|
||||
upgrade_parser = subparsers.add_parser(
|
||||
"upgrade", help="Upgrade database to latest revision",
|
||||
)
|
||||
upgrade_parser.add_argument(
|
||||
"revision",
|
||||
nargs="?",
|
||||
default="head",
|
||||
help="Target revision (default: head)",
|
||||
)
|
||||
|
||||
# Downgrade command
|
||||
downgrade_parser = subparsers.add_parser("downgrade", help="Downgrade database")
|
||||
downgrade_parser.add_argument("revision", help="Target revision")
|
||||
|
||||
# Current command
|
||||
subparsers.add_parser("current", help="Show current revision")
|
||||
|
||||
# History command
|
||||
subparsers.add_parser("history", help="Show revision history")
|
||||
|
||||
# Generate migration command
|
||||
revision_parser = subparsers.add_parser("revision", help="Create new migration")
|
||||
revision_parser.add_argument(
|
||||
"-m", "--message", required=True, help="Migration message",
|
||||
)
|
||||
revision_parser.add_argument(
|
||||
"--autogenerate", action="store_true", help="Auto-generate migration",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# Get the alembic config
|
||||
config_path = Path("alembic.ini")
|
||||
if not config_path.exists():
|
||||
logger.error("Error: alembic.ini not found. Run from the backend directory.")
|
||||
sys.exit(1)
|
||||
|
||||
alembic_cfg = Config(str(config_path))
|
||||
|
||||
try:
|
||||
if args.command == "upgrade":
|
||||
command.upgrade(alembic_cfg, args.revision)
|
||||
logger.info(
|
||||
"Successfully upgraded database to revision: %s", args.revision,
|
||||
)
|
||||
|
||||
elif args.command == "downgrade":
|
||||
command.downgrade(alembic_cfg, args.revision)
|
||||
logger.info(
|
||||
"Successfully downgraded database to revision: %s", args.revision,
|
||||
)
|
||||
|
||||
elif args.command == "current":
|
||||
command.current(alembic_cfg)
|
||||
|
||||
elif args.command == "history":
|
||||
command.history(alembic_cfg)
|
||||
|
||||
elif args.command == "revision":
|
||||
if args.autogenerate:
|
||||
command.revision(alembic_cfg, message=args.message, autogenerate=True)
|
||||
logger.info("Created new auto-generated migration: %s", args.message)
|
||||
else:
|
||||
command.revision(alembic_cfg, message=args.message)
|
||||
logger.info("Created new empty migration: %s", args.message)
|
||||
|
||||
except (OSError, RuntimeError):
|
||||
logger.exception("Error occurred during migration")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,34 +1,42 @@
|
||||
[project]
|
||||
name = "backend"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
name = "sdb"
|
||||
version = "2.0.0"
|
||||
description = "Backend for the SDB v2"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"aiosqlite==0.21.0",
|
||||
"bcrypt==4.3.0",
|
||||
"email-validator==2.2.0",
|
||||
"fastapi[standard]==0.116.1",
|
||||
"alembic==1.16.5",
|
||||
"apscheduler==3.11.0",
|
||||
"bcrypt==5.0.0",
|
||||
"email-validator==2.3.0",
|
||||
"fastapi[standard]==0.118.0",
|
||||
"ffmpeg-python==0.2.0",
|
||||
"gtts==2.5.4",
|
||||
"httpx==0.28.1",
|
||||
"pydantic-settings==2.10.1",
|
||||
"pydantic-settings==2.11.0",
|
||||
"pyjwt==2.10.1",
|
||||
"python-socketio==5.13.0",
|
||||
"python-socketio==5.14.1",
|
||||
"pytz==2025.2",
|
||||
"python-vlc==3.0.21203",
|
||||
"sqlmodel==0.0.24",
|
||||
"uvicorn[standard]==0.35.0",
|
||||
"yt-dlp==2025.7.21",
|
||||
"sqlmodel==0.0.25",
|
||||
"uvicorn[standard]==0.37.0",
|
||||
"yt-dlp==2025.9.26",
|
||||
"asyncpg==0.30.0",
|
||||
"psycopg[binary]==3.2.10",
|
||||
"pycaw>=20240210",
|
||||
"pulsectl>=24.12.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"coverage==7.10.1",
|
||||
"faker==37.4.2",
|
||||
"coverage==7.10.7",
|
||||
"faker==37.8.0",
|
||||
"httpx==0.28.1",
|
||||
"mypy==1.17.0",
|
||||
"pytest==8.4.1",
|
||||
"pytest-asyncio==1.1.0",
|
||||
"ruff==0.12.6",
|
||||
"mypy==1.18.2",
|
||||
"pytest==8.4.2",
|
||||
"pytest-asyncio==1.2.0",
|
||||
"ruff==0.13.3",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
@@ -43,10 +51,31 @@ exclude = ["alembic"]
|
||||
select = ["ALL"]
|
||||
ignore = ["D100", "D103", "TRY301"]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"tests/**/*.py" = ["S101", "S105"]
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = [
|
||||
"S101", # Use of assert detected
|
||||
"S105", # Possible hardcoded password
|
||||
"S106", # Possible hardcoded password
|
||||
"ANN001", # Missing type annotation for function argument
|
||||
"ANN003", # Missing type annotation for **kwargs
|
||||
"ANN201", # Missing return type annotation for public function
|
||||
"ANN202", # Missing return type annotation for private function
|
||||
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
|
||||
"ARG001", # Unused function argument
|
||||
"ARG002", # Unused method argument
|
||||
"ARG005", # Unused lambda argument
|
||||
"BLE001", # Do not catch blind exception
|
||||
"E501", # Line too long
|
||||
"PLR2004", # Magic value used in comparison
|
||||
"PLC0415", # `import` should be at top-level
|
||||
"SLF001", # Private member accessed
|
||||
"SIM117", # Use a single `if` statement
|
||||
"PT011", # `pytest.raises()` is too broad
|
||||
"PT012", # `pytest.raises()` block should contain a single simple statement
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
filterwarnings = [
|
||||
"ignore:transaction already deassociated from connection:sqlalchemy.exc.SAWarning",
|
||||
]
|
||||
|
||||
55
reset.sh
Executable file
55
reset.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Reset script for SDB2 soundboard application
|
||||
# This script removes the database and cleans extracted sounds
|
||||
|
||||
set -e # Exit on any error
|
||||
|
||||
echo "🔄 Resetting SDB2 application..."
|
||||
|
||||
# Change to backend directory
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# Remove database file if it exists
|
||||
if [ -f "data/soundboard.db" ]; then
|
||||
echo "🗑️ Removing database: data/soundboard.db"
|
||||
rm "data/soundboard.db"
|
||||
else
|
||||
echo "ℹ️ Database file not found, skipping"
|
||||
fi
|
||||
|
||||
# List of folders to clean (only files will be deleted, preserving .gitignore)
|
||||
FOLDERS_TO_CLEAN=(
|
||||
"sounds/originals/extracted"
|
||||
"sounds/originals/extracted/thumbnails"
|
||||
"sounds/originals/text_to_speech"
|
||||
"sounds/normalized/extracted"
|
||||
"sounds/normalized/soundboard"
|
||||
"sounds/normalized/text_to_speech"
|
||||
"sounds/temp"
|
||||
)
|
||||
|
||||
# Function to clean files in a directory
|
||||
clean_folder() {
|
||||
local folder="$1"
|
||||
|
||||
if [ -d "$folder" ]; then
|
||||
echo "🧹 Cleaning folder: $folder"
|
||||
|
||||
# Find and delete all files except .gitignore (preserving subdirectories)
|
||||
find "$folder" -maxdepth 1 -type f -not -name '.gitignore' -delete
|
||||
|
||||
echo "✅ Folder cleaned: $folder"
|
||||
else
|
||||
echo "ℹ️ Folder not found, skipping: $folder"
|
||||
fi
|
||||
}
|
||||
|
||||
# Clean all specified folders
|
||||
echo "🧹 Cleaning specified folders..."
|
||||
for folder in "${FOLDERS_TO_CLEAN[@]}"; do
|
||||
clean_folder "$folder"
|
||||
done
|
||||
|
||||
echo "✅ Application reset complete!"
|
||||
echo "💡 Run 'uv run python run.py' to start fresh"
|
||||
1
tests/api/v1/admin/__init__.py
Normal file
1
tests/api/v1/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for admin API endpoints."""
|
||||
154
tests/api/v1/admin/test_extraction_endpoints.py
Normal file
154
tests/api/v1/admin/test_extraction_endpoints.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Tests for admin extraction API endpoints."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models.extraction import Extraction
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class TestAdminExtractionEndpoints:
|
||||
"""Test admin extraction endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_extraction_processor_status(self, authenticated_admin_client):
|
||||
"""Test getting extraction processor status."""
|
||||
response = await authenticated_admin_client.get(
|
||||
"/api/v1/admin/extractions/status",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Check expected status fields (match actual processor status format)
|
||||
assert "currently_processing" in data
|
||||
assert "max_concurrent" in data
|
||||
assert "available_slots" in data
|
||||
assert "processing_ids" in data
|
||||
assert isinstance(data["currently_processing"], int)
|
||||
assert isinstance(data["max_concurrent"], int)
|
||||
assert isinstance(data["available_slots"], int)
|
||||
assert isinstance(data["processing_ids"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_extraction_success(
|
||||
self,
|
||||
authenticated_admin_client,
|
||||
test_session,
|
||||
test_plan,
|
||||
):
|
||||
"""Test admin successfully deleting any extraction."""
|
||||
# Create a test user
|
||||
user = User(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
is_active=True,
|
||||
plan_id=test_plan.id,
|
||||
)
|
||||
test_session.add(user)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(user)
|
||||
|
||||
# Create test extraction
|
||||
extraction = Extraction(
|
||||
url="https://example.com/video",
|
||||
user_id=user.id,
|
||||
status="completed",
|
||||
)
|
||||
test_session.add(extraction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(extraction)
|
||||
|
||||
# Admin delete the extraction
|
||||
response = await authenticated_admin_client.delete(
|
||||
f"/api/v1/admin/extractions/{extraction.id}",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
|
||||
|
||||
# Verify extraction was deleted from database
|
||||
deleted_extraction = await test_session.get(Extraction, extraction.id)
|
||||
assert deleted_extraction is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_extraction_not_found(self, authenticated_admin_client):
|
||||
"""Test admin deleting non-existent extraction."""
|
||||
response = await authenticated_admin_client.delete(
|
||||
"/api/v1/admin/extractions/999",
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "not found" in data["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_extraction_any_user(
|
||||
self,
|
||||
authenticated_admin_client,
|
||||
test_session,
|
||||
test_plan,
|
||||
):
|
||||
"""Test admin deleting extraction owned by any user."""
|
||||
# Create another user and their extraction
|
||||
other_user = User(
|
||||
name="Other User",
|
||||
email="other@example.com",
|
||||
is_active=True,
|
||||
plan_id=test_plan.id,
|
||||
)
|
||||
test_session.add(other_user)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(other_user)
|
||||
|
||||
extraction = Extraction(
|
||||
url="https://example.com/video",
|
||||
user_id=other_user.id,
|
||||
status="completed",
|
||||
)
|
||||
test_session.add(extraction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(extraction)
|
||||
|
||||
# Admin can delete any user's extraction
|
||||
response = await authenticated_admin_client.delete(
|
||||
f"/api/v1/admin/extractions/{extraction.id}",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_non_admin(self, authenticated_client, test_user, test_session):
|
||||
"""Test non-admin user cannot access admin deletion endpoint."""
|
||||
# Create test extraction
|
||||
extraction = Extraction(
|
||||
url="https://example.com/video",
|
||||
user_id=test_user.id,
|
||||
status="completed",
|
||||
)
|
||||
test_session.add(extraction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(extraction)
|
||||
|
||||
# Non-admin user cannot access admin endpoint
|
||||
response = await authenticated_client.delete(
|
||||
f"/api/v1/admin/extractions/{extraction.id}",
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
data = response.json()
|
||||
assert "permissions" in data["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_endpoints_unauthenticated(self, client: AsyncClient):
|
||||
"""Test admin endpoints require authentication."""
|
||||
# Status endpoint
|
||||
response = await client.get("/api/v1/admin/extractions/status")
|
||||
assert response.status_code == 401
|
||||
|
||||
# Delete endpoint
|
||||
response = await client.delete("/api/v1/admin/extractions/1")
|
||||
assert response.status_code == 401
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user