Compare commits

..

110 Commits

Author SHA1 Message Date
JSC
e4c72f3b19 chore: Remove unused .env.template and SCHEDULER_EXAMPLE.md files
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Failing after 1m39s
2025-10-05 16:33:29 +02:00
JSC
17eafa4872 feat: Enhance play_next functionality by storing and restoring playlist index
Some checks failed
Backend CI / test (push) Failing after 2m17s
Backend CI / lint (push) Failing after 14m55s
2025-10-05 04:07:34 +02:00
JSC
c9f6bff723 refactor: Improve code readability by formatting function signatures and descriptions
Some checks failed
Backend CI / lint (push) Failing after 9s
Backend CI / test (push) Failing after 1m29s
2025-10-04 22:27:12 +02:00
JSC
12243b1424 feat: Clear and manage play_next queue on playlist changes
Some checks failed
Backend CI / lint (push) Failing after 9s
Backend CI / test (push) Failing after 1m36s
2025-10-04 19:39:44 +02:00
JSC
f7197a89a7 feat: Add play next functionality to player service and API 2025-10-04 19:16:37 +02:00
JSC
b66b8e36bb feat: Enhance user metrics retrieval by integrating Extraction model and updating related queries
Some checks failed
Backend CI / lint (push) Failing after 17s
Backend CI / test (push) Failing after 2m32s
2025-10-04 13:45:36 +02:00
JSC
95e166eefb feat: Add endpoint and service method to retrieve top users by various metrics
Some checks failed
Backend CI / lint (push) Failing after 9s
Backend CI / test (push) Failing after 1m36s
2025-09-27 21:52:00 +02:00
JSC
d9697c2dd7 feat: Add TTS statistics endpoint and service method for comprehensive TTS data 2025-09-27 21:37:59 +02:00
JSC
7b59a8216a fix: Correct import formatting for CreditService in VLCPlayerService
Some checks failed
Backend CI / lint (push) Failing after 9s
Backend CI / test (push) Failing after 1m28s
2025-09-27 03:34:19 +02:00
JSC
4b8496d025 feat: Implement host system volume control and update player service to use it
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Has been cancelled
2025-09-27 03:33:11 +02:00
JSC
0806d541f2 Upgrade packages
Some checks failed
Backend CI / lint (push) Failing after 11s
Backend CI / test (push) Failing after 1m59s
2025-09-27 02:32:59 +02:00
JSC
acdf191a5a refactor: Improve code readability and structure across TTS modules
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Failing after 1m36s
2025-09-21 19:07:32 +02:00
JSC
35b857fd0d feat: Add GitHub as an available OAuth provider and remove database initialization logs 2025-09-21 18:58:20 +02:00
JSC
c13e18c290 feat: Implement playlist sound deletion and update current playlist logic on deletion
Some checks failed
Backend CI / lint (push) Failing after 9s
Backend CI / test (push) Failing after 1m34s
2025-09-21 18:32:48 +02:00
JSC
702d7ee577 Merge branch 'tts'
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Failing after 1m35s
2025-09-21 18:19:26 +02:00
JSC
d3b6e90262 style: Format code for consistency and readability across TTS modules 2025-09-21 18:05:20 +02:00
JSC
50eeae4c62 refactor: Clean up TTSService methods for improved readability and consistency 2025-09-21 15:38:35 +02:00
JSC
e005dedcd3 refactor: Update supported languages list in GTTSProvider and remove TLD option from schema 2025-09-21 15:20:23 +02:00
JSC
72ddd98b25 feat: Add status and error fields to TTS model and implement background processing for TTS generations 2025-09-21 14:39:41 +02:00
JSC
b2e513a915 feat: Add endpoint to retrieve TTS history for the current user and improve request model formatting 2025-09-21 13:55:24 +02:00
JSC
c8b796aa94 refactor: Simplify TTS API endpoints by removing specific paths for generate and history 2025-09-21 13:38:12 +02:00
JSC
d5f9a3c736 feat: Run database migrations in a thread pool to avoid blocking during initialization 2025-09-21 13:21:23 +02:00
JSC
2b61d35d6a chore: Update dependencies for fastapi, faker, sqlmodel, and uvicorn; add gtts and charset-normalizer packages 2025-09-20 23:10:59 +02:00
JSC
5e8d619736 feat: Implement Text-to-Speech (TTS) functionality with API endpoints, models, and service integration 2025-09-20 23:10:47 +02:00
JSC
fb0e5e919c fix: Remove GitHub from available OAuth providers list
Some checks failed
Backend CI / lint (push) Failing after 11s
Backend CI / test (push) Failing after 1m34s
2025-09-20 21:11:50 +02:00
JSC
bccfcafe0e feat: Update CORS origins to allow Chrome extensions and improve logging in migration tool
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Failing after 1m37s
2025-09-19 16:41:11 +02:00
JSC
1bef694f38 feat: Enhance play_sound method to accept volume parameter and retrieve current volume
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Failing after 1m33s
2025-09-18 13:57:54 +02:00
JSC
b87a47f199 fix: Update PostgreSQL database URL for Alembic to use psycopg driver
Some checks failed
Backend CI / lint (push) Failing after 12s
Backend CI / test (push) Failing after 1m33s
2025-09-18 13:14:01 +02:00
JSC
83239cb4fa Add Alembic for database migrations and initial migration scripts
- Created alembic.ini configuration file for Alembic migrations.
- Added README file for Alembic with a brief description.
- Implemented env.py for Alembic to manage database migrations.
- Created script.py.mako template for migration scripts.
- Added initial migration script to create database tables.
- Created a migration script to add initial plan and playlist data.
- Updated database initialization to run Alembic migrations.
- Enhanced credit service to automatically recharge user credits based on their plan.
- Implemented delete_task method in scheduler service to remove scheduled tasks.
- Updated scheduler API to reflect task deletion instead of cancellation.
- Added CLI tool for managing database migrations.
- Updated tests to cover new functionality for task deletion and credit recharge.
- Updated pyproject.toml and lock files to include Alembic as a dependency.
2025-09-16 13:45:14 +02:00
JSC
e8f979c137 feat: Add MINUTELY recurrence type and enhance scheduler handling 2025-09-13 23:44:20 +02:00
JSC
92571f4de9 Refactor code structure for improved readability and maintainability
Some checks failed
Backend CI / lint (push) Failing after 1m29s
Backend CI / test (push) Failing after 1m36s
2025-09-13 22:38:55 +02:00
JSC
1388ede1dc Merge branch 'tasks3'
Some checks failed
Backend CI / lint (push) Successful in 9m32s
Backend CI / test (push) Failing after 4m47s
2025-08-29 23:12:45 +02:00
JSC
75569a60b5 fix: Improve logging for invalid player mode by using logger.exception 2025-08-29 15:44:11 +02:00
JSC
2bdd109492 Refactor code structure for improved readability and maintainability 2025-08-29 15:27:12 +02:00
JSC
dc89e45675 Refactor scheduled task repository and schemas for improved type hints and consistency
- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase.
- Refactored import statements for better organization and clarity.
- Enhanced the ScheduledTaskBase schema to use modern type hints.
- Cleaned up unnecessary comments and whitespace in various files.
- Improved error handling and logging in task execution handlers.
- Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
2025-08-28 23:38:47 +02:00
JSC
96801dc4d6 feat: Refactor TaskHandlerRegistry to include db_session_factory and enhance sound playback handling for user tasks 2025-08-28 23:36:30 +02:00
JSC
6e74d9b940 feat: Add load_playlist method to PlayerService and update task handlers for playlist management 2025-08-28 22:50:57 +02:00
JSC
03abed6d39 Add comprehensive tests for scheduled task repository, scheduler service, and task handlers
- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates.
- Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs.
- Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback.
- Ensured proper error handling and edge cases in task execution scenarios.
- Added fixtures and mocks to facilitate isolated testing of services and repositories.
2025-08-28 22:37:43 +02:00
JSC
7dee6e320e Add tests for extraction API endpoints and enhance existing tests
Some checks failed
Backend CI / lint (push) Successful in 9m25s
Backend CI / test (push) Failing after 4m48s
- Implement tests for admin extraction API endpoints including status retrieval, deletion of extractions, and permission checks.
- Add tests for user extraction deletion, ensuring proper handling of permissions and non-existent extractions.
- Enhance sound endpoint tests to include duplicate handling in responses.
- Refactor favorite service tests to utilize mock dependencies for better maintainability and clarity.
- Update sound scanner tests to improve file handling and ensure proper deletion of associated files.
2025-08-25 21:40:31 +02:00
JSC
d3ce17f10d feat: Enhance SoundScannerService with duplicate detection and normalized file handling
Some checks failed
Backend CI / lint (push) Failing after 4m52s
Backend CI / test (push) Failing after 4m39s
2025-08-25 12:33:10 +02:00
JSC
da66516bb3 feat: Implement hash-first identification strategy in audio file syncing and enhance tests for renamed files
Some checks failed
Backend CI / lint (push) Failing after 4m55s
Backend CI / test (push) Failing after 4m32s
2025-08-25 11:56:07 +02:00
JSC
d81a54207c feat: Add endpoint to retrieve currently processing extractions and corresponding tests
Some checks failed
Backend CI / lint (push) Failing after 4m54s
Backend CI / test (push) Failing after 4m39s
2025-08-24 13:44:01 +02:00
JSC
16eb789539 feat: Add method to get extractions by status and implement user info retrieval in extraction service
Some checks failed
Backend CI / lint (push) Failing after 4m53s
Backend CI / test (push) Failing after 4m31s
2025-08-24 13:24:48 +02:00
JSC
28faca55bc Refactor code structure for improved readability and maintainability
Some checks failed
Backend CI / lint (push) Failing after 5m1s
Backend CI / test (push) Failing after 4m30s
2025-08-22 21:18:04 +02:00
JSC
821093f64f Refactor code structure for improved readability and maintainability
Some checks failed
Backend CI / lint (push) Failing after 4m51s
Backend CI / test (push) Failing after 4m35s
2025-08-20 11:37:28 +02:00
JSC
9653062003 refactor: Move imports to avoid circular dependencies in socket and VLCPlayerService
Some checks failed
Backend CI / lint (push) Successful in 9m24s
Backend CI / test (push) Failing after 3m55s
2025-08-19 22:32:19 +02:00
JSC
b808cfaddf feat: Enhance WebSocket sound playback with credit validation and refactor related methods
Some checks failed
Backend CI / lint (push) Has been cancelled
Backend CI / test (push) Has been cancelled
2025-08-19 22:28:54 +02:00
JSC
a82acfae50 feat: Implement sound playback with credit validation in VLCPlayerService and update WebSocket handling
Some checks failed
Backend CI / lint (push) Failing after 5m0s
Backend CI / test (push) Failing after 2m0s
2025-08-19 22:16:48 +02:00
JSC
560ccd3f7e refactor: Improve code readability by formatting query parameters in user endpoints and enhancing error handling in sound playback 2025-08-19 22:09:50 +02:00
JSC
a660cc1861 Merge branch 'favorite'
Some checks failed
Backend CI / lint (push) Successful in 9m21s
Backend CI / test (push) Failing after 3m59s
2025-08-17 13:25:59 +02:00
JSC
6b55ff0e81 Refactor user endpoint tests to include pagination and response structure validation
- Updated tests for listing users to validate pagination and response format.
- Changed mock return values to include total count and pagination details.
- Refactored user creation mocks for clarity and consistency.
- Enhanced assertions to check for presence of pagination fields in responses.
- Adjusted test cases for user retrieval and updates to ensure proper handling of user data.
- Improved readability by restructuring mock definitions and assertions across various test files.
2025-08-17 12:36:52 +02:00
JSC
e6f796a3c9 feat: Add pagination, search, and filter functionality to user retrieval endpoint 2025-08-17 11:44:15 +02:00
JSC
99c757a073 feat: Implement pagination for extractions and playlists with total count in responses 2025-08-17 11:21:55 +02:00
JSC
f598ec2c12 fix: Extract user name in session context for improved performance 2025-08-17 01:49:47 +02:00
JSC
66d22df7dd feat: Add filtering, searching, and sorting to extraction retrieval endpoints 2025-08-17 01:44:43 +02:00
JSC
3326e406f8 feat: Add filtering, searching, and sorting to user extractions retrieval 2025-08-17 01:27:41 +02:00
JSC
fe15e7a6af fix: Correct log message for sound favorited event broadcasting 2025-08-17 01:08:33 +02:00
JSC
f56cc8b4cc feat: Enhance sound favorite management; add WebSocket event broadcasting for favoriting and unfavoriting sounds 2025-08-16 22:19:24 +02:00
JSC
f906b6d643 feat: Enhance favorites functionality; add favorites filtering to playlists and sounds, and improve favorite indicators in responses 2025-08-16 21:41:50 +02:00
JSC
78508c84eb feat: Add favorites filtering to sound retrieval; include user-specific favorite sounds in the API response 2025-08-16 21:27:40 +02:00
JSC
a947fd830b feat: Implement favorites management API; add endpoints for adding, removing, and retrieving favorites for sounds and playlists
feat: Create Favorite model and repository for managing user favorites in the database
feat: Add FavoriteService to handle business logic for favorites management
feat: Enhance Playlist and Sound response schemas to include favorite indicators and counts
refactor: Update API routes to include favorites functionality in playlists and sounds
2025-08-16 21:16:02 +02:00
JSC
5e6cc04ad2 fix: Increase broadcast interval to 1 second while playing
All checks were successful
Backend CI / lint (push) Successful in 9m23s
Backend CI / test (push) Successful in 3m47s
2025-08-16 12:24:41 +02:00
JSC
c27530a25f refactor: Remove unused variable main_playlist_id from test cases in TestPlaylistService
All checks were successful
Backend CI / lint (push) Successful in 9m24s
Backend CI / test (push) Successful in 3m48s
2025-08-16 00:54:41 +02:00
JSC
a109a88eed feat: Implement main playlist restrictions; add internal method for sound addition and update tests
Some checks failed
Backend CI / test (push) Has been cancelled
Backend CI / lint (push) Has been cancelled
2025-08-16 00:51:38 +02:00
JSC
4cec3b9d18 feat: Enhance timestamp management in BaseModel and PlaylistRepository; add automatic updates and improve code readability
All checks were successful
Backend CI / lint (push) Successful in 9m21s
Backend CI / test (push) Successful in 4m0s
2025-08-16 00:19:53 +02:00
JSC
b691649f7e feat: Implement automatic updated_at timestamp management in BaseModel and update BaseRepository to reflect changes
Some checks failed
Backend CI / lint (push) Failing after 5m0s
Backend CI / test (push) Successful in 3m46s
2025-08-16 00:07:15 +02:00
JSC
87d6e6ed67 feat: Update API documentation endpoints and enhance application metadata for SBD v2
All checks were successful
Backend CI / lint (push) Successful in 9m22s
Backend CI / test (push) Successful in 3m54s
2025-08-13 13:56:01 +02:00
JSC
bee1076239 refactor: Improve exception handling and logging in authentication and playlist services; enhance code readability and structure
All checks were successful
Backend CI / lint (push) Successful in 9m21s
Backend CI / test (push) Successful in 4m18s
2025-08-13 00:04:55 +02:00
JSC
f094fbf140 fix: Add missing commas in function calls and improve code formatting
Some checks failed
Backend CI / lint (push) Failing after 4m51s
Backend CI / test (push) Successful in 4m19s
2025-08-12 23:37:38 +02:00
JSC
d3d7edb287 feat: Add tests for dashboard service including statistics and date filters
Some checks failed
Backend CI / lint (push) Has been cancelled
Backend CI / test (push) Has been cancelled
2025-08-12 23:34:02 +02:00
JSC
cba1653565 feat: Update player state tests to include previous volume and adjust volume assertions
Some checks failed
Backend CI / lint (push) Failing after 5m0s
Backend CI / test (push) Successful in 3m46s
2025-08-12 22:58:47 +02:00
JSC
c69a45c9b4 feat: Add endpoint and service method to retrieve top sounds by play count with filtering options
Some checks failed
Backend CI / lint (push) Failing after 4m54s
Backend CI / test (push) Failing after 4m24s
2025-08-11 22:04:42 +02:00
JSC
53b6c4bca5 feat: Enhance sound addition and removal in playlists with position handling and reordering
Some checks failed
Backend CI / lint (push) Failing after 4m55s
Backend CI / test (push) Failing after 3m44s
2025-08-11 20:55:31 +02:00
JSC
49ad6c8581 feat: Add dashboard API endpoints and service for sound statistics
Some checks failed
Backend CI / lint (push) Failing after 4m52s
Backend CI / test (push) Failing after 3m42s
2025-08-11 11:16:45 +02:00
JSC
bb1f036caa Merge branch 'add_sound_to_playlist_dnd'
Some checks failed
Backend CI / lint (push) Failing after 4m53s
Backend CI / test (push) Failing after 3m40s
2025-08-11 09:41:35 +02:00
JSC
d1bf2fe0a4 feat: Add scheduler for daily user credits recharge 2025-08-11 00:30:29 +02:00
JSC
bdeb00d562 feat: Increase default volume level to 80 and adjust volume handling in player service 2025-08-10 21:55:12 +02:00
JSC
13e0db1fe9 feat: Add position shifting logic for adding sounds to playlists in repository 2025-08-10 21:33:06 +02:00
JSC
357fbcecac feat: Implement search and sorting functionality for playlists in API and repository
Some checks failed
Backend CI / lint (push) Failing after 4m54s
Backend CI / test (push) Failing after 4m25s
2025-08-10 19:30:14 +02:00
JSC
aa9a73ac1d feat: Add search and sorting functionality to sound repository and API
Some checks failed
Backend CI / lint (push) Failing after 4m54s
Backend CI / test (push) Failing after 3m46s
2025-08-10 15:33:15 +02:00
JSC
8544a3ce22 feat: Add mute and unmute functionality to player service and API
Some checks failed
Backend CI / lint (push) Failing after 4m52s
Backend CI / test (push) Failing after 3m44s
2025-08-10 15:11:28 +02:00
JSC
0a8b50a0be feat: Add user profile management and password change endpoints
Some checks failed
Backend CI / lint (push) Failing after 4m51s
Backend CI / test (push) Successful in 3m38s
2025-08-09 23:43:20 +02:00
JSC
9e07ce393f feat: Implement admin user management endpoints and user update schema 2025-08-09 22:37:51 +02:00
JSC
734521c5c3 feat: Add environment configuration files and update settings for production and development
Some checks failed
Backend CI / lint (push) Failing after 5m0s
Backend CI / test (push) Successful in 3m39s
2025-08-09 14:43:20 +02:00
JSC
69544b6bb8 feat: Refactor cookie handling to use utility functions for setting access and refresh tokens
All checks were successful
Backend CI / lint (push) Successful in 9m30s
Backend CI / test (push) Successful in 3m31s
2025-08-08 10:06:45 +02:00
JSC
b4f0f54516 Refactor sound and extraction services to include user and timestamp fields
All checks were successful
Backend CI / lint (push) Successful in 18m8s
Backend CI / test (push) Successful in 53m35s
- Updated ExtractionInfo to include user_id, created_at, and updated_at fields.
- Modified ExtractionService to return user and timestamp information in extraction responses.
- Enhanced sound serialization in PlayerState to include extraction URL if available.
- Adjusted PlaylistRepository to load sound extractions when retrieving playlist sounds.
- Added tests for new fields in extraction and sound endpoints, ensuring proper response structure.
- Created new test file endpoints for sound downloads and thumbnail retrievals, including success and error cases.
- Refactored various test cases for consistency and clarity, ensuring proper mocking and assertions.
2025-08-03 20:54:14 +02:00
JSC
77446cb5a8 feat: Include admin user name in SoundPlayed records for enhanced tracking
Some checks failed
Backend CI / lint (push) Successful in 16m41s
Backend CI / test (push) Failing after 47m14s
2025-08-02 18:22:38 +02:00
JSC
4bbae4c5d4 feat: Add endpoint to retrieve sounds with optional type filtering and implement corresponding repository method
Some checks failed
Backend CI / lint (push) Successful in 9m41s
Backend CI / test (push) Failing after 1m39s
2025-08-01 22:03:09 +02:00
JSC
d2d0240fdb feat: Add audio extraction endpoints and refactor sound API routes 2025-08-01 21:39:42 +02:00
JSC
6068599a47 Refactor test cases for improved readability and consistency
All checks were successful
Backend CI / lint (push) Successful in 9m49s
Backend CI / test (push) Successful in 6m15s
- Adjusted function signatures in various test files to enhance clarity by aligning parameters.
- Updated patching syntax for better readability across test cases.
- Improved formatting and spacing in test assertions and mock setups.
- Ensured consistent use of async/await patterns in async test functions.
- Enhanced comments for better understanding of test intentions.
2025-08-01 20:53:30 +02:00
JSC
d926779fe4 feat: Implement playlist reordering with position swapping and reload player on current playlist changes
Some checks failed
Backend CI / lint (push) Failing after 5m7s
Backend CI / test (push) Successful in 5m14s
2025-08-01 17:49:29 +02:00
JSC
0575d12b0e refactor: Rename global current playlist methods for clarity and consistency 2025-08-01 17:12:56 +02:00
JSC
c0f51b2e23 refactor: Update playlist service and endpoints for global current playlist management 2025-08-01 16:58:25 +02:00
JSC
3132175354 refactor: Create admin enpoints and some renaming of api endpoints 2025-08-01 15:34:35 +02:00
JSC
43be92c8f9 fix: Update linter command in CI
All checks were successful
Backend CI / lint (push) Successful in 9m31s
Backend CI / test (push) Successful in 4m7s
2025-08-01 09:44:53 +02:00
JSC
f68f4d9046 refactor: Compiled ignored ruff rules in pyproject 2025-08-01 09:40:15 +02:00
JSC
fceff92ca1 fix: Lint fixes of last tests 2025-08-01 09:30:15 +02:00
JSC
dc29915fbc fix: Lint fixes of core and repositories tests
All checks were successful
Backend CI / lint (push) Successful in 9m26s
Backend CI / test (push) Successful in 4m24s
2025-08-01 09:17:20 +02:00
JSC
389cfe2d6a fix: Lint fixes of utils tests 2025-08-01 02:22:30 +02:00
JSC
502feea035 fix: Enable lint job in CI workflow
All checks were successful
Backend CI / lint (push) Successful in 9m26s
Backend CI / test (push) Successful in 4m4s
2025-08-01 02:09:45 +02:00
JSC
5fdc7aae85 fix: Lint fixes of last errors in app 2025-08-01 02:08:36 +02:00
JSC
69cdc7567d Refactor player service to diminish play complexity
All checks were successful
Backend CI / test (push) Successful in 4m53s
2025-08-01 01:34:22 +02:00
JSC
a10111793c fix: Lint fixes of services
All checks were successful
Backend CI / test (push) Successful in 3m59s
2025-08-01 01:27:47 +02:00
JSC
95ccb76233 fix: Lint fixes of api and repositories
All checks were successful
Backend CI / test (push) Successful in 3m58s
2025-07-31 22:29:11 +02:00
JSC
7ba52ad6fc fix: Lint fixes of core, models and schemas
All checks were successful
Backend CI / test (push) Successful in 4m5s
2025-07-31 22:06:31 +02:00
JSC
01bb48c206 fix: Utils lint fixes 2025-07-31 21:56:03 +02:00
JSC
8847131f24 Refactor test files for improved readability and consistency
- Removed unnecessary blank lines and adjusted formatting in test files.
- Ensured consistent use of commas in function calls and assertions across various test cases.
- Updated import statements for better organization and clarity.
- Enhanced mock setups in tests for better isolation and reliability.
- Improved assertions to follow a consistent style for better readability.
2025-07-31 21:37:04 +02:00
JSC
e69098d633 refactor: Update player seek functionality to use consistent position field across schemas and services
All checks were successful
Backend CI / test (push) Successful in 4m5s
2025-07-31 21:33:00 +02:00
JSC
3405d817d5 refactor: Simplify repository classes by inheriting from BaseRepository and removing redundant methods 2025-07-31 21:32:46 +02:00
JSC
c63997f591 refactor: Update PlayerState to improve serialization structure for current sound and playlist
Some checks failed
Backend CI / test (push) Failing after 4m2s
2025-07-31 21:01:40 +02:00
142 changed files with 20335 additions and 4764 deletions

55
.env.development.template Normal file
View File

@@ -0,0 +1,55 @@
# Development Environment Configuration
# Copy this file to .env for development setup
# Application Configuration
HOST=localhost
PORT=8000
RELOAD=true
# Development URLs (for local development)
FRONTEND_URL=http://localhost:8001
BACKEND_URL=http://localhost:8000
CORS_ORIGINS=["http://localhost:8001"]
# Database Configuration
DATABASE_URL=sqlite+aiosqlite:///data/soundboard.db
DATABASE_ECHO=false
# Logging Configuration
LOG_LEVEL=debug
LOG_FILE=logs/app.log
LOG_MAX_SIZE=10485760
LOG_BACKUP_COUNT=5
# JWT Configuration (Use a secure key even in development)
JWT_SECRET_KEY=development-secret-key-change-for-production
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
# Cookie Configuration (Development settings)
COOKIE_SECURE=false
COOKIE_SAMESITE=lax
COOKIE_DOMAIN=localhost
# OAuth2 Configuration (Get these from OAuth providers)
# Google: https://console.developers.google.com/
# Redirect URI: http://localhost:8000/api/v1/auth/google/callback
GOOGLE_CLIENT_ID=your-google-client-id
GOOGLE_CLIENT_SECRET=your-google-client-secret
# GitHub: https://github.com/settings/developers
# Redirect URI: http://localhost:8000/api/v1/auth/github/callback
GITHUB_CLIENT_ID=your-github-client-id
GITHUB_CLIENT_SECRET=your-github-client-secret
# Audio Normalization Configuration
NORMALIZED_AUDIO_FORMAT=mp3
NORMALIZED_AUDIO_BITRATE=256k
NORMALIZED_AUDIO_PASSES=2
# Audio Extraction Configuration
EXTRACTION_AUDIO_FORMAT=mp3
EXTRACTION_AUDIO_BITRATE=256k
EXTRACTION_TEMP_DIR=sounds/temp
EXTRACTION_THUMBNAILS_DIR=sounds/originals/extracted/thumbnails
EXTRACTION_MAX_CONCURRENT=2

50
.env.production.template Normal file
View File

@@ -0,0 +1,50 @@
# Production Environment Configuration
# Copy this file to .env and configure for your production environment
# Application Configuration
HOST=0.0.0.0
PORT=8000
RELOAD=false
# Production URLs (configure for your domain)
FRONTEND_URL=https://yourdomain.com
BACKEND_URL=https://yourdomain.com
CORS_ORIGINS=["https://yourdomain.com"]
# Database Configuration (consider using PostgreSQL in production)
DATABASE_URL=sqlite+aiosqlite:///data/soundboard.db
DATABASE_ECHO=false
# Logging Configuration
LOG_LEVEL=info
LOG_FILE=logs/app.log
LOG_MAX_SIZE=10485760
LOG_BACKUP_COUNT=5
# JWT Configuration (IMPORTANT: Generate secure keys for production)
JWT_SECRET_KEY=your-super-secure-secret-key-change-this-in-production
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
# Cookie Configuration (Production settings)
COOKIE_SECURE=true
COOKIE_SAMESITE=lax
COOKIE_DOMAIN= # Leave empty for same-origin cookies in production with reverse proxy
# OAuth2 Configuration (Configure with your OAuth providers)
GOOGLE_CLIENT_ID=your-google-client-id
GOOGLE_CLIENT_SECRET=your-google-client-secret
GITHUB_CLIENT_ID=your-github-client-id
GITHUB_CLIENT_SECRET=your-github-client-secret
# Audio Normalization Configuration
NORMALIZED_AUDIO_FORMAT=mp3
NORMALIZED_AUDIO_BITRATE=256k
NORMALIZED_AUDIO_PASSES=2
# Audio Extraction Configuration
EXTRACTION_AUDIO_FORMAT=mp3
EXTRACTION_AUDIO_BITRATE=256k
EXTRACTION_TEMP_DIR=sounds/temp
EXTRACTION_THUMBNAILS_DIR=sounds/originals/extracted/thumbnails
EXTRACTION_MAX_CONCURRENT=2

View File

@@ -1,29 +0,0 @@
# Application Configuration
HOST=localhost
PORT=8000
RELOAD=true
# Database Configuration
DATABASE_URL=sqlite+aiosqlite:///data/soundboard.db
DATABASE_ECHO=false
# Logging Configuration
LOG_LEVEL=info
LOG_FILE=logs/app.log
LOG_MAX_SIZE=10485760
LOG_BACKUP_COUNT=5
# JWT Configuration
JWT_SECRET_KEY=your-secret-key-change-in-production
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
# Cookie Configuration
COOKIE_SECURE=false
COOKIE_SAMESITE=lax
# OAuth2 Configuration
GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET=
GITHUB_CLIENT_ID=
GITHUB_CLIENT_SECRET=

View File

@@ -9,28 +9,28 @@ on:
- main
jobs:
# lint:
# runs-on: ubuntu-latest
lint:
runs-on: ubuntu-latest
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
steps:
- name: Checkout code
uses: actions/checkout@v4
# - name: "Set up python"
# uses: actions/setup-python@v5
# with:
# python-version-file: "pyproject.toml"
- name: "Set up python"
uses: actions/setup-python@v5
with:
python-version-file: "pyproject.toml"
# - name: Install uv
# uses: astral-sh/setup-uv@v6
# with:
# enable-cache: true
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
# - name: Install requirements
# run: uv sync --locked --all-extras --dev
- name: Install requirements
run: uv sync --locked --all-extras --dev
# - name: Run linter
# run: uv run ruff check
- name: Run linter
run: uv run ruff check
test:
runs-on: ubuntu-latest

4
.gitignore vendored
View File

@@ -8,4 +8,6 @@ wheels/
# Virtual environments
.venv
.env
.env
.coverage

148
alembic.ini Normal file
View File

@@ -0,0 +1,148 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
# sqlalchemy.url = driver://user:pass@localhost/dbname
# URL will be set dynamically in env.py from config
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

86
alembic/env.py Normal file
View File

@@ -0,0 +1,86 @@
import asyncio
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from sqlalchemy.ext.asyncio import create_async_engine
from alembic import context
import app.models # noqa: F401
from app.core.config import settings
from sqlmodel import SQLModel
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Set the database URL from settings - convert async URL to sync for alembic
sync_db_url = settings.DATABASE_URL.replace("sqlite+aiosqlite", "sqlite")
sync_db_url = sync_db_url.replace("postgresql+asyncpg", "postgresql+psycopg")
config.set_main_option("sqlalchemy.url", sync_db_url)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = SQLModel.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,34 @@
"""Add status and error fields to TTS table
Revision ID: 0d9b7f1c367f
Revises: e617c155eea9
Create Date: 2025-09-21 14:09:56.418372
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '0d9b7f1c367f'
down_revision: Union[str, Sequence[str], None] = 'e617c155eea9'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('tts', sa.Column('status', sa.String(), nullable=False, server_default='pending'))
op.add_column('tts', sa.Column('error', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('tts', 'error')
op.drop_column('tts', 'status')
# ### end Alembic commands ###

View File

@@ -0,0 +1,222 @@
"""Initial migration
Revision ID: 7aa9892ceff3
Revises:
Create Date: 2025-09-16 13:16:58.233360
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '7aa9892ceff3'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('plan',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('code', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('credits', sa.Integer(), nullable=False),
sa.Column('max_credits', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_plan_code'), 'plan', ['code'], unique=True)
op.create_table('sound',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('filename', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('duration', sa.Integer(), nullable=False),
sa.Column('size', sa.Integer(), nullable=False),
sa.Column('hash', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('normalized_filename', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('normalized_duration', sa.Integer(), nullable=True),
sa.Column('normalized_size', sa.Integer(), nullable=True),
sa.Column('normalized_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('thumbnail', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('play_count', sa.Integer(), nullable=False),
sa.Column('is_normalized', sa.Boolean(), nullable=False),
sa.Column('is_music', sa.Boolean(), nullable=False),
sa.Column('is_deletable', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('hash', name='uq_sound_hash')
)
op.create_table('user',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('plan_id', sa.Integer(), nullable=False),
sa.Column('role', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('picture', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('password_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('credits', sa.Integer(), nullable=False),
sa.Column('api_token', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('api_token_expires_at', sa.DateTime(), nullable=True),
sa.Column('refresh_token_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('refresh_token_expires_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['plan_id'], ['plan.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('api_token'),
sa.UniqueConstraint('email')
)
op.create_table('credit_transaction',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('action_type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('amount', sa.Integer(), nullable=False),
sa.Column('balance_before', sa.Integer(), nullable=False),
sa.Column('balance_after', sa.Integer(), nullable=False),
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('success', sa.Boolean(), nullable=False),
sa.Column('metadata_json', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('extraction',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('service', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('service_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('sound_id', sa.Integer(), nullable=True),
sa.Column('url', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('title', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('track', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('artist', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('album', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('genre', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('status', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('error', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('playlist',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('genre', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('is_main', sa.Boolean(), nullable=False),
sa.Column('is_current', sa.Boolean(), nullable=False),
sa.Column('is_deletable', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name')
)
op.create_table('scheduled_task',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('task_type', sa.Enum('CREDIT_RECHARGE', 'PLAY_SOUND', 'PLAY_PLAYLIST', name='tasktype'), nullable=False),
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='taskstatus'), nullable=False),
sa.Column('scheduled_at', sa.DateTime(), nullable=False),
sa.Column('timezone', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('recurrence_type', sa.Enum('NONE', 'MINUTELY', 'HOURLY', 'DAILY', 'WEEKLY', 'MONTHLY', 'YEARLY', 'CRON', name='recurrencetype'), nullable=False),
sa.Column('cron_expression', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('recurrence_count', sa.Integer(), nullable=True),
sa.Column('executions_count', sa.Integer(), nullable=False),
sa.Column('parameters', sa.JSON(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('last_executed_at', sa.DateTime(), nullable=True),
sa.Column('next_execution_at', sa.DateTime(), nullable=True),
sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('expires_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('sound_played',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('sound_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user_oauth',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('provider_user_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('picture', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_user_oauth_provider_user_id')
)
op.create_table('favorite',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('sound_id', sa.Integer(), nullable=True),
sa.Column('playlist_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['playlist_id'], ['playlist.id'], ),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('user_id', 'playlist_id', name='uq_favorite_user_playlist'),
sa.UniqueConstraint('user_id', 'sound_id', name='uq_favorite_user_sound')
)
op.create_table('playlist_sound',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('playlist_id', sa.Integer(), nullable=False),
sa.Column('sound_id', sa.Integer(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['playlist_id'], ['playlist.id'], ),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('playlist_id', 'position', name='uq_playlist_sound_playlist_position'),
sa.UniqueConstraint('playlist_id', 'sound_id', name='uq_playlist_sound_playlist_sound')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('playlist_sound')
op.drop_table('favorite')
op.drop_table('user_oauth')
op.drop_table('sound_played')
op.drop_table('scheduled_task')
op.drop_table('playlist')
op.drop_table('extraction')
op.drop_table('credit_transaction')
op.drop_table('user')
op.drop_table('sound')
op.drop_index(op.f('ix_plan_code'), table_name='plan')
op.drop_table('plan')
# ### end Alembic commands ###

View File

@@ -0,0 +1,106 @@
"""Add initial plan and playlist data
Revision ID: a0d322857b2c
Revises: 7aa9892ceff3
Create Date: 2025-09-16 13:23:31.682276
"""
from typing import Sequence, Union
from datetime import datetime
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'a0d322857b2c'
down_revision: Union[str, Sequence[str], None] = '7aa9892ceff3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema and add initial data."""
# Get the current timestamp
now = datetime.utcnow()
# Insert initial plans
plans_table = sa.table(
'plan',
sa.column('code', sa.String),
sa.column('name', sa.String),
sa.column('description', sa.String),
sa.column('credits', sa.Integer),
sa.column('max_credits', sa.Integer),
sa.column('created_at', sa.DateTime),
sa.column('updated_at', sa.DateTime),
)
op.bulk_insert(
plans_table,
[
{
'code': 'free',
'name': 'Free Plan',
'description': 'Basic free plan with limited features',
'credits': 25,
'max_credits': 75,
'created_at': now,
'updated_at': now,
},
{
'code': 'premium',
'name': 'Premium Plan',
'description': 'Premium plan with more features',
'credits': 50,
'max_credits': 150,
'created_at': now,
'updated_at': now,
},
{
'code': 'pro',
'name': 'Pro Plan',
'description': 'Pro plan with unlimited features',
'credits': 100,
'max_credits': 300,
'created_at': now,
'updated_at': now,
},
]
)
# Insert main playlist
playlist_table = sa.table(
'playlist',
sa.column('name', sa.String),
sa.column('description', sa.String),
sa.column('is_main', sa.Boolean),
sa.column('is_deletable', sa.Boolean),
sa.column('is_current', sa.Boolean),
sa.column('created_at', sa.DateTime),
sa.column('updated_at', sa.DateTime),
)
op.bulk_insert(
playlist_table,
[
{
'name': 'All',
'description': 'The default main playlist with all the tracks',
'is_main': True,
'is_deletable': False,
'is_current': True,
'created_at': now,
'updated_at': now,
}
]
)
def downgrade() -> None:
"""Downgrade schema and remove initial data."""
# Remove initial plans
op.execute("DELETE FROM plan WHERE code IN ('free', 'premium', 'pro')")
# Remove main playlist
op.execute("DELETE FROM playlist WHERE is_main = 1")

View File

@@ -0,0 +1,45 @@
"""Add TTS table
Revision ID: e617c155eea9
Revises: a0d322857b2c
Create Date: 2025-09-20 21:51:26.557738
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = 'e617c155eea9'
down_revision: Union[str, Sequence[str], None] = 'a0d322857b2c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tts',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('text', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=False),
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False),
sa.Column('options', sa.JSON(), nullable=True),
sa.Column('sound_id', sa.Integer(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tts')
# ### end Alembic commands ###

View File

@@ -2,15 +2,36 @@
from fastapi import APIRouter
from app.api.v1 import auth, main, player, playlists, socket, sounds
from app.api.v1 import (
admin,
auth,
dashboard,
extractions,
favorites,
files,
main,
player,
playlists,
scheduler,
socket,
sounds,
tts,
)
# V1 API router with v1 prefix
api_router = APIRouter(prefix="/v1")
# Include all route modules
api_router.include_router(auth.router, tags=["authentication"])
api_router.include_router(dashboard.router, tags=["dashboard"])
api_router.include_router(extractions.router, tags=["extractions"])
api_router.include_router(favorites.router, tags=["favorites"])
api_router.include_router(files.router, tags=["files"])
api_router.include_router(main.router, tags=["main"])
api_router.include_router(player.router, tags=["player"])
api_router.include_router(playlists.router, tags=["playlists"])
api_router.include_router(scheduler.router, tags=["scheduler"])
api_router.include_router(socket.router, tags=["socket"])
api_router.include_router(sounds.router, tags=["sounds"])
api_router.include_router(tts.router, tags=["tts"])
api_router.include_router(admin.router)

View File

@@ -0,0 +1,12 @@
"""Admin API endpoints."""
from fastapi import APIRouter
from app.api.v1.admin import extractions, sounds, users
router = APIRouter(prefix="/admin")
# Include all admin sub-routers
router.include_router(extractions.router)
router.include_router(sounds.router)
router.include_router(users.router)

View File

@@ -0,0 +1,59 @@
"""Admin audio extraction API endpoints."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_admin_user
from app.models.user import User
from app.services.extraction import ExtractionService
from app.services.extraction_processor import extraction_processor
router = APIRouter(prefix="/extractions", tags=["admin-extractions"])
async def get_extraction_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> ExtractionService:
"""Get the extraction service."""
return ExtractionService(session)
@router.get("/status")
async def get_extraction_processor_status(
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
) -> dict:
"""Get the status of the extraction processor. Admin only."""
return extraction_processor.get_status()
@router.delete("/{extraction_id}")
async def delete_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, str]:
"""Delete any extraction and its associated sound and files. Admin only."""
try:
deleted = await extraction_service.delete_extraction(extraction_id, None)
if not deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete extraction: {e!s}",
) from e
else:
return {
"message": f"Extraction {extraction_id} deleted successfully",
}

228
app/api/v1/admin/sounds.py Normal file
View File

@@ -0,0 +1,228 @@
"""Admin sound management API endpoints."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_admin_user
from app.models.user import User
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService
router = APIRouter(prefix="/sounds", tags=["admin-sounds"])
async def get_sound_scanner_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SoundScannerService:
"""Get the sound scanner service."""
return SoundScannerService(session)
async def get_sound_normalizer_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SoundNormalizerService:
"""Get the sound normalizer service."""
return SoundNormalizerService(session)
# SCAN ENDPOINTS
@router.post("/scan")
async def scan_sounds(
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
) -> dict[str, ScanResults | str]:
"""Sync the soundboard directory (add/update/delete sounds). Admin only."""
try:
results = await scanner_service.scan_soundboard_directory()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to sync sounds: {e!s}",
) from e
else:
return {
"message": "Sound sync completed",
"results": results,
}
@router.post("/scan/custom")
async def scan_custom_directory(
directory: str,
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
sound_type: str = "SDB",
) -> dict[str, ScanResults | str]:
"""Sync a custom directory with the database. Admin only."""
try:
results = await scanner_service.scan_directory(directory, sound_type)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to sync directory: {e!s}",
) from e
else:
return {
"message": f"Sync of directory '{directory}' completed",
"results": results,
}
# NORMALIZE ENDPOINTS
@router.post("/normalize/all")
async def normalize_all_sounds(
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
normalizer_service: Annotated[
SoundNormalizerService,
Depends(get_sound_normalizer_service),
],
*,
force: Annotated[
bool,
Query(
description="Force normalization of already normalized sounds",
),
] = False,
one_pass: Annotated[
bool | None,
Query(
description="Use one-pass normalization (overrides config)",
),
] = None,
) -> dict[str, NormalizationResults | str]:
"""Normalize all unnormalized sounds. Admin only."""
try:
results = await normalizer_service.normalize_all_sounds(
force=force,
one_pass=one_pass,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize sounds: {e!s}",
) from e
else:
return {
"message": "Sound normalization completed",
"results": results,
}
@router.post("/normalize/type/{sound_type}")
async def normalize_sounds_by_type(
sound_type: str,
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
normalizer_service: Annotated[
SoundNormalizerService,
Depends(get_sound_normalizer_service),
],
*,
force: Annotated[
bool,
Query(
description="Force normalization of already normalized sounds",
),
] = False,
one_pass: Annotated[
bool | None,
Query(
description="Use one-pass normalization (overrides config)",
),
] = None,
) -> dict[str, NormalizationResults | str]:
"""Normalize all sounds of a specific type (SDB, TTS, EXT). Admin only."""
# Validate sound type
valid_types = ["SDB", "TTS", "EXT"]
if sound_type not in valid_types:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid sound type. Must be one of: {', '.join(valid_types)}",
)
try:
results = await normalizer_service.normalize_sounds_by_type(
sound_type=sound_type,
force=force,
one_pass=one_pass,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize {sound_type} sounds: {e!s}",
) from e
else:
return {
"message": f"Normalization of {sound_type} sounds completed",
"results": results,
}
@router.post("/normalize/{sound_id}")
async def normalize_sound_by_id(
sound_id: int,
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
normalizer_service: Annotated[
SoundNormalizerService,
Depends(get_sound_normalizer_service),
],
*,
force: Annotated[
bool,
Query(
description="Force normalization of already normalized sound",
),
] = False,
one_pass: Annotated[
bool | None,
Query(
description="Use one-pass normalization (overrides config)",
),
] = None,
) -> dict[str, str]:
"""Normalize a specific sound by ID. Admin only."""
try:
# Get the sound
sound = await normalizer_service.sound_repo.get_by_id(sound_id)
if not sound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Sound with ID {sound_id} not found",
)
# Normalize the sound
result = await normalizer_service.normalize_sound(
sound=sound,
force=force,
one_pass=one_pass,
)
# Check result status
if result["status"] == "error":
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize sound: {result['error']}",
)
return {
"message": f"Sound normalization {result['status']}: {sound.filename}",
"status": result["status"],
"reason": result["reason"] or "",
"normalized_filename": result["normalized_filename"] or "",
}
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize sound: {e!s}",
) from e

176
app/api/v1/admin/users.py Normal file
View File

@@ -0,0 +1,176 @@
"""Admin users endpoints."""
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_admin_user
from app.models.plan import Plan
from app.models.user import User
from app.repositories.plan import PlanRepository
from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus
from app.schemas.auth import UserResponse
from app.schemas.user import UserUpdate
router = APIRouter(
prefix="/users",
tags=["admin-users"],
dependencies=[Depends(get_admin_user)],
)
def _user_to_response(user: User) -> UserResponse:
"""Convert User model to UserResponse."""
return UserResponse(
id=user.id,
email=user.email,
name=user.name,
picture=user.picture,
role=user.role,
credits=user.credits,
is_active=user.is_active,
plan={
"id": user.plan.id,
"name": user.plan.name,
"max_credits": user.plan.max_credits,
"features": [], # Add features if needed
}
if user.plan
else {},
created_at=user.created_at,
updated_at=user.updated_at,
)
@router.get("/")
async def list_users( # noqa: PLR0913
session: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
search: Annotated[str | None, Query(description="Search in name or email")] = None,
sort_by: Annotated[
UserSortField,
Query(description="Sort by field"),
] = UserSortField.NAME,
sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC,
status_filter: Annotated[
UserStatus,
Query(description="Filter by status"),
] = UserStatus.ALL,
) -> dict[str, Any]:
"""Get all users with pagination, search, and filters (admin only)."""
user_repo = UserRepository(session)
users, total_count = await user_repo.get_all_with_plan_paginated(
page=page,
limit=limit,
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
)
total_pages = (total_count + limit - 1) // limit # Ceiling division
return {
"users": [_user_to_response(user) for user in users],
"total": total_count,
"page": page,
"limit": limit,
"total_pages": total_pages,
}
@router.get("/{user_id}")
async def get_user(
user_id: int,
session: Annotated[AsyncSession, Depends(get_db)],
) -> UserResponse:
"""Get a specific user by ID (admin only)."""
user_repo = UserRepository(session)
user = await user_repo.get_by_id_with_plan(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
return _user_to_response(user)
@router.patch("/{user_id}")
async def update_user(
user_id: int,
user_update: UserUpdate,
session: Annotated[AsyncSession, Depends(get_db)],
) -> UserResponse:
"""Update a user (admin only)."""
user_repo = UserRepository(session)
user = await user_repo.get_by_id_with_plan(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
update_data = user_update.model_dump(exclude_unset=True)
# If plan_id is being updated, validate it exists
if "plan_id" in update_data:
plan_repo = PlanRepository(session)
plan = await plan_repo.get_by_id(update_data["plan_id"])
if not plan:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Plan not found",
)
updated_user = await user_repo.update(user, update_data)
# Need to refresh the plan relationship after update
await session.refresh(updated_user, ["plan"])
return _user_to_response(updated_user)
@router.post("/{user_id}/disable")
async def disable_user(
user_id: int,
session: Annotated[AsyncSession, Depends(get_db)],
) -> dict[str, str]:
"""Disable a user (admin only)."""
user_repo = UserRepository(session)
user = await user_repo.get_by_id_with_plan(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
await user_repo.update(user, {"is_active": False})
return {"message": "User disabled successfully"}
@router.post("/{user_id}/enable")
async def enable_user(
user_id: int,
session: Annotated[AsyncSession, Depends(get_db)],
) -> dict[str, str]:
"""Enable a user (admin only)."""
user_repo = UserRepository(session)
user = await user_repo.get_by_id_with_plan(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
await user_repo.update(user, {"is_active": True})
return {"message": "User enabled successfully"}
@router.get("/plans/list")
async def list_plans(
session: Annotated[AsyncSession, Depends(get_db)],
) -> list[Plan]:
"""Get all plans for user editing (admin only)."""
plan_repo = PlanRepository(session)
return await plan_repo.get_all()

View File

@@ -20,6 +20,8 @@ from app.schemas.auth import (
ApiTokenRequest,
ApiTokenResponse,
ApiTokenStatusResponse,
ChangePasswordRequest,
UpdateProfileRequest,
UserLoginRequest,
UserRegisterRequest,
UserResponse,
@@ -27,6 +29,7 @@ from app.schemas.auth import (
from app.services.auth import AuthService
from app.services.oauth import OAuthService
from app.utils.auth import JWTUtils, TokenUtils
from app.utils.cookies import set_access_token_cookie, set_auth_cookies
router = APIRouter(prefix="/auth", tags=["authentication"])
logger = get_logger(__name__)
@@ -54,26 +57,11 @@ async def register(
refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
max_age=auth_response.token.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
* 24
* 60
* 60, # Convert days to seconds
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
set_auth_cookies(
response=response,
access_token=auth_response.token.access_token,
refresh_token=refresh_token,
expires_in=auth_response.token.expires_in,
)
except HTTPException:
@@ -103,26 +91,11 @@ async def login(
refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
max_age=auth_response.token.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
* 24
* 60
* 60, # Convert days to seconds
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
set_auth_cookies(
response=response,
access_token=auth_response.token.access_token,
refresh_token=refresh_token,
expires_in=auth_response.token.expires_in,
)
except HTTPException:
@@ -171,14 +144,10 @@ async def refresh_token(
token_response = await auth_service.refresh_access_token(refresh_token)
# Set new access token cookie
response.set_cookie(
key="access_token",
value=token_response.access_token,
max_age=token_response.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
set_access_token_cookie(
response=response,
access_token=token_response.access_token,
expires_in=token_response.expires_in,
)
except HTTPException:
@@ -212,7 +181,9 @@ async def logout(
user_id = int(user_id_str)
user = await auth_service.get_current_user(user_id)
logger.info("Found user from access token: %s", user.email)
except (HTTPException, Exception) as e:
except HTTPException as e:
logger.info("Access token validation failed: %s", str(e))
except Exception as e: # noqa: BLE001
logger.info("Access token validation failed: %s", str(e))
# If no user found, try refresh token
@@ -224,7 +195,9 @@ async def logout(
user_id = int(user_id_str)
user = await auth_service.get_current_user(user_id)
logger.info("Found user from refresh token: %s", user.email)
except (HTTPException, Exception) as e:
except HTTPException as e:
logger.info("Refresh token validation failed: %s", str(e))
except Exception as e: # noqa: BLE001
logger.info("Refresh token validation failed: %s", str(e))
# If we found a user, revoke their refresh token
@@ -240,14 +213,14 @@ async def logout(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Match the domain used when setting cookies
domain=settings.COOKIE_DOMAIN, # Match the domain used when setting cookies
)
response.delete_cookie(
key="refresh_token",
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Match the domain used when setting cookies
domain=settings.COOKIE_DOMAIN, # Match the domain used when setting cookies
)
return {"message": "Successfully logged out"}
@@ -307,24 +280,11 @@ async def oauth_callback(
# Set HTTP-only cookies for both tokens (not used due to cross-port issues)
# These cookies are kept for potential future same-origin scenarios
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
max_age=auth_response.token.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
path="/", # Ensure cookie is available for all paths
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
set_auth_cookies(
response=response,
access_token=auth_response.token.access_token,
refresh_token=refresh_token,
expires_in=auth_response.token.expires_in,
path="/", # Ensure cookie is available for all paths
)
@@ -349,7 +309,7 @@ async def oauth_callback(
"created_at": time.time(),
}
redirect_url = f"http://localhost:8001/auth/callback?code={temp_code}"
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?code={temp_code}"
logger.info("Redirecting to: %s", redirect_url)
return RedirectResponse(
@@ -410,24 +370,11 @@ async def exchange_oauth_token(
)
# Set the proper auth cookies
response.set_cookie(
key="access_token",
value=token_data["access_token"],
max_age=token_data["expires_in"],
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost",
path="/",
)
response.set_cookie(
key="refresh_token",
value=token_data["refresh_token"],
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost",
set_auth_cookies(
response=response,
access_token=token_data["access_token"],
refresh_token=token_data["refresh_token"],
expires_in=token_data["expires_in"],
path="/",
)
@@ -505,3 +452,93 @@ async def revoke_api_token(
) from e
else:
return {"message": "API token revoked successfully"}
# Profile management endpoints
@router.patch("/me")
async def update_profile(
request: UpdateProfileRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> UserResponse:
"""Update the current user's profile."""
try:
updated_user = await auth_service.update_user_profile(
current_user,
request.model_dump(exclude_unset=True),
)
return await auth_service.user_to_response(updated_user)
except Exception as e:
logger.exception("Failed to update profile for user: %s", current_user.email)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update profile",
) from e
@router.post("/change-password")
async def change_password(
request: ChangePasswordRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> dict[str, str]:
"""Change the current user's password."""
# Store user email before operations to avoid session detachment issues
user_email = current_user.email
try:
await auth_service.change_user_password(
current_user,
request.current_password,
request.new_password,
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
logger.exception("Failed to change password for user: %s", user_email)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to change password",
) from e
else:
return {"message": "Password changed successfully"}
@router.get("/user-providers")
async def get_user_providers(
current_user: Annotated[User, Depends(get_current_active_user)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> list[dict[str, str]]:
"""Get the current user's connected authentication providers."""
providers = []
# Add password provider if user has password
if current_user.password_hash:
providers.append(
{
"provider": "password",
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
},
)
# Get OAuth providers from the database
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
for oauth in oauth_providers:
display_name = oauth.provider.title() # Capitalize first letter
if oauth.provider == "github":
display_name = "GitHub"
elif oauth.provider == "google":
display_name = "Google"
providers.append(
{
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
},
)
return providers

127
app/api/v1/dashboard.py Normal file
View File

@@ -0,0 +1,127 @@
"""Dashboard API endpoints."""
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.core.dependencies import get_current_user, get_dashboard_service
from app.models.user import User
from app.services.dashboard import DashboardService
router = APIRouter(prefix="/dashboard", tags=["dashboard"])
@router.get("/soundboard-statistics")
async def get_soundboard_statistics(
_current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
) -> dict[str, Any]:
"""Get soundboard statistics."""
try:
return await dashboard_service.get_soundboard_statistics()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch soundboard statistics: {e!s}",
) from e
@router.get("/track-statistics")
async def get_track_statistics(
_current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
) -> dict[str, Any]:
"""Get track statistics."""
try:
return await dashboard_service.get_track_statistics()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch track statistics: {e!s}",
) from e
@router.get("/tts-statistics")
async def get_tts_statistics(
_current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
) -> dict[str, Any]:
"""Get TTS statistics."""
try:
return await dashboard_service.get_tts_statistics()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch TTS statistics: {e!s}",
) from e
@router.get("/top-users")
async def get_top_users(
_current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
metric_type: Annotated[
str,
Query(
description=(
"Metric type: sounds_played, credits_used, tracks_added, "
"tts_added, playlists_created"
),
),
],
period: Annotated[
str,
Query(
description="Time period (today, 1_day, 1_week, 1_month, 1_year, all_time)",
),
] = "all_time",
limit: Annotated[
int,
Query(description="Number of top users to return", ge=1, le=100),
] = 10,
) -> list[dict[str, Any]]:
"""Get top users by metric for a specific period."""
try:
return await dashboard_service.get_top_users(
metric_type=metric_type,
period=period,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch top users: {e!s}",
) from e
@router.get("/top-sounds")
async def get_top_sounds(
_current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
sound_type: Annotated[
str,
Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
],
period: Annotated[
str,
Query(
description="Time period (today, 1_day, 1_week, 1_month, 1_year, all_time)",
),
] = "all_time",
limit: Annotated[
int,
Query(description="Number of top sounds to return", ge=1, le=100),
] = 10,
) -> list[dict[str, Any]]:
"""Get top sounds by play count for a specific period."""
try:
return await dashboard_service.get_top_sounds(
sound_type=sound_type,
period=period,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch top sounds: {e!s}",
) from e

248
app/api/v1/extractions.py Normal file
View File

@@ -0,0 +1,248 @@
"""Audio extraction API endpoints."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.extraction_processor import extraction_processor
router = APIRouter(prefix="/extractions", tags=["extractions"])
async def get_extraction_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> ExtractionService:
"""Get the extraction service."""
return ExtractionService(session)
@router.post("/")
async def create_extraction(
url: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, ExtractionInfo | str]:
"""Create a new extraction job for a URL."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
extraction_info = await extraction_service.create_extraction(
url,
current_user.id,
)
# Queue the extraction for background processing
await extraction_processor.queue_extraction(extraction_info["id"])
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create extraction: {e!s}",
) from e
else:
return {
"message": "Extraction queued successfully",
"extraction": extraction_info,
}
@router.get("/user")
async def get_user_extractions( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[
str | None,
Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions for the current user."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
result = await extraction_service.get_user_extractions(
user_id=current_user.id,
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
else:
return extraction_info
@router.get("/")
async def get_all_extractions( # noqa: PLR0913
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[
str | None,
Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result
@router.get("/processing/current")
async def get_processing_extractions(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> list[ExtractionInfo]:
"""Get all currently processing extractions for showing ongoing toasts."""
try:
# Get all extractions with processing status
processing_extractions = await extraction_service.extraction_repo.get_by_status(
"processing",
)
# Convert to ExtractionInfo format
result = []
for extraction in processing_extractions:
# Get user information
user = await extraction_service.user_repo.get_by_id(extraction.user_id)
user_name = user.name if user else None
extraction_info: ExtractionInfo = {
"id": extraction.id or 0,
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user_name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
result.append(extraction_info)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get processing extractions: {e!s}",
) from e
else:
return result
@router.delete("/{extraction_id}")
async def delete_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, str]:
"""Delete extraction and associated sound/files. Users can only delete their own."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
deleted = await extraction_service.delete_extraction(
extraction_id, current_user.id,
)
if not deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
) from e
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete extraction: {e!s}",
) from e
else:
return {
"message": f"Extraction {extraction_id} deleted successfully",
}

197
app/api/v1/favorites.py Normal file
View File

@@ -0,0 +1,197 @@
"""Favorites management API endpoints."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.core.database import get_session_factory
from app.core.dependencies import get_current_active_user
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.favorite import (
FavoriteCountsResponse,
FavoriteResponse,
FavoritesListResponse,
)
from app.services.favorite import FavoriteService
router = APIRouter(prefix="/favorites", tags=["favorites"])
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
return FavoriteService(get_session_factory())
@router.get("/")
async def get_user_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0,
) -> FavoritesListResponse:
"""Get all favorites for the current user."""
favorites = await favorite_service.get_user_favorites(
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/sounds")
async def get_user_sound_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0,
) -> FavoritesListResponse:
"""Get sound favorites for the current user."""
favorites = await favorite_service.get_user_sound_favorites(
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/playlists")
async def get_user_playlist_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0,
) -> FavoritesListResponse:
"""Get playlist favorites for the current user."""
favorites = await favorite_service.get_user_playlist_favorites(
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/counts")
async def get_favorite_counts(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> FavoriteCountsResponse:
"""Get favorite counts for the current user."""
counts = await favorite_service.get_favorite_counts(current_user.id)
return FavoriteCountsResponse(**counts)
@router.post("/sounds/{sound_id}")
async def add_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> FavoriteResponse:
"""Add a sound to favorites."""
try:
favorite = await favorite_service.add_sound_favorite(current_user.id, sound_id)
return FavoriteResponse.model_validate(favorite)
except ValueError as e:
if "not found" in str(e):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
if "already favorited" in str(e):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
) from e
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
@router.post("/playlists/{playlist_id}")
async def add_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> FavoriteResponse:
"""Add a playlist to favorites."""
try:
favorite = await favorite_service.add_playlist_favorite(
current_user.id,
playlist_id,
)
return FavoriteResponse.model_validate(favorite)
except ValueError as e:
if "not found" in str(e):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
if "already favorited" in str(e):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
) from e
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
@router.delete("/sounds/{sound_id}")
async def remove_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> MessageResponse:
"""Remove a sound from favorites."""
try:
await favorite_service.remove_sound_favorite(current_user.id, sound_id)
return MessageResponse(message="Sound removed from favorites")
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
@router.delete("/playlists/{playlist_id}")
async def remove_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> MessageResponse:
"""Remove a playlist from favorites."""
try:
await favorite_service.remove_playlist_favorite(current_user.id, playlist_id)
return MessageResponse(message="Playlist removed from favorites")
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
@router.get("/sounds/{sound_id}/check")
async def check_sound_favorited(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> dict[str, bool]:
"""Check if a sound is favorited by the current user."""
is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound_id)
return {"is_favorited": is_favorited}
@router.get("/playlists/{playlist_id}/check")
async def check_playlist_favorited(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> dict[str, bool]:
"""Check if a playlist is favorited by the current user."""
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id,
playlist_id,
)
return {"is_favorited": is_favorited}

153
app/api/v1/files.py Normal file
View File

@@ -0,0 +1,153 @@
"""File serving API endpoints for audio files and thumbnails."""
import mimetypes
from pathlib import Path
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import FileResponse
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.core.logging import get_logger
from app.models.user import User
from app.repositories.sound import SoundRepository
from app.utils.audio import get_sound_file_path
logger = get_logger(__name__)
router = APIRouter(prefix="/files", tags=["files"])
async def get_sound_repository(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SoundRepository:
"""Get the sound repository."""
return SoundRepository(session)
@router.get("/sounds/{sound_id}/download")
async def download_sound(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
) -> FileResponse:
"""Download a sound file."""
try:
# Get the sound record
sound = await sound_repo.get_by_id(sound_id)
if not sound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Sound with ID {sound_id} not found",
)
# Get the file path using the audio utility
file_path = get_sound_file_path(sound)
# Determine filename based on normalization status
if sound.is_normalized and sound.normalized_filename:
filename = sound.normalized_filename
else:
filename = sound.filename
# Check if file exists
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Sound file not found on disk",
)
# Get MIME type
mime_type, _ = mimetypes.guess_type(str(file_path))
if not mime_type:
mime_type = "audio/mpeg" # Default to MP3
logger.info(
"Serving sound download: %s (user: %d, sound: %d)",
filename,
current_user.id,
sound_id,
)
return FileResponse(
path=str(file_path),
filename=filename,
media_type=mime_type,
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
except HTTPException:
raise
except Exception as e:
logger.exception("Error serving sound download for sound %d", sound_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to serve sound file",
) from e
@router.get("/sounds/{sound_id}/thumbnail")
async def get_sound_thumbnail(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
) -> FileResponse:
"""Get a sound's thumbnail image."""
try:
# Get the sound record
sound = await sound_repo.get_by_id(sound_id)
if not sound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Sound with ID {sound_id} not found",
)
# Check if sound has a thumbnail
if not sound.thumbnail:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No thumbnail available for this sound",
)
# Get thumbnail file path
thumbnail_path = Path(settings.EXTRACTION_THUMBNAILS_DIR) / sound.thumbnail
# Check if thumbnail file exists
if not thumbnail_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Thumbnail file not found on disk",
)
# Get MIME type
mime_type, _ = mimetypes.guess_type(str(thumbnail_path))
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG
logger.debug(
"Serving thumbnail: %s (user: %d, sound: %d)",
sound.thumbnail,
current_user.id,
sound_id,
)
return FileResponse(
path=str(thumbnail_path),
media_type=mime_type,
headers={
"Cache-Control": "public, max-age=3600", # Cache for 1 hour
"Content-Disposition": "inline", # Display inline, not download
},
)
except HTTPException:
raise
except Exception as e:
logger.exception("Error serving thumbnail for sound %d", sound_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to serve thumbnail",
) from e

View File

@@ -1,6 +1,7 @@
"""Main router for v1 endpoints."""
from fastapi import APIRouter
from fastapi.responses import HTMLResponse
from app.core.logging import get_logger
from app.schemas.common import HealthResponse
@@ -10,8 +11,77 @@ router = APIRouter()
logger = get_logger(__name__)
@router.get("/")
@router.get("/health")
def health() -> HealthResponse:
"""Health check endpoint."""
logger.info("Health check endpoint accessed")
return HealthResponse(status="healthy")
@router.get("/docs/scalar", response_class=HTMLResponse)
def scalar_docs() -> HTMLResponse:
"""Serve the API documentation using Scalar."""
return """
<!doctype html>
<html>
<head>
<title>API Documentation - Scalar</title>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
</head>
<body>
<script
id="api-reference"
data-url="http://localhost:8000/api/openapi.json"
src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
</body>
</html>
"""
@router.get("/docs/rapidoc", response_class=HTMLResponse)
async def rapidoc_docs() -> HTMLResponse:
"""Serve the API documentation using Rapidoc."""
return """
<!doctype html>
<html>
<head>
<title>API Documentation - Rapidoc</title>
<meta charset="utf-8">
<script type="module" src="https://unpkg.com/rapidoc/dist/rapidoc-min.js"></script>
</head>
<body>
<rapi-doc
spec-url="http://localhost:8000/api/openapi.json"
theme="dark"
render-style="read">
</rapi-doc>
</body>
</html>
"""
@router.get("/docs/elements", response_class=HTMLResponse)
async def elements_docs() -> HTMLResponse:
"""Serve the API documentation using Stoplight Elements."""
return """
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta
name="viewport"
content="width=device-width, initial-scale=1, shrink-to-fit=no"
>
<title>API Documentation - elements</title>
<script src="https://unpkg.com/@stoplight/elements/web-components.min.js"></script>
<link rel="stylesheet" href="https://unpkg.com/@stoplight/elements/styles.min.css">
</head>
<body>
<elements-api
apiDescriptionUrl="http://localhost:8000/api/openapi.json"
router="hash"
/>
</body>
</html>
"""

View File

@@ -1,6 +1,6 @@
"""Player API endpoints."""
from typing import Annotated, Any
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
@@ -137,10 +137,10 @@ async def seek(
"""Seek to specific position in current track."""
try:
player = get_player_service()
await player.seek(request.position_ms)
return MessageResponse(message=f"Seeked to position {request.position_ms}ms")
await player.seek(request.position)
return MessageResponse(message=f"Seeked to position {request.position}ms")
except Exception as e:
logger.exception("Error seeking to position %s", request.position_ms)
logger.exception("Error seeking to position %s", request.position)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to seek",
@@ -165,6 +165,40 @@ async def set_volume(
) from e
@router.post("/mute")
async def mute(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
) -> MessageResponse:
"""Mute playback."""
try:
player = get_player_service()
await player.mute()
return MessageResponse(message="Playback muted")
except Exception as e:
logger.exception("Error muting playback")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to mute playback",
) from e
@router.post("/unmute")
async def unmute(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
) -> MessageResponse:
"""Unmute playback."""
try:
player = get_player_service()
await player.unmute()
return MessageResponse(message="Playback unmuted")
except Exception as e:
logger.exception("Error unmuting playback")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to unmute playback",
) from e
@router.post("/mode")
async def set_mode(
request: PlayerModeRequest,
@@ -214,4 +248,22 @@ async def get_state(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get player state",
) from e
) from e
@router.post("/play-next/{sound_id}")
async def add_to_play_next(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
) -> MessageResponse:
"""Add a sound to the play next queue."""
try:
player = get_player_service()
await player.add_to_play_next(sound_id)
return MessageResponse(message=f"Added sound {sound_id} to play next queue")
except Exception as e:
logger.exception("Error adding sound to play next queue: %s", sound_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add sound to play next queue",
) from e

View File

@@ -1,13 +1,14 @@
"""Playlist management API endpoints."""
from typing import Annotated
from typing import Annotated, Any
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.database import get_db, get_session_factory
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.repositories.playlist import PlaylistSortField, SortOrder
from app.schemas.common import MessageResponse
from app.schemas.playlist import (
PlaylistAddSoundRequest,
@@ -18,6 +19,7 @@ from app.schemas.playlist import (
PlaylistStatsResponse,
PlaylistUpdateRequest,
)
from app.services.favorite import FavoriteService
from app.services.playlist import PlaylistService
router = APIRouter(prefix="/playlists", tags=["playlists"])
@@ -30,44 +32,143 @@ async def get_playlist_service(
return PlaylistService(session)
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
return FavoriteService(get_session_factory())
@router.get("/")
async def get_all_playlists(
async def get_all_playlists( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> list[PlaylistResponse]:
"""Get all playlists from all users."""
playlists = await playlist_service.get_all_playlists()
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
search: Annotated[
str | None,
Query(description="Search playlists by name"),
] = None,
sort_by: Annotated[
PlaylistSortField | None,
Query(description="Sort by field"),
] = None,
sort_order: Annotated[
SortOrder,
Query(description="Sort order (asc or desc)"),
] = SortOrder.ASC,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited playlists"),
] = False,
) -> dict[str, Any]:
"""Get all playlists from all users with search and sorting."""
result = await playlist_service.search_and_sort_playlists_paginated(
search_query=search,
sort_by=sort_by,
sort_order=sort_order,
user_id=None,
include_stats=True,
page=page,
limit=limit,
favorites_only=favorites_only,
current_user_id=current_user.id,
)
# Convert to PlaylistResponse with favorite indicators
playlist_responses = []
for playlist_dict in result["playlists"]:
# The playlist service returns dict, need to create playlist object structure
playlist_id = playlist_dict["id"]
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id,
playlist_id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id)
# Create a PlaylistResponse-like dict with proper datetime conversion
playlist_response = {
**playlist_dict,
"created_at": (
playlist_dict["created_at"].isoformat()
if playlist_dict["created_at"]
else None
),
"updated_at": (
playlist_dict["updated_at"].isoformat()
if playlist_dict["updated_at"]
else None
),
"is_favorited": is_favorited,
"favorite_count": favorite_count,
}
playlist_responses.append(playlist_response)
return {
"playlists": playlist_responses,
"total": result["total"],
"page": result["page"],
"limit": result["limit"],
"total_pages": result["total_pages"],
}
@router.get("/user")
async def get_user_playlists(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> list[PlaylistResponse]:
"""Get playlists for the current user only."""
playlists = await playlist_service.get_user_playlists(current_user.id)
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
# Add favorite indicators for each playlist
playlist_responses = []
for playlist in playlists:
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id,
playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
playlist_response = PlaylistResponse.from_playlist(
playlist,
is_favorited,
favorite_count,
)
playlist_responses.append(playlist_response)
return playlist_responses
@router.get("/main")
async def get_main_playlist(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> PlaylistResponse:
"""Get the global main playlist."""
playlist = await playlist_service.get_main_playlist()
return PlaylistResponse.from_playlist(playlist)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id,
playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@router.get("/current")
async def get_current_playlist(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> PlaylistResponse:
"""Get the user's current playlist (falls back to main playlist)."""
playlist = await playlist_service.get_current_playlist(current_user.id)
return PlaylistResponse.from_playlist(playlist)
"""Get the global current playlist (falls back to main playlist)."""
playlist = await playlist_service.get_current_playlist()
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id,
playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@router.post("/")
@@ -91,10 +192,16 @@ async def get_playlist(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> PlaylistResponse:
"""Get a specific playlist."""
playlist = await playlist_service.get_playlist_by_id(playlist_id)
return PlaylistResponse.from_playlist(playlist)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id,
playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@router.put("/{playlist_id}")
@@ -105,13 +212,19 @@ async def update_playlist(
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistResponse:
"""Update a playlist."""
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
playlist = await playlist_service.update_playlist(
playlist_id=playlist_id,
user_id=current_user.id,
name=request.name,
description=request.description,
genre=request.genre,
is_current=request.is_current,
is_current=None, # is_current is not handled by this endpoint
)
return PlaylistResponse.from_playlist(playlist)
@@ -130,7 +243,7 @@ async def delete_playlist(
@router.get("/search/{query}")
async def search_playlists(
query: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> list[PlaylistResponse]:
"""Search all playlists by name."""
@@ -152,7 +265,7 @@ async def search_user_playlists(
@router.get("/{playlist_id}/sounds")
async def get_playlist_sounds(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> list[PlaylistSoundResponse]:
"""Get all sounds in a playlist."""
@@ -212,28 +325,28 @@ async def reorder_playlist_sounds(
@router.put("/{playlist_id}/set-current")
async def set_current_playlist(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistResponse:
"""Set a playlist as the current playlist."""
playlist = await playlist_service.set_current_playlist(playlist_id, current_user.id)
playlist = await playlist_service.set_current_playlist(playlist_id)
return PlaylistResponse.from_playlist(playlist)
@router.delete("/current")
async def unset_current_playlist(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> MessageResponse:
"""Unset the current playlist."""
await playlist_service.unset_current_playlist(current_user.id)
await playlist_service.unset_current_playlist()
return MessageResponse(message="Current playlist unset successfully")
@router.get("/{playlist_id}/stats")
async def get_playlist_stats(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistStatsResponse:
"""Get statistics for a playlist."""

230
app/api/v1/scheduler.py Normal file
View File

@@ -0,0 +1,230 @@
"""API endpoints for scheduled task management."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import (
get_admin_user,
get_current_active_user,
)
from app.core.services import get_global_scheduler_service
from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
from app.models.user import User
from app.repositories.scheduled_task import ScheduledTaskRepository
from app.schemas.scheduler import (
ScheduledTaskCreate,
ScheduledTaskResponse,
ScheduledTaskUpdate,
TaskFilterParams,
)
from app.services.scheduler import SchedulerService
router = APIRouter(prefix="/scheduler")
def get_scheduler_service() -> SchedulerService:
"""Get the global scheduler service instance."""
return get_global_scheduler_service()
def get_task_filters(
status: Annotated[
TaskStatus | None, Query(description="Filter by task status"),
] = None,
task_type: Annotated[
TaskType | None, Query(description="Filter by task type"),
] = None,
limit: Annotated[int, Query(description="Maximum number of tasks to return")] = 50,
offset: Annotated[int, Query(description="Number of tasks to skip")] = 0,
) -> TaskFilterParams:
"""Create task filter parameters from query parameters."""
return TaskFilterParams(
status=status,
task_type=task_type,
limit=limit,
offset=offset,
)
@router.post("/tasks", response_model=ScheduledTaskResponse)
async def create_task(
task_data: ScheduledTaskCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
scheduler_service: Annotated[SchedulerService, Depends(get_scheduler_service)],
) -> ScheduledTask:
"""Create a new scheduled task."""
try:
return await scheduler_service.create_task(
task_data=task_data,
user_id=current_user.id,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@router.get("/tasks", response_model=list[ScheduledTaskResponse])
async def get_user_tasks(
filters: Annotated[TaskFilterParams, Depends(get_task_filters)],
current_user: Annotated[User, Depends(get_current_active_user)],
scheduler_service: Annotated[SchedulerService, Depends(get_scheduler_service)],
) -> list[ScheduledTask]:
"""Get user's scheduled tasks."""
return await scheduler_service.get_user_tasks(
user_id=current_user.id,
status=filters.status,
task_type=filters.task_type,
limit=filters.limit,
offset=filters.offset,
)
@router.get("/tasks/{task_id}", response_model=ScheduledTaskResponse)
async def get_task(
task_id: int,
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
) -> ScheduledTask:
"""Get a specific scheduled task."""
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
return task
@router.patch("/tasks/{task_id}", response_model=ScheduledTaskResponse)
async def update_task(
task_id: int,
task_update: ScheduledTaskUpdate,
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
) -> ScheduledTask:
"""Update a scheduled task."""
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
# Update task fields
update_data = task_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(task, field, value)
return await repo.update(task)
@router.delete("/tasks/{task_id}")
async def delete_task(
task_id: int,
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
scheduler_service: Annotated[
SchedulerService, Depends(get_scheduler_service),
] = ...,
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
) -> dict:
"""Delete a scheduled task completely."""
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
success = await scheduler_service.delete_task(task_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to delete task")
return {"message": "Task deleted successfully"}
# Admin-only endpoints
@router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
async def get_all_tasks(
status: Annotated[
TaskStatus | None, Query(description="Filter by task status"),
] = None,
task_type: Annotated[
TaskType | None, Query(description="Filter by task type"),
] = None,
limit: Annotated[
int | None, Query(description="Maximum number of tasks to return"),
] = 100,
offset: Annotated[
int | None, Query(description="Number of tasks to skip"),
] = 0,
_: Annotated[User, Depends(get_admin_user)] = ...,
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
) -> list[ScheduledTask]:
"""Get all scheduled tasks (admin only)."""
# Build query with pagination and filtering
statement = select(ScheduledTask)
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
result = await db_session.exec(statement)
return list(result.all())
@router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
async def get_system_tasks(
status: Annotated[
TaskStatus | None, Query(description="Filter by task status"),
] = None,
task_type: Annotated[
TaskType | None, Query(description="Filter by task type"),
] = None,
_: Annotated[User, Depends(get_admin_user)] = ...,
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
) -> list[ScheduledTask]:
"""Get system tasks (admin only)."""
repo = ScheduledTaskRepository(db_session)
return await repo.get_system_tasks(status=status, task_type=task_type)
@router.post("/admin/system-tasks", response_model=ScheduledTaskResponse)
async def create_system_task(
task_data: ScheduledTaskCreate,
_: Annotated[User, Depends(get_admin_user)] = ...,
scheduler_service: Annotated[
SchedulerService, Depends(get_scheduler_service),
] = ...,
) -> ScheduledTask:
"""Create a system task (admin only)."""
try:
return await scheduler_service.create_task(
task_data=task_data,
user_id=None, # System task
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e

View File

@@ -1,5 +1,7 @@
"""Socket.IO API endpoints for WebSocket management."""
from typing import Annotated
from fastapi import APIRouter, Depends
from app.core.dependencies import get_current_user
@@ -10,7 +12,9 @@ router = APIRouter(prefix="/socket", tags=["socket"])
@router.get("/status")
async def get_socket_status(current_user: User = Depends(get_current_user)):
async def get_socket_status(
current_user: Annotated[User, Depends(get_current_user)],
) -> dict[str, int | bool]:
"""Get current socket connection status."""
connected_users = socket_manager.get_connected_users()
@@ -25,8 +29,8 @@ async def get_socket_status(current_user: User = Depends(get_current_user)):
async def send_message_to_user(
target_user_id: int,
message: str,
current_user: User = Depends(get_current_user),
):
current_user: Annotated[User, Depends(get_current_user)],
) -> dict[str, int | bool | str]:
"""Send a message to a specific user via WebSocket."""
success = await socket_manager.send_to_user(
str(target_user_id),
@@ -48,8 +52,8 @@ async def send_message_to_user(
@router.post("/broadcast")
async def broadcast_message(
message: str,
current_user: User = Depends(get_current_user),
):
current_user: Annotated[User, Depends(get_current_user)],
) -> dict[str, bool | str]:
"""Broadcast a message to all connected users."""
await socket_manager.broadcast_to_all(
"broadcast_message",

View File

@@ -5,52 +5,25 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.database import get_db, get_session_factory
from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType
from app.models.user import User
from app.repositories.sound import SoundRepository
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.extraction_processor import extraction_processor
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService
from app.services.vlc_player import get_vlc_player_service, VLCPlayerService
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
from app.schemas.sound import SoundResponse, SoundsListResponse
from app.services.favorite import FavoriteService
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
router = APIRouter(prefix="/sounds", tags=["sounds"])
async def get_sound_scanner_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SoundScannerService:
"""Get the sound scanner service."""
return SoundScannerService(session)
async def get_sound_normalizer_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SoundNormalizerService:
"""Get the sound normalizer service."""
return SoundNormalizerService(session)
async def get_extraction_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> ExtractionService:
"""Get the extraction service."""
return ExtractionService(session)
def get_vlc_player() -> VLCPlayerService:
"""Get the VLC player service."""
from app.core.database import get_session_factory
return get_vlc_player_service(get_session_factory())
def get_credit_service() -> CreditService:
"""Get the credit service."""
from app.core.database import get_session_factory
return CreditService(get_session_factory())
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
return FavoriteService(get_session_factory())
async def get_sound_repository(
@@ -60,377 +33,92 @@ async def get_sound_repository(
return SoundRepository(session)
# SCAN
@router.post("/scan")
async def scan_sounds(
@router.get("/")
async def get_sounds( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
) -> dict[str, ScanResults | str]:
"""Sync the soundboard directory (add/update/delete sounds)."""
# Only allow admins to scan sounds
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can sync sounds",
)
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
types: Annotated[
list[str] | None,
Query(description="Filter by sound types (e.g., SDB, TTS, EXT)"),
] = None,
search: Annotated[
str | None,
Query(description="Search sounds by name"),
] = None,
sort_by: Annotated[
SoundSortField | None,
Query(description="Sort by field"),
] = None,
sort_order: Annotated[
SortOrder,
Query(description="Sort order (asc or desc)"),
] = SortOrder.ASC,
limit: Annotated[
int | None,
Query(description="Maximum number of results", ge=1, le=1000),
] = None,
offset: Annotated[
int,
Query(description="Number of results to skip", ge=0),
] = 0,
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited sounds"),
] = False,
) -> SoundsListResponse:
"""Get sounds with optional search, filtering, and sorting."""
try:
results = await scanner_service.scan_soundboard_directory()
return {
"message": "Sound sync completed",
"results": results,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to sync sounds: {e!s}",
) from e
@router.post("/scan/custom")
async def scan_custom_directory(
directory: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
sound_type: str = "SDB",
) -> dict[str, ScanResults | str]:
"""Sync a custom directory with the database (add/update/delete sounds)."""
# Only allow admins to sync sounds
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can sync sounds",
sounds = await sound_repo.search_and_sort(
search_query=search,
sound_types=types,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset,
favorites_only=favorites_only,
user_id=current_user.id,
)
try:
results = await scanner_service.scan_directory(directory, sound_type)
return {
"message": f"Sync of directory '{directory}' completed",
"results": results,
}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to sync directory: {e!s}",
) from e
# NORMALIZE
@router.post("/normalize/all")
async def normalize_all_sounds(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service)
],
force: bool = Query(
False, description="Force normalization of already normalized sounds"
),
one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)"
),
) -> dict[str, NormalizationResults | str]:
"""Normalize all unnormalized sounds."""
# Only allow admins to normalize sounds
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can normalize sounds",
)
try:
results = await normalizer_service.normalize_all_sounds(
force=force,
one_pass=one_pass,
)
return {
"message": "Sound normalization completed",
"results": results,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize sounds: {e!s}",
) from e
@router.post("/normalize/type/{sound_type}")
async def normalize_sounds_by_type(
sound_type: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service)
],
force: bool = Query(
False, description="Force normalization of already normalized sounds"
),
one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)"
),
) -> dict[str, NormalizationResults | str]:
"""Normalize all sounds of a specific type (SDB, TTS, EXT)."""
# Only allow admins to normalize sounds
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can normalize sounds",
)
# Validate sound type
valid_types = ["SDB", "TTS", "EXT"]
if sound_type not in valid_types:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid sound type. Must be one of: {', '.join(valid_types)}",
)
try:
results = await normalizer_service.normalize_sounds_by_type(
sound_type=sound_type,
force=force,
one_pass=one_pass,
)
return {
"message": f"Normalization of {sound_type} sounds completed",
"results": results,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize {sound_type} sounds: {e!s}",
) from e
@router.post("/normalize/{sound_id}")
async def normalize_sound_by_id(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service)
],
force: bool = Query(
False, description="Force normalization of already normalized sound"
),
one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)"
),
) -> dict[str, str]:
"""Normalize a specific sound by ID."""
# Only allow admins to normalize sounds
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can normalize sounds",
)
try:
# Get the sound
sound = await normalizer_service.sound_repo.get_by_id(sound_id)
if not sound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Sound with ID {sound_id} not found",
# Add favorite indicators for each sound
sound_responses = []
for sound in sounds:
is_favorited = await favorite_service.is_sound_favorited(
current_user.id,
sound.id,
)
# Normalize the sound
result = await normalizer_service.normalize_sound(
sound=sound,
force=force,
one_pass=one_pass,
)
# Check result status
if result["status"] == "error":
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize sound: {result['error']}",
favorite_count = await favorite_service.get_sound_favorite_count(sound.id)
sound_response = SoundResponse.from_sound(
sound,
is_favorited,
favorite_count,
)
return {
"message": f"Sound normalization {result['status']}: {sound.filename}",
"status": result["status"],
"reason": result["reason"] or "",
"normalized_filename": result["normalized_filename"] or "",
}
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize sound: {e!s}",
) from e
# EXTRACT
@router.post("/extract")
async def create_extraction(
url: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, ExtractionInfo | str]:
"""Create a new extraction job for a URL."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
extraction_info = await extraction_service.create_extraction(
url, current_user.id
)
# Queue the extraction for background processing
await extraction_processor.queue_extraction(extraction_info["id"])
return {
"message": "Extraction queued successfully",
"extraction": extraction_info,
}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create extraction: {e!s}",
) from e
@router.get("/extract/status")
async def get_extraction_processor_status(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
) -> dict:
"""Get the status of the extraction processor."""
# Only allow admins to see processor status
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can view processor status",
)
return extraction_processor.get_status()
@router.get("/extract/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
return extraction_info
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
@router.get("/extract")
async def get_user_extractions(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, list[ExtractionInfo]]:
"""Get all extractions for the current user."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
extractions = await extraction_service.get_user_extractions(current_user.id)
return {
"extractions": extractions,
}
sound_responses.append(sound_response)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
detail=f"Failed to get sounds: {e!s}",
) from e
else:
return SoundsListResponse(sounds=sound_responses)
# VLC PLAYER
@router.post("/vlc/play/{sound_id}")
@router.post("/play/{sound_id}")
async def play_sound_with_vlc(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
credit_service: Annotated[CreditService, Depends(get_credit_service)],
) -> dict[str, str | int | bool]:
"""Play a sound using VLC subprocess (requires 1 credit)."""
try:
# Get the sound
sound = await sound_repo.get_by_id(sound_id)
if not sound:
if not current_user.id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Sound with ID {sound_id} not found",
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID is required",
)
# Check and validate credits before playing
try:
await credit_service.validate_and_reserve_credits(
current_user.id,
CreditActionType.VLC_PLAY_SOUND,
{"sound_id": sound_id, "sound_name": sound.name}
)
except InsufficientCreditsError as e:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Insufficient credits: {e.required} required, {e.available} available",
) from e
# Play the sound using VLC
success = await vlc_player.play_sound(sound)
# Deduct credits based on success
await credit_service.deduct_credits(
current_user.id,
CreditActionType.VLC_PLAY_SOUND,
success,
{"sound_id": sound_id, "sound_name": sound.name},
)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to launch VLC for sound playback",
)
return {
"message": f"Sound '{sound.name}' is now playing via VLC",
"sound_id": sound_id,
"sound_name": sound.name,
"success": True,
"credits_deducted": 1,
}
return await vlc_player.play_sound_with_credits(sound_id, current_user.id)
except HTTPException:
raise
except Exception as e:
@@ -440,16 +128,14 @@ async def play_sound_with_vlc(
) from e
@router.post("/vlc/stop-all")
@router.post("/stop")
async def stop_all_vlc_instances(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
) -> dict:
"""Stop all running VLC instances."""
try:
result = await vlc_player.stop_all_vlc_instances()
return result
return await vlc_player.stop_all_vlc_instances()
except Exception as e:
raise HTTPException(

225
app/api/v1/tts.py Normal file
View File

@@ -0,0 +1,225 @@
"""TTS API endpoints."""
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.services.tts import TTSService
router = APIRouter(prefix="/tts", tags=["tts"])
class TTSGenerateRequest(BaseModel):
"""TTS generation request model."""
text: str = Field(
..., min_length=1, max_length=1000, description="Text to convert to speech",
)
provider: str = Field(default="gtts", description="TTS provider to use")
options: dict[str, Any] = Field(
default_factory=dict, description="Provider-specific options",
)
class TTSResponse(BaseModel):
"""TTS generation response model."""
id: int
text: str
provider: str
options: dict[str, Any]
status: str
error: str | None
sound_id: int | None
user_id: int
created_at: str
class ProviderInfo(BaseModel):
"""Provider information model."""
name: str
file_extension: str
supported_languages: list[str]
option_schema: dict[str, Any]
async def get_tts_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> TTSService:
"""Get the TTS service."""
return TTSService(session)
@router.get("")
async def get_tts_list(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
tts_service: Annotated[TTSService, Depends(get_tts_service)],
limit: int = 50,
offset: int = 0,
) -> list[TTSResponse]:
"""Get TTS list for the current user."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
tts_records = await tts_service.get_user_tts_history(
user_id=current_user.id,
limit=limit,
offset=offset,
)
return [
TTSResponse(
id=tts.id,
text=tts.text,
provider=tts.provider,
options=tts.options,
status=tts.status,
error=tts.error,
sound_id=tts.sound_id,
user_id=tts.user_id,
created_at=tts.created_at.isoformat(),
)
for tts in tts_records
]
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get TTS history: {e!s}",
) from e
@router.post("")
async def generate_tts(
request: TTSGenerateRequest,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> dict[str, Any]:
"""Generate TTS audio and create sound."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
result = await tts_service.create_tts_request(
text=request.text,
user_id=current_user.id,
provider=request.provider,
**request.options,
)
tts_record = result["tts"]
return {
"message": result["message"],
"tts": TTSResponse(
id=tts_record.id,
text=tts_record.text,
provider=tts_record.provider,
options=tts_record.options,
status=tts_record.status,
error=tts_record.error,
sound_id=tts_record.sound_id,
user_id=tts_record.user_id,
created_at=tts_record.created_at.isoformat(),
),
}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate TTS: {e!s}",
) from e
@router.get("/providers")
async def get_providers(
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> dict[str, ProviderInfo]:
"""Get all available TTS providers."""
providers = tts_service.get_providers()
result = {}
for name, provider in providers.items():
result[name] = ProviderInfo(
name=provider.name,
file_extension=provider.file_extension,
supported_languages=provider.get_supported_languages(),
option_schema=provider.get_option_schema(),
)
return result
@router.get("/providers/{provider_name}")
async def get_provider(
provider_name: str,
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> ProviderInfo:
"""Get information about a specific TTS provider."""
provider = tts_service.get_provider(provider_name)
if not provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_name}' not found",
)
return ProviderInfo(
name=provider.name,
file_extension=provider.file_extension,
supported_languages=provider.get_supported_languages(),
option_schema=provider.get_option_schema(),
)
@router.delete("/{tts_id}")
async def delete_tts(
tts_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> dict[str, str]:
"""Delete a TTS generation and its associated files."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
except PermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete TTS: {e!s}",
) from e
else:
return {"message": "TTS generation deleted successfully"}

View File

@@ -18,6 +18,16 @@ class Settings(BaseSettings):
PORT: int = 8000
RELOAD: bool = True
# Production URLs (for reverse proxy deployment)
FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
# CORS Configuration
CORS_ORIGINS: list[str] = [
"http://localhost:8001", # Frontend development
"chrome-extension://*", # Chrome extensions
]
# Database Configuration
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
DATABASE_ECHO: bool = False
@@ -30,7 +40,9 @@ class Settings(BaseSettings):
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# JWT Configuration
JWT_SECRET_KEY: str = "your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
JWT_SECRET_KEY: str = (
"your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
)
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7
@@ -38,6 +50,7 @@ class Settings(BaseSettings):
# Cookie Configuration
COOKIE_SECURE: bool = True
COOKIE_SAMESITE: Literal["strict", "lax", "none"] = "lax"
COOKIE_DOMAIN: str | None = "localhost" # Cookie domain (None for production)
# OAuth2 Configuration
GOOGLE_CLIENT_ID: str = ""

View File

@@ -1,22 +1,15 @@
from collections.abc import AsyncGenerator
import asyncio
from collections.abc import AsyncGenerator, Callable
from alembic.config import Config
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
from alembic import command
from app.core.config import settings
from app.core.logging import get_logger
from app.core.seeds import seed_all_data
from app.models import ( # noqa: F401
extraction,
plan,
playlist,
playlist_sound,
sound,
sound_played,
user,
user_oauth,
)
engine: AsyncEngine = create_async_engine(
settings.DATABASE_URL,
@@ -24,7 +17,7 @@ engine: AsyncEngine = create_async_engine(
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async def get_db() -> AsyncGenerator[AsyncSession]:
"""Get a database session for dependency injection."""
logger = get_logger(__name__)
async with AsyncSession(engine) as session:
@@ -38,34 +31,33 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
await session.close()
def get_session_factory():
def get_session_factory() -> Callable[[], AsyncSession]:
"""Get a session factory function for services."""
def session_factory():
def session_factory() -> AsyncSession:
return AsyncSession(engine)
return session_factory
async def init_db() -> None:
"""Initialize the database and create tables if they do not exist."""
"""Initialize the database using Alembic migrations."""
logger = get_logger(__name__)
try:
logger.info("Initializing database tables")
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
logger.info("Database tables created successfully")
logger.info("Running database migrations")
# Run Alembic migrations programmatically
# Seed initial data
await seed_initial_data()
# Get the alembic config
alembic_cfg = Config("alembic.ini")
# Run migrations to the latest revision in a thread pool to avoid blocking
await asyncio.get_event_loop().run_in_executor(
None, command.upgrade, alembic_cfg, "head",
)
logger.info("Database migrations completed successfully")
except Exception:
logger.exception("Failed to initialize database")
raise
async def seed_initial_data() -> None:
"""Seed initial data into the database."""
logger = get_logger(__name__)
logger.info("Starting initial data seeding")
async with AsyncSession(engine) as session:
await seed_all_data(session)

View File

@@ -8,7 +8,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.logging import get_logger
from app.models.user import User
from app.repositories.sound import SoundRepository
from app.repositories.user import UserRepository
from app.services.auth import AuthService
from app.services.dashboard import DashboardService
from app.services.oauth import OAuthService
from app.utils.auth import JWTUtils, TokenUtils
@@ -135,8 +138,6 @@ async def get_current_user_api_token(
detail="Account is deactivated",
)
return user
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
@@ -146,6 +147,8 @@ async def get_current_user_api_token(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate API token",
) from e
else:
return user
async def get_current_user_flexible(
@@ -184,3 +187,12 @@ async def get_admin_user(
detail="Not enough permissions",
)
return current_user
async def get_dashboard_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> DashboardService:
"""Get the dashboard service."""
sound_repository = SoundRepository(session)
user_repository = UserRepository(session)
return DashboardService(sound_repository, user_repository)

23
app/core/services.py Normal file
View File

@@ -0,0 +1,23 @@
"""Global services container to avoid circular imports."""
from app.services.scheduler import SchedulerService
class AppServices:
"""Container for application services."""
def __init__(self) -> None:
"""Initialize the application services container."""
self.scheduler_service: SchedulerService | None = None
# Global service container
app_services = AppServices()
def get_global_scheduler_service() -> SchedulerService:
"""Get the global scheduler service instance."""
if app_services.scheduler_service is None:
msg = "Scheduler service not initialized"
raise RuntimeError(msg)
return app_services.scheduler_service

View File

@@ -6,53 +6,93 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import api_router
from app.core.database import get_session_factory, init_db
from app.core.config import settings
from app.core.database import get_session_factory
from app.core.logging import get_logger, setup_logging
from app.core.services import app_services
from app.middleware.logging import LoggingMiddleware
from app.services.extraction_processor import extraction_processor
from app.services.player import initialize_player_service, shutdown_player_service
from app.services.player import (
get_player_service,
initialize_player_service,
shutdown_player_service,
)
from app.services.scheduler import SchedulerService
from app.services.socket import socket_manager
from app.services.tts_processor import tts_processor
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
"""Application lifespan context manager for setup and teardown."""
setup_logging()
logger = get_logger(__name__)
logger.info("Starting application")
await init_db()
logger.info("Database initialized")
# Start the extraction processor
await extraction_processor.start()
logger.info("Extraction processor started")
# Start the TTS processor
await tts_processor.start()
logger.info("TTS processor started")
# Start the player service
await initialize_player_service(get_session_factory())
logger.info("Player service started")
# Start the scheduler service
try:
player_service = get_player_service() # Get the initialized player service
app_services.scheduler_service = SchedulerService(
get_session_factory(),
player_service,
)
await app_services.scheduler_service.start()
logger.info("Enhanced scheduler service started")
except Exception:
logger.exception("Failed to start scheduler service - continuing without it")
app_services.scheduler_service = None
yield
logger.info("Shutting down application")
# Stop the scheduler service
if app_services.scheduler_service:
await app_services.scheduler_service.stop()
logger.info("Scheduler service stopped")
# Stop the player service
await shutdown_player_service()
logger.info("Player service stopped")
# Stop the TTS processor
await tts_processor.stop()
logger.info("TTS processor stopped")
# Stop the extraction processor
await extraction_processor.stop()
logger.info("Extraction processor stopped")
def create_app():
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
app = FastAPI(lifespan=lifespan)
app = FastAPI(
title="SBD v2 API",
description=("API for the SBD v2 application"),
version="2.0.0",
lifespan=lifespan,
# Configure docs URLs for reverse proxy setup
docs_url="/api/docs", # Swagger UI at /api/docs
redoc_url="/api/redoc", # ReDoc at /api/redoc
openapi_url="/api/openapi.json", # OpenAPI schema at /api/openapi.json
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:8001"],
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],

View File

@@ -1 +1,34 @@
"""Models package."""
# Import all models for SQLAlchemy metadata discovery
from .base import BaseModel
from .credit_action import CreditAction
from .credit_transaction import CreditTransaction
from .extraction import Extraction
from .favorite import Favorite
from .plan import Plan
from .playlist import Playlist
from .playlist_sound import PlaylistSound
from .scheduled_task import ScheduledTask
from .sound import Sound
from .sound_played import SoundPlayed
from .tts import TTS
from .user import User
from .user_oauth import UserOauth
__all__ = [
"TTS",
"BaseModel",
"CreditAction",
"CreditTransaction",
"Extraction",
"Favorite",
"Plan",
"Playlist",
"PlaylistSound",
"ScheduledTask",
"Sound",
"SoundPlayed",
"User",
"UserOauth",
]

View File

@@ -1,5 +1,9 @@
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import event
from sqlalchemy.engine import Connection
from sqlalchemy.orm import Mapper
from sqlmodel import Field, SQLModel
@@ -11,3 +15,14 @@ class BaseModel(SQLModel):
# timestamps
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
# SQLAlchemy event listener to automatically update updated_at timestamp
@event.listens_for(BaseModel, "before_update", propagate=True)
def update_timestamp(
mapper: Mapper[Any], # noqa: ARG001
connection: Connection, # noqa: ARG001
target: BaseModel,
) -> None:
"""Automatically set updated_at timestamp before update operations."""
target.updated_at = datetime.now(UTC)

View File

@@ -13,6 +13,7 @@ class CreditActionType(str, Enum):
SOUND_NORMALIZATION = "sound_normalization"
API_REQUEST = "api_request"
PLAYLIST_CREATION = "playlist_creation"
DAILY_RECHARGE = "daily_recharge"
class CreditAction:
@@ -92,6 +93,12 @@ CREDIT_ACTIONS = {
description="Create a new playlist",
requires_success=True,
),
CreditActionType.DAILY_RECHARGE: CreditAction(
action_type=CreditActionType.DAILY_RECHARGE,
cost=0, # This is a credit addition, not deduction
description="Daily credit recharge",
requires_success=True,
),
}
@@ -118,4 +125,4 @@ def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
Dictionary of all credit actions
"""
return CREDIT_ACTIONS.copy()
return CREDIT_ACTIONS.copy()

View File

@@ -17,7 +17,8 @@ class CreditTransaction(BaseModel, table=True):
user_id: int = Field(foreign_key="user.id", nullable=False)
action_type: str = Field(nullable=False)
amount: int = Field(nullable=False) # Negative for deductions, positive for additions
# Negative for deductions, positive for additions
amount: int = Field(nullable=False)
balance_before: int = Field(nullable=False)
balance_after: int = Field(nullable=False)
description: str = Field(nullable=False)
@@ -26,4 +27,4 @@ class CreditTransaction(BaseModel, table=True):
metadata_json: str | None = Field(default=None)
# relationships
user: "User" = Relationship(back_populates="credit_transactions")
user: "User" = Relationship(back_populates="credit_transactions")

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel import Field, Relationship
from app.models.base import BaseModel
@@ -25,7 +25,8 @@ class Extraction(BaseModel, table=True):
status: str = Field(nullable=False, default="pending")
error: str | None = Field(default=None)
# constraints - only enforce uniqueness when both service and service_id are not null
# constraints - only enforce uniqueness when both service and service_id
# are not null
__table_args__ = ()
# relationships

29
app/models/favorite.py Normal file
View File

@@ -0,0 +1,29 @@
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.playlist import Playlist
from app.models.sound import Sound
from app.models.user import User
class Favorite(BaseModel, table=True):
"""Database model for user favorites (sounds and playlists)."""
user_id: int = Field(foreign_key="user.id", nullable=False)
sound_id: int | None = Field(foreign_key="sound.id", default=None)
playlist_id: int | None = Field(foreign_key="playlist.id", default=None)
# constraints
__table_args__ = (
UniqueConstraint("user_id", "sound_id", name="uq_favorite_user_sound"),
UniqueConstraint("user_id", "playlist_id", name="uq_favorite_user_playlist"),
)
# relationships
user: "User" = Relationship(back_populates="favorites")
sound: "Sound" = Relationship(back_populates="favorites")
playlist: "Playlist" = Relationship(back_populates="favorites")

View File

@@ -5,6 +5,7 @@ from sqlmodel import Field, Relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.favorite import Favorite
from app.models.playlist_sound import PlaylistSound
from app.models.user import User
@@ -23,3 +24,4 @@ class Playlist(BaseModel, table=True):
# relationships
user: "User" = Relationship(back_populates="playlists")
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="playlist")
favorites: list["Favorite"] = Relationship(back_populates="playlist")

View File

@@ -0,0 +1,125 @@
"""Scheduled task model for flexible task scheduling with timezone support."""
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from sqlmodel import JSON, Column, Field
from app.models.base import BaseModel
class TaskType(str, Enum):
"""Available task types."""
CREDIT_RECHARGE = "credit_recharge"
PLAY_SOUND = "play_sound"
PLAY_PLAYLIST = "play_playlist"
class TaskStatus(str, Enum):
"""Task execution status."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class RecurrenceType(str, Enum):
"""Recurrence patterns."""
NONE = "none" # One-shot task
MINUTELY = "minutely"
HOURLY = "hourly"
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
YEARLY = "yearly"
CRON = "cron" # Custom cron expression
class ScheduledTask(BaseModel, table=True):
"""Model for scheduled tasks with timezone support."""
__tablename__ = "scheduled_task"
id: int | None = Field(primary_key=True, default=None)
name: str = Field(max_length=255, description="Human-readable task name")
task_type: TaskType = Field(description="Type of task to execute")
status: TaskStatus = Field(default=TaskStatus.PENDING)
# Scheduling fields with timezone support
scheduled_at: datetime = Field(description="When the task should be executed (UTC)")
timezone: str = Field(
default="UTC",
description="Timezone for scheduling (e.g., 'America/New_York')",
)
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
cron_expression: str | None = Field(
default=None,
description="Cron expression for custom recurrence",
)
recurrence_count: int | None = Field(
default=None,
description="Number of times to repeat (None for infinite)",
)
executions_count: int = Field(default=0, description="Number of times executed")
# Task parameters
parameters: dict[str, Any] = Field(
default_factory=dict,
sa_column=Column(JSON),
description="Task-specific parameters",
)
# User association (None for system tasks)
user_id: int | None = Field(
default=None,
foreign_key="user.id",
description="User who created the task (None for system tasks)",
)
# Execution tracking
last_executed_at: datetime | None = Field(
default=None,
description="When the task was last executed (UTC)",
)
next_execution_at: datetime | None = Field(
default=None,
description="When the task should be executed next (UTC, for recurring tasks)",
)
error_message: str | None = Field(
default=None,
description="Error message if execution failed",
)
# Task lifecycle
is_active: bool = Field(default=True, description="Whether the task is active")
expires_at: datetime | None = Field(
default=None,
description="When the task expires (UTC, optional)",
)
def is_expired(self) -> bool:
"""Check if the task has expired."""
if self.expires_at is None:
return False
return datetime.now(tz=UTC).replace(tzinfo=None) > self.expires_at
def is_recurring(self) -> bool:
"""Check if the task is recurring."""
return self.recurrence_type != RecurrenceType.NONE
def should_repeat(self) -> bool:
"""Check if the task should be repeated."""
if not self.is_recurring():
return False
if self.recurrence_count is None:
return True
return self.executions_count < self.recurrence_count
def is_system_task(self) -> bool:
"""Check if this is a system task (no user association)."""
return self.user_id is None

View File

@@ -6,6 +6,7 @@ from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.extraction import Extraction
from app.models.favorite import Favorite
from app.models.playlist_sound import PlaylistSound
from app.models.sound_played import SoundPlayed
@@ -30,11 +31,14 @@ class Sound(BaseModel, table=True):
is_deletable: bool = Field(default=True, nullable=False)
# constraints
__table_args__ = (
UniqueConstraint("hash", name="uq_sound_hash"),
)
__table_args__ = (UniqueConstraint("hash", name="uq_sound_hash"),)
# relationships
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="sound")
playlist_sounds: list["PlaylistSound"] = Relationship(
back_populates="sound", cascade_delete=True,
)
extractions: list["Extraction"] = Relationship(back_populates="sound")
play_history: list["SoundPlayed"] = Relationship(back_populates="sound")
play_history: list["SoundPlayed"] = Relationship(
back_populates="sound", cascade_delete=True,
)
favorites: list["Favorite"] = Relationship(back_populates="sound")

30
app/models/tts.py Normal file
View File

@@ -0,0 +1,30 @@
"""TTS model."""
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
class TTS(SQLModel, table=True):
"""Text-to-Speech generation record."""
__tablename__ = "tts"
id: int | None = Field(primary_key=True)
text: str = Field(max_length=1000, description="Text that was converted to speech")
provider: str = Field(max_length=50, description="TTS provider used")
options: dict[str, Any] = Field(
default_factory=dict,
sa_column=Column(JSON),
description="Provider-specific options used",
)
status: str = Field(default="pending", description="Processing status")
error: str | None = Field(default=None, description="Error message if failed")
sound_id: int | None = Field(
foreign_key="sound.id", description="Associated sound ID",
)
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -8,6 +8,7 @@ from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.credit_transaction import CreditTransaction
from app.models.extraction import Extraction
from app.models.favorite import Favorite
from app.models.plan import Plan
from app.models.playlist import Playlist
from app.models.sound_played import SoundPlayed
@@ -37,3 +38,4 @@ class User(BaseModel, table=True):
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
extractions: list["Extraction"] = Relationship(back_populates="user")
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")
favorites: list["Favorite"] = Relationship(back_populates="user")

View File

@@ -1,6 +1,6 @@
"""Base repository with common CRUD operations."""
from typing import Any, Generic, TypeVar
from typing import Any, TypeVar
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -13,7 +13,7 @@ ModelType = TypeVar("ModelType")
logger = get_logger(__name__)
class BaseRepository(Generic[ModelType]):
class BaseRepository[ModelType]:
"""Base repository with common CRUD operations."""
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
@@ -38,11 +38,15 @@ class BaseRepository(Generic[ModelType]):
"""
try:
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
statement = select(self.model).where(self.model.id == entity_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
logger.exception(
"Failed to get %s by ID: %s",
self.model.__name__,
entity_id,
)
raise
async def get_all(
@@ -68,27 +72,36 @@ class BaseRepository(Generic[ModelType]):
logger.exception("Failed to get all %s", self.model.__name__)
raise
async def create(self, entity_data: dict[str, Any]) -> ModelType:
async def create(self, entity_data: dict[str, Any] | ModelType) -> ModelType:
"""Create a new entity.
Args:
entity_data: Dictionary of entity data
entity_data: Dictionary of entity data or model instance
Returns:
The created entity
"""
try:
entity = self.model(**entity_data)
if isinstance(entity_data, dict):
entity = self.model(**entity_data)
else:
# Already a model instance
entity = entity_data
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to create %s", self.model.__name__)
raise
else:
logger.info(
"Created new %s with ID: %s",
self.model.__name__,
getattr(entity, "id", "unknown"),
)
return entity
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
"""Update an entity.
@@ -105,15 +118,22 @@ class BaseRepository(Generic[ModelType]):
for field, value in update_data.items():
setattr(entity, field, value)
# The updated_at timestamp will be automatically set by the SQLAlchemy
# event listener
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to update %s", self.model.__name__)
raise
else:
logger.info(
"Updated %s with ID: %s",
self.model.__name__,
getattr(entity, "id", "unknown"),
)
return entity
async def delete(self, entity: ModelType) -> None:
"""Delete an entity.
@@ -125,8 +145,12 @@ class BaseRepository(Generic[ModelType]):
try:
await self.session.delete(entity)
await self.session.commit()
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
logger.info(
"Deleted %s with ID: %s",
self.model.__name__,
getattr(entity, "id", "unknown"),
)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete %s", self.model.__name__)
raise
raise

View File

@@ -91,18 +91,17 @@ class CreditTransactionRepository(BaseRepository[CreditTransaction]):
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.success == True) # noqa: E712
select(CreditTransaction).where(CreditTransaction.success == True) # noqa: E712
)
if user_id is not None:
stmt = stmt.where(CreditTransaction.user_id == user_id)
stmt = (
stmt.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())
return list(result.all())

View File

@@ -1,42 +1,32 @@
"""Extraction repository for database operations."""
from sqlalchemy import desc
from sqlalchemy import asc, desc, func, or_
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.extraction import Extraction
from app.models.user import User
from app.repositories.base import BaseRepository
class ExtractionRepository:
class ExtractionRepository(BaseRepository[Extraction]):
"""Repository for extraction database operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the extraction repository."""
self.session = session
async def create(self, extraction_data: dict) -> Extraction:
"""Create a new extraction."""
extraction = Extraction(**extraction_data)
self.session.add(extraction)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def get_by_id(self, extraction_id: int) -> Extraction | None:
"""Get an extraction by ID."""
result = await self.session.exec(
select(Extraction).where(Extraction.id == extraction_id)
)
return result.first()
super().__init__(Extraction, session)
async def get_by_service_and_id(
self, service: str, service_id: str
self,
service: str,
service_id: str,
) -> Extraction | None:
"""Get an extraction by service and service_id."""
result = await self.session.exec(
select(Extraction).where(
Extraction.service == service, Extraction.service_id == service_id
)
Extraction.service == service,
Extraction.service_id == service_id,
),
)
return result.first()
@@ -45,38 +35,131 @@ class ExtractionRepository:
result = await self.session.exec(
select(Extraction)
.where(Extraction.user_id == user_id)
.order_by(desc(Extraction.created_at))
.order_by(desc(Extraction.created_at)),
)
return list(result.all())
async def get_pending_extractions(self) -> list[Extraction]:
"""Get all pending extractions."""
async def get_by_status(self, status: str) -> list[Extraction]:
"""Get all extractions by status."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == "pending")
.order_by(Extraction.created_at)
.where(Extraction.status == status)
.order_by(Extraction.created_at),
)
return list(result.all())
async def update(self, extraction: Extraction, update_data: dict) -> Extraction:
"""Update an extraction."""
for key, value in update_data.items():
setattr(extraction, key, value)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def delete(self, extraction: Extraction) -> None:
"""Delete an extraction."""
await self.session.delete(extraction)
await self.session.commit()
async def get_pending_extractions(self) -> list[tuple[Extraction, User]]:
"""Get all pending extractions."""
result = await self.session.exec(
select(Extraction, User)
.join(User, Extraction.user_id == User.id)
.where(Extraction.status == "pending")
.order_by(Extraction.created_at),
)
return list(result.all())
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
"""Get extractions by status."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == status)
.order_by(desc(Extraction.created_at))
.order_by(desc(Extraction.created_at)),
)
return list(result.all())
async def get_user_extractions_filtered( # noqa: PLR0913
self,
user_id: int,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
limit: int = 50,
offset: int = 0,
) -> tuple[list[tuple[Extraction, User]], int]:
"""Get extractions for a user with filtering, search, and sorting."""
base_query = (
select(Extraction, User)
.join(User, Extraction.user_id == User.id)
.where(Extraction.user_id == user_id)
)
# Apply search filter
if search:
search_pattern = f"%{search}%"
base_query = base_query.where(
or_(
Extraction.title.ilike(search_pattern),
Extraction.url.ilike(search_pattern),
Extraction.service.ilike(search_pattern),
),
)
# Apply status filter
if status_filter:
base_query = base_query.where(Extraction.status == status_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply sorting and pagination
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
if sort_order.lower() == "asc":
base_query = base_query.order_by(asc(sort_column))
else:
base_query = base_query.order_by(desc(sort_column))
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
return list(result.all()), total_count
async def get_all_extractions_filtered( # noqa: PLR0913
self,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
limit: int = 50,
offset: int = 0,
) -> tuple[list[tuple[Extraction, User]], int]:
"""Get all extractions with filtering, search, and sorting."""
base_query = select(Extraction, User).join(User, Extraction.user_id == User.id)
# Apply search filter
if search:
search_pattern = f"%{search}%"
base_query = base_query.where(
or_(
Extraction.title.ilike(search_pattern),
Extraction.url.ilike(search_pattern),
Extraction.service.ilike(search_pattern),
),
)
# Apply status filter
if status_filter:
base_query = base_query.where(Extraction.status == status_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply sorting and pagination
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
if sort_order.lower() == "asc":
base_query = base_query.order_by(asc(sort_column))
else:
base_query = base_query.order_by(desc(sort_column))
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
return list(result.all()), total_count

View File

@@ -0,0 +1,258 @@
"""Repository for managing favorites."""
from sqlmodel import and_, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.favorite import Favorite
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class FavoriteRepository(BaseRepository[Favorite]):
"""Repository for managing favorites."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the favorite repository.
Args:
session: Database session
"""
super().__init__(Favorite, session)
async def get_user_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get all favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user favorites
"""
try:
statement = (
select(Favorite)
.where(Favorite.user_id == user_id)
.limit(limit)
.offset(offset)
.order_by(Favorite.created_at.desc())
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get favorites for user: %s", user_id)
raise
async def get_user_sound_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get sound favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user sound favorites
"""
try:
statement = (
select(Favorite)
.where(and_(Favorite.user_id == user_id, Favorite.sound_id.isnot(None)))
.limit(limit)
.offset(offset)
.order_by(Favorite.created_at.desc())
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get sound favorites for user: %s", user_id)
raise
async def get_user_playlist_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get playlist favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user playlist favorites
"""
try:
statement = (
select(Favorite)
.where(
and_(Favorite.user_id == user_id, Favorite.playlist_id.isnot(None)),
)
.limit(limit)
.offset(offset)
.order_by(Favorite.created_at.desc())
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get playlist favorites for user: %s", user_id)
raise
async def get_by_user_and_sound(
self,
user_id: int,
sound_id: int,
) -> Favorite | None:
"""Get a favorite by user and sound.
Args:
user_id: The user ID
sound_id: The sound ID
Returns:
The favorite if found, None otherwise
"""
try:
statement = select(Favorite).where(
and_(Favorite.user_id == user_id, Favorite.sound_id == sound_id),
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception(
"Failed to get favorite for user %s and sound %s",
user_id,
sound_id,
)
raise
async def get_by_user_and_playlist(
self,
user_id: int,
playlist_id: int,
) -> Favorite | None:
"""Get a favorite by user and playlist.
Args:
user_id: The user ID
playlist_id: The playlist ID
Returns:
The favorite if found, None otherwise
"""
try:
statement = select(Favorite).where(
and_(Favorite.user_id == user_id, Favorite.playlist_id == playlist_id),
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception(
"Failed to get favorite for user %s and playlist %s",
user_id,
playlist_id,
)
raise
async def is_sound_favorited(self, user_id: int, sound_id: int) -> bool:
"""Check if a sound is favorited by a user.
Args:
user_id: The user ID
sound_id: The sound ID
Returns:
True if the sound is favorited, False otherwise
"""
favorite = await self.get_by_user_and_sound(user_id, sound_id)
return favorite is not None
async def is_playlist_favorited(self, user_id: int, playlist_id: int) -> bool:
"""Check if a playlist is favorited by a user.
Args:
user_id: The user ID
playlist_id: The playlist ID
Returns:
True if the playlist is favorited, False otherwise
"""
favorite = await self.get_by_user_and_playlist(user_id, playlist_id)
return favorite is not None
async def count_user_favorites(self, user_id: int) -> int:
"""Count total favorites for a user.
Args:
user_id: The user ID
Returns:
Total number of favorites
"""
try:
statement = select(Favorite).where(Favorite.user_id == user_id)
result = await self.session.exec(statement)
return len(list(result.all()))
except Exception:
logger.exception("Failed to count favorites for user: %s", user_id)
raise
async def count_sound_favorites(self, sound_id: int) -> int:
"""Count how many users have favorited a sound.
Args:
sound_id: The sound ID
Returns:
Number of users who favorited this sound
"""
try:
statement = select(Favorite).where(Favorite.sound_id == sound_id)
result = await self.session.exec(statement)
return len(list(result.all()))
except Exception:
logger.exception("Failed to count favorites for sound: %s", sound_id)
raise
async def count_playlist_favorites(self, playlist_id: int) -> int:
"""Count how many users have favorited a playlist.
Args:
playlist_id: The playlist ID
Returns:
Number of users who favorited this playlist
"""
try:
statement = select(Favorite).where(Favorite.playlist_id == playlist_id)
result = await self.session.exec(statement)
return len(list(result.all()))
except Exception:
logger.exception("Failed to count favorites for playlist: %s", playlist_id)
raise

17
app/repositories/plan.py Normal file
View File

@@ -0,0 +1,17 @@
"""Plan repository."""
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.plan import Plan
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class PlanRepository(BaseRepository[Plan]):
"""Repository for plan operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the plan repository."""
super().__init__(Plan, session)

View File

@@ -1,34 +1,65 @@
"""Playlist repository for database operations."""
from typing import Any
from datetime import UTC, datetime
from enum import Enum
from sqlalchemy import func
from sqlalchemy import func, update
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.favorite import Favorite
from app.models.playlist import Playlist
from app.models.playlist_sound import PlaylistSound
from app.models.sound import Sound
from app.models.user import User
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class PlaylistRepository:
class PlaylistSortField(str, Enum):
"""Playlist sort field enumeration."""
NAME = "name"
GENRE = "genre"
CREATED_AT = "created_at"
UPDATED_AT = "updated_at"
SOUND_COUNT = "sound_count"
TOTAL_DURATION = "total_duration"
class SortOrder(str, Enum):
"""Sort order enumeration."""
ASC = "asc"
DESC = "desc"
class PlaylistRepository(BaseRepository[Playlist]):
"""Repository for playlist operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the playlist repository."""
self.session = session
super().__init__(Playlist, session)
async def get_by_id(self, playlist_id: int) -> Playlist | None:
"""Get a playlist by ID."""
async def _update_playlist_timestamp(self, playlist_id: int) -> None:
"""Update the playlist's updated_at timestamp."""
try:
statement = select(Playlist).where(Playlist.id == playlist_id)
result = await self.session.exec(statement)
return result.first()
update_stmt = (
update(Playlist)
.where(Playlist.id == playlist_id)
.values(updated_at=datetime.now(UTC))
)
await self.session.exec(update_stmt)
# Note: No commit here - let the calling method handle transaction
# management
except Exception:
logger.exception("Failed to get playlist by ID: %s", playlist_id)
logger.exception(
"Failed to update playlist timestamp for playlist: %s",
playlist_id,
)
raise
async def get_by_name(self, name: str) -> Playlist | None:
@@ -51,16 +82,6 @@ class PlaylistRepository:
logger.exception("Failed to get playlists for user: %s", user_id)
raise
async def get_all(self) -> list[Playlist]:
"""Get all playlists from all users."""
try:
statement = select(Playlist)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all playlists")
raise
async def get_main_playlist(self) -> Playlist | None:
"""Get the global main playlist."""
try:
@@ -73,63 +94,22 @@ class PlaylistRepository:
logger.exception("Failed to get main playlist")
raise
async def get_current_playlist(self, user_id: int) -> Playlist | None:
"""Get the user's current playlist."""
async def get_current_playlist(self) -> Playlist | None:
"""Get the global current playlist (app-wide)."""
try:
statement = select(Playlist).where(
Playlist.user_id == user_id,
Playlist.is_current == True, # noqa: E712
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get current playlist for user: %s", user_id)
raise
async def create(self, playlist_data: dict[str, Any]) -> Playlist:
"""Create a new playlist."""
try:
playlist = Playlist(**playlist_data)
self.session.add(playlist)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to create playlist")
raise
else:
logger.info("Created new playlist: %s", playlist.name)
return playlist
async def update(self, playlist: Playlist, update_data: dict[str, Any]) -> Playlist:
"""Update a playlist."""
try:
for field, value in update_data.items():
setattr(playlist, field, value)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to update playlist")
raise
else:
logger.info("Updated playlist: %s", playlist.name)
return playlist
async def delete(self, playlist: Playlist) -> None:
"""Delete a playlist."""
try:
await self.session.delete(playlist)
await self.session.commit()
logger.info("Deleted playlist: %s", playlist.name)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete playlist")
logger.exception("Failed to get current playlist")
raise
async def search_by_name(
self, query: str, user_id: int | None = None
self,
query: str,
user_id: int | None = None,
) -> list[Playlist]:
"""Search playlists by name (case-insensitive)."""
try:
@@ -146,11 +126,12 @@ class PlaylistRepository:
raise
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
"""Get all sounds in a playlist, ordered by position."""
"""Get all sounds in a playlist with extractions, ordered by position."""
try:
statement = (
select(Sound)
.join(PlaylistSound)
.options(selectinload(Sound.extractions))
.where(PlaylistSound.playlist_id == playlist_id)
.order_by(PlaylistSound.position)
)
@@ -160,18 +141,67 @@ class PlaylistRepository:
logger.exception("Failed to get sounds for playlist: %s", playlist_id)
raise
async def get_playlist_sound_entries(self, playlist_id: int) -> list[PlaylistSound]:
"""Get all PlaylistSound entries for a playlist, ordered by position."""
try:
statement = (
select(PlaylistSound)
.where(PlaylistSound.playlist_id == playlist_id)
.order_by(PlaylistSound.position)
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception(
"Failed to get playlist sound entries for playlist: %s",
playlist_id,
)
raise
async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, position: int | None = None
self,
playlist_id: int,
sound_id: int,
position: int | None = None,
) -> PlaylistSound:
"""Add a sound to a playlist."""
try:
if position is None:
# Get the next available position
statement = select(
func.coalesce(func.max(PlaylistSound.position), -1) + 1
func.coalesce(func.max(PlaylistSound.position), -1) + 1,
).where(PlaylistSound.playlist_id == playlist_id)
result = await self.session.exec(statement)
position = result.first() or 0
else:
# Shift existing positions to make room for the new sound
# Use a two-step approach to avoid unique constraint violations:
# 1. Move all affected positions to negative temporary positions
# 2. Then move them to their final positions
# Step 1: Move to temporary negative positions
update_to_negative = (
update(PlaylistSound)
.where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.position >= position,
)
.values(position=PlaylistSound.position - 10000)
)
await self.session.exec(update_to_negative)
await self.session.commit()
# Step 2: Move from temporary negative positions to final positions
update_to_final = (
update(PlaylistSound)
.where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.position < 0,
)
.values(position=PlaylistSound.position + 10001)
)
await self.session.exec(update_to_final)
await self.session.commit()
playlist_sound = PlaylistSound(
playlist_id=playlist_id,
@@ -179,12 +209,17 @@ class PlaylistRepository:
position=position,
)
self.session.add(playlist_sound)
# Update playlist timestamp before commit
await self._update_playlist_timestamp(playlist_id)
await self.session.commit()
await self.session.refresh(playlist_sound)
except Exception:
await self.session.rollback()
logger.exception(
"Failed to add sound %s to playlist %s", sound_id, playlist_id
"Failed to add sound %s to playlist %s",
sound_id,
playlist_id,
)
raise
else:
@@ -208,25 +243,47 @@ class PlaylistRepository:
if playlist_sound:
await self.session.delete(playlist_sound)
# Update playlist timestamp before commit
await self._update_playlist_timestamp(playlist_id)
await self.session.commit()
logger.info("Removed sound %s from playlist %s", sound_id, playlist_id)
except Exception:
await self.session.rollback()
logger.exception(
"Failed to remove sound %s from playlist %s", sound_id, playlist_id
"Failed to remove sound %s from playlist %s",
sound_id,
playlist_id,
)
raise
async def reorder_playlist_sounds(
self, playlist_id: int, sound_positions: list[tuple[int, int]]
self,
playlist_id: int,
sound_positions: list[tuple[int, int]],
) -> None:
"""Reorder sounds in a playlist.
Args:
playlist_id: The playlist ID
sound_positions: List of (sound_id, new_position) tuples
"""
try:
# Phase 1: Set all positions to temporary negative values to avoid conflicts
temp_offset = -10000 # Use large negative number to avoid conflicts
for i, (sound_id, _) in enumerate(sound_positions):
statement = select(PlaylistSound).where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.sound_id == sound_id,
)
result = await self.session.exec(statement)
playlist_sound = result.first()
if playlist_sound:
playlist_sound.position = temp_offset + i
# Phase 2: Set the final positions
for sound_id, new_position in sound_positions:
statement = select(PlaylistSound).where(
PlaylistSound.playlist_id == playlist_id,
@@ -238,6 +295,8 @@ class PlaylistRepository:
if playlist_sound:
playlist_sound.position = new_position
# Update playlist timestamp before commit
await self._update_playlist_timestamp(playlist_id)
await self.session.commit()
logger.info("Reordered sounds in playlist %s", playlist_id)
except Exception:
@@ -249,7 +308,7 @@ class PlaylistRepository:
"""Get the number of sounds in a playlist."""
try:
statement = select(func.count(PlaylistSound.id)).where(
PlaylistSound.playlist_id == playlist_id
PlaylistSound.playlist_id == playlist_id,
)
result = await self.session.exec(statement)
return result.first() or 0
@@ -268,6 +327,230 @@ class PlaylistRepository:
return result.first() is not None
except Exception:
logger.exception(
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id
"Failed to check if sound %s is in playlist %s",
sound_id,
playlist_id,
)
raise
async def search_and_sort( # noqa: C901, PLR0913, PLR0912, PLR0915
self,
search_query: str | None = None,
sort_by: PlaylistSortField | None = None,
sort_order: SortOrder = SortOrder.ASC,
user_id: int | None = None,
include_stats: bool = False, # noqa: FBT001, FBT002
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False, # noqa: FBT001, FBT002
current_user_id: int | None = None,
*,
return_count: bool = False,
) -> list[dict] | tuple[list[dict], int]:
"""Search and sort playlists with optional statistics."""
try:
if include_stats and sort_by in (
PlaylistSortField.SOUND_COUNT,
PlaylistSortField.TOTAL_DURATION,
):
# Use subquery for sorting by stats
subquery = (
select(
Playlist.id,
Playlist.name,
Playlist.description,
Playlist.genre,
Playlist.user_id,
Playlist.is_main,
Playlist.is_current,
Playlist.is_deletable,
Playlist.created_at,
Playlist.updated_at,
func.count(PlaylistSound.id).label("sound_count"),
func.coalesce(func.sum(Sound.duration), 0).label(
"total_duration",
),
User.name.label("user_name"),
)
.select_from(Playlist)
.join(User, Playlist.user_id == User.id, isouter=True)
.join(
PlaylistSound,
Playlist.id == PlaylistSound.playlist_id,
isouter=True,
)
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name)
)
# Apply filters
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(
func.lower(Playlist.name).like(search_pattern),
)
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = (
select(1)
.select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
)
subquery = subquery.where(favorites_subquery.exists())
# Apply sorting
if sort_by == PlaylistSortField.SOUND_COUNT:
if sort_order == SortOrder.DESC:
subquery = subquery.order_by(
func.count(PlaylistSound.id).desc(),
)
else:
subquery = subquery.order_by(func.count(PlaylistSound.id).asc())
elif sort_by == PlaylistSortField.TOTAL_DURATION:
if sort_order == SortOrder.DESC:
subquery = subquery.order_by(
func.coalesce(func.sum(Sound.duration), 0).desc(),
)
else:
subquery = subquery.order_by(
func.coalesce(func.sum(Sound.duration), 0).asc(),
)
else:
# Default sorting by name
subquery = subquery.order_by(Playlist.name.asc())
else:
# Simple query without stats-based sorting
subquery = (
select(
Playlist.id,
Playlist.name,
Playlist.description,
Playlist.genre,
Playlist.user_id,
Playlist.is_main,
Playlist.is_current,
Playlist.is_deletable,
Playlist.created_at,
Playlist.updated_at,
func.count(PlaylistSound.id).label("sound_count"),
func.coalesce(func.sum(Sound.duration), 0).label(
"total_duration",
),
User.name.label("user_name"),
)
.select_from(Playlist)
.join(User, Playlist.user_id == User.id, isouter=True)
.join(
PlaylistSound,
Playlist.id == PlaylistSound.playlist_id,
isouter=True,
)
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name)
)
# Apply filters
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(
func.lower(Playlist.name).like(search_pattern),
)
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = (
select(1)
.select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
)
subquery = subquery.where(favorites_subquery.exists())
# Apply sorting
if sort_by:
if sort_by == PlaylistSortField.NAME:
sort_column = Playlist.name
elif sort_by == PlaylistSortField.GENRE:
sort_column = Playlist.genre
elif sort_by == PlaylistSortField.CREATED_AT:
sort_column = Playlist.created_at
elif sort_by == PlaylistSortField.UPDATED_AT:
sort_column = Playlist.updated_at
else:
sort_column = Playlist.name
if sort_order == SortOrder.DESC:
subquery = subquery.order_by(sort_column.desc())
else:
subquery = subquery.order_by(sort_column.asc())
else:
# Default sorting by name ascending
subquery = subquery.order_by(Playlist.name.asc())
# Get total count if requested
total_count = 0
if return_count:
# Create count query from the subquery before pagination
count_query = select(func.count()).select_from(subquery.subquery())
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply pagination
if offset > 0:
subquery = subquery.offset(offset)
if limit is not None:
subquery = subquery.limit(limit)
result = await self.session.exec(subquery)
rows = result.all()
# Convert to dictionary format
playlists = [
{
"id": row.id,
"name": row.name,
"description": row.description,
"genre": row.genre,
"user_id": row.user_id,
"user_name": row.user_name,
"is_main": row.is_main,
"is_current": row.is_current,
"is_deletable": row.is_deletable,
"created_at": row.created_at,
"updated_at": row.updated_at,
"sound_count": row.sound_count or 0,
"total_duration": row.total_duration or 0,
}
for row in rows
]
except Exception:
logger.exception(
(
"Failed to search and sort playlists: "
"query=%s, sort_by=%s, sort_order=%s"
),
search_query,
sort_by,
sort_order,
)
raise
else:
if return_count:
return playlists, total_count
return playlists

View File

@@ -0,0 +1,181 @@
"""Repository for scheduled task operations."""
from datetime import UTC, datetime
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.base import BaseRepository
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
"""Repository for scheduled task database operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the repository."""
super().__init__(ScheduledTask, session)
async def get_pending_tasks(self) -> list[ScheduledTask]:
"""Get all pending tasks that are ready to be executed."""
now = datetime.now(tz=UTC)
statement = select(ScheduledTask).where(
ScheduledTask.status == TaskStatus.PENDING,
ScheduledTask.is_active.is_(True),
ScheduledTask.scheduled_at <= now,
)
result = await self.session.exec(statement)
return list(result.all())
async def get_active_tasks(self) -> list[ScheduledTask]:
"""Get all active tasks."""
statement = select(ScheduledTask).where(
ScheduledTask.is_active.is_(True),
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
)
result = await self.session.exec(statement)
return list(result.all())
async def get_user_tasks(
self,
user_id: int,
status: TaskStatus | None = None,
task_type: TaskType | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[ScheduledTask]:
"""Get tasks for a specific user."""
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
result = await self.session.exec(statement)
return list(result.all())
async def get_system_tasks(
self,
status: TaskStatus | None = None,
task_type: TaskType | None = None,
) -> list[ScheduledTask]:
"""Get system tasks (tasks with no user association)."""
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
result = await self.session.exec(statement)
return list(result.all())
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
"""Get recurring tasks that need their next execution scheduled."""
now = datetime.now(tz=UTC)
statement = select(ScheduledTask).where(
ScheduledTask.recurrence_type != RecurrenceType.NONE,
ScheduledTask.is_active.is_(True),
ScheduledTask.status == TaskStatus.COMPLETED,
ScheduledTask.next_execution_at <= now,
)
result = await self.session.exec(statement)
return list(result.all())
async def get_expired_tasks(self) -> list[ScheduledTask]:
"""Get expired tasks that should be cleaned up."""
now = datetime.now(tz=UTC)
statement = select(ScheduledTask).where(
ScheduledTask.expires_at.is_not(None),
ScheduledTask.expires_at <= now,
ScheduledTask.is_active.is_(True),
)
result = await self.session.exec(statement)
return list(result.all())
async def cancel_user_tasks(
self,
user_id: int,
task_type: TaskType | None = None,
) -> int:
"""Cancel all pending/running tasks for a user."""
statement = select(ScheduledTask).where(
ScheduledTask.user_id == user_id,
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
result = await self.session.exec(statement)
tasks = list(result.all())
count = 0
for task in tasks:
task.status = TaskStatus.CANCELLED
task.is_active = False
self.session.add(task)
count += 1
await self.session.commit()
return count
async def mark_as_running(self, task: ScheduledTask) -> None:
"""Mark a task as running."""
task.status = TaskStatus.RUNNING
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)
async def mark_as_completed(
self,
task: ScheduledTask,
next_execution_at: datetime | None = None,
) -> None:
"""Mark a task as completed and set next execution if recurring."""
task.status = TaskStatus.COMPLETED
task.last_executed_at = datetime.now(tz=UTC)
task.executions_count += 1
task.error_message = None
if next_execution_at and task.should_repeat():
task.next_execution_at = next_execution_at
task.status = TaskStatus.PENDING
elif not task.should_repeat():
task.is_active = False
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
"""Mark a task as failed with error message."""
task.status = TaskStatus.FAILED
task.error_message = error_message
task.last_executed_at = datetime.now(tz=UTC)
# For non-recurring tasks, deactivate on failure
if not task.is_recurring():
task.is_active = False
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)

View File

@@ -1,33 +1,47 @@
"""Sound repository for database operations."""
from typing import Any
from datetime import datetime
from enum import Enum
from sqlalchemy import desc, func
from sqlmodel import select
from sqlalchemy import func
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.favorite import Favorite
from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class SoundRepository:
class SoundSortField(str, Enum):
"""Sound sort field enumeration."""
NAME = "name"
FILENAME = "filename"
DURATION = "duration"
SIZE = "size"
TYPE = "type"
PLAY_COUNT = "play_count"
CREATED_AT = "created_at"
UPDATED_AT = "updated_at"
class SortOrder(str, Enum):
"""Sort order enumeration."""
ASC = "asc"
DESC = "desc"
class SoundRepository(BaseRepository[Sound]):
"""Repository for sound operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the sound repository."""
self.session = session
async def get_by_id(self, sound_id: int) -> Sound | None:
"""Get a sound by ID."""
try:
statement = select(Sound).where(Sound.id == sound_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get sound by ID: %s", sound_id)
raise
super().__init__(Sound, session)
async def get_by_filename(self, filename: str) -> Sound | None:
"""Get a sound by filename."""
@@ -59,48 +73,6 @@ class SoundRepository:
logger.exception("Failed to get sounds by type: %s", sound_type)
raise
async def create(self, sound_data: dict[str, Any]) -> Sound:
"""Create a new sound."""
try:
sound = Sound(**sound_data)
self.session.add(sound)
await self.session.commit()
await self.session.refresh(sound)
except Exception:
await self.session.rollback()
logger.exception("Failed to create sound")
raise
else:
logger.info("Created new sound: %s", sound.name)
return sound
async def update(self, sound: Sound, update_data: dict[str, Any]) -> Sound:
"""Update a sound."""
try:
for field, value in update_data.items():
setattr(sound, field, value)
await self.session.commit()
await self.session.refresh(sound)
except Exception:
await self.session.rollback()
logger.exception("Failed to update sound")
raise
else:
logger.info("Updated sound: %s", sound.name)
return sound
async def delete(self, sound: Sound) -> None:
"""Delete a sound."""
try:
await self.session.delete(sound)
await self.session.commit()
logger.info("Deleted sound: %s", sound.name)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete sound")
raise
async def search_by_name(self, query: str) -> list[Sound]:
"""Search sounds by name (case-insensitive)."""
try:
@@ -144,6 +116,213 @@ class SoundRepository:
return list(result.all())
except Exception:
logger.exception(
"Failed to get unnormalized sounds by type: %s", sound_type
"Failed to get unnormalized sounds by type: %s",
sound_type,
)
raise
async def get_by_types(self, sound_types: list[str] | None = None) -> list[Sound]:
"""Get sounds by types. If types is None or empty, return all sounds."""
try:
statement = select(Sound)
if sound_types:
statement = statement.where(col(Sound.type).in_(sound_types))
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get sounds by types: %s", sound_types)
raise
async def search_and_sort( # noqa: PLR0913
self,
search_query: str | None = None,
sound_types: list[str] | None = None,
sort_by: SoundSortField | None = None,
sort_order: SortOrder = SortOrder.ASC,
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False, # noqa: FBT001, FBT002
user_id: int | None = None,
) -> list[Sound]:
"""Search and sort sounds with optional filtering."""
try:
statement = select(Sound)
# Apply favorites filter
if favorites_only and user_id is not None:
statement = statement.join(Favorite).where(
Favorite.user_id == user_id,
Favorite.sound_id == Sound.id,
)
# Apply type filter
if sound_types:
statement = statement.where(col(Sound.type).in_(sound_types))
# Apply search filter
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
statement = statement.where(
func.lower(Sound.name).like(search_pattern),
)
# Apply sorting
if sort_by:
sort_column = getattr(Sound, sort_by.value)
if sort_order == SortOrder.DESC:
statement = statement.order_by(sort_column.desc())
else:
statement = statement.order_by(sort_column.asc())
else:
# Default sorting by name ascending
statement = statement.order_by(Sound.name.asc())
# Apply pagination
if offset > 0:
statement = statement.offset(offset)
if limit is not None:
statement = statement.limit(limit)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception(
(
"Failed to search and sort sounds: "
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, "
"user_id=%s"
),
search_query,
sound_types,
sort_by,
sort_order,
favorites_only,
user_id,
)
raise
async def get_soundboard_statistics(
self,
sound_type: str = "SDB",
) -> dict[str, int | float]:
"""Get statistics for sounds of a specific type."""
try:
statement = select(
func.count(Sound.id).label("count"),
func.sum(Sound.play_count).label("total_plays"),
func.sum(Sound.duration).label("total_duration"),
func.sum(
Sound.size + func.coalesce(Sound.normalized_size, 0),
).label("total_size"),
).where(Sound.type == sound_type)
result = await self.session.exec(statement)
row = result.first()
except Exception:
logger.exception("Failed to get soundboard statistics")
raise
else:
return {
"count": row.count if row.count is not None else 0,
"total_plays": row.total_plays if row.total_plays is not None else 0,
"total_duration": (
row.total_duration if row.total_duration is not None else 0
),
"total_size": row.total_size if row.total_size is not None else 0,
}
async def get_track_statistics(self) -> dict[str, int | float]:
"""Get statistics for EXT type sounds."""
try:
statement = select(
func.count(Sound.id).label("count"),
func.sum(Sound.play_count).label("total_plays"),
func.sum(Sound.duration).label("total_duration"),
func.sum(
Sound.size + func.coalesce(Sound.normalized_size, 0),
).label("total_size"),
).where(Sound.type == "EXT")
result = await self.session.exec(statement)
row = result.first()
except Exception:
logger.exception("Failed to get track statistics")
raise
else:
return {
"count": row.count if row.count is not None else 0,
"total_plays": row.total_plays if row.total_plays is not None else 0,
"total_duration": (
row.total_duration if row.total_duration is not None else 0
),
"total_size": row.total_size if row.total_size is not None else 0,
}
async def get_top_sounds(
self,
sound_type: str,
date_filter: datetime | None = None,
limit: int = 10,
) -> list[dict]:
"""Get top sounds by play count for a specific type and period."""
try:
# Join SoundPlayed with Sound and count plays within the period
statement = (
select(
Sound.id,
Sound.name,
Sound.type,
Sound.duration,
Sound.created_at,
func.count(SoundPlayed.id).label("play_count"),
)
.select_from(SoundPlayed)
.join(Sound, SoundPlayed.sound_id == Sound.id)
)
# Apply sound type filter
if sound_type != "all":
statement = statement.where(Sound.type == sound_type.upper())
# Apply date filter if provided
if date_filter:
statement = statement.where(SoundPlayed.created_at >= date_filter)
# Group by sound and order by play count descending
statement = (
statement.group_by(
Sound.id,
Sound.name,
Sound.type,
Sound.duration,
Sound.created_at,
)
.order_by(func.count(SoundPlayed.id).desc())
.limit(limit)
)
result = await self.session.exec(statement)
rows = result.all()
# Convert to dictionaries with the play count from the period
return [
{
"id": row.id,
"name": row.name,
"type": row.type,
"play_count": row.play_count,
"duration": row.duration,
"created_at": row.created_at,
}
for row in rows
]
except Exception:
logger.exception(
"Failed to get top sounds: type=%s, date_filter=%s, limit=%s",
sound_type,
date_filter,
limit,
)
raise

74
app/repositories/tts.py Normal file
View File

@@ -0,0 +1,74 @@
"""TTS repository for database operations."""
from collections.abc import Sequence
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from app.models.tts import TTS
from app.repositories.base import BaseRepository
class TTSRepository(BaseRepository[TTS]):
"""Repository for TTS operations."""
def __init__(self, session: "AsyncSession") -> None:
"""Initialize TTS repository.
Args:
session: Database session for operations
"""
super().__init__(TTS, session)
async def get_by_user_id(
self,
user_id: int,
limit: int = 50,
offset: int = 0,
) -> Sequence[TTS]:
"""Get TTS records by user ID with pagination.
Args:
user_id: User ID to filter by
limit: Maximum number of records to return
offset: Number of records to skip
Returns:
List of TTS records
"""
stmt = (
select(self.model)
.where(self.model.user_id == user_id)
.order_by(self.model.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return result.all()
async def get_by_user_and_id(
self,
user_id: int,
tts_id: int,
) -> TTS | None:
"""Get a specific TTS record by user ID and TTS ID.
Args:
user_id: User ID to filter by
tts_id: TTS ID to retrieve
Returns:
TTS record if found and belongs to user, None otherwise
"""
stmt = select(self.model).where(
self.model.id == tts_id,
self.model.user_id == user_id,
)
result = await self.session.exec(stmt)
return result.first()

View File

@@ -1,32 +1,156 @@
"""User repository."""
from datetime import datetime
from enum import Enum
from typing import Any
from sqlalchemy import Select, func
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.credit_transaction import CreditTransaction
from app.models.extraction import Extraction
from app.models.plan import Plan
from app.models.playlist import Playlist
from app.models.sound_played import SoundPlayed
from app.models.tts import TTS
from app.models.user import User
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class UserRepository:
class UserSortField(str, Enum):
"""User sort fields."""
NAME = "name"
EMAIL = "email"
ROLE = "role"
CREDITS = "credits"
CREATED_AT = "created_at"
class SortOrder(str, Enum):
"""Sort order."""
ASC = "asc"
DESC = "desc"
class UserStatus(str, Enum):
"""User status filter."""
ALL = "all"
ACTIVE = "active"
INACTIVE = "inactive"
class UserRepository(BaseRepository[User]):
"""Repository for user operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the user repository."""
self.session = session
super().__init__(User, session)
async def get_by_id(self, user_id: int) -> User | None:
"""Get a user by ID."""
async def get_all_with_plan(
self,
limit: int = 100,
offset: int = 0,
) -> list[User]:
"""Get all users with plan relationship loaded."""
try:
statement = select(User).where(User.id == user_id)
statement = (
select(User)
.options(selectinload(User.plan))
.limit(limit)
.offset(offset)
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all users with plan")
raise
async def get_all_with_plan_paginated( # noqa: PLR0913
self,
page: int = 1,
limit: int = 50,
search: str | None = None,
sort_by: UserSortField = UserSortField.NAME,
sort_order: SortOrder = SortOrder.ASC,
status_filter: UserStatus = UserStatus.ALL,
) -> tuple[list[User], int]:
"""Get all users with plan relationship loaded and return total count."""
try:
# Calculate offset
offset = (page - 1) * limit
# Build base query
base_query = select(User).options(selectinload(User.plan))
count_query = select(func.count(User.id))
# Apply search filter
if search and search.strip():
search_pattern = f"%{search.strip().lower()}%"
search_condition = func.lower(User.name).like(
search_pattern,
) | func.lower(User.email).like(search_pattern)
base_query = base_query.where(search_condition)
count_query = count_query.where(search_condition)
# Apply status filter
if status_filter == UserStatus.ACTIVE:
base_query = base_query.where(User.is_active == True) # noqa: E712
count_query = count_query.where(User.is_active == True) # noqa: E712
elif status_filter == UserStatus.INACTIVE:
base_query = base_query.where(User.is_active == False) # noqa: E712
count_query = count_query.where(User.is_active == False) # noqa: E712
# Apply sorting
sort_column = {
UserSortField.NAME: User.name,
UserSortField.EMAIL: User.email,
UserSortField.ROLE: User.role,
UserSortField.CREDITS: User.credits,
UserSortField.CREATED_AT: User.created_at,
}.get(sort_by, User.name)
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(sort_column.desc())
else:
base_query = base_query.order_by(sort_column.asc())
# Get total count
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply pagination and get results
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
users = list(result.all())
except Exception:
logger.exception("Failed to get paginated users with plan")
raise
else:
return users, total_count
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
"""Get a user by ID with plan relationship loaded."""
try:
statement = (
select(User)
.options(selectinload(User.plan))
.where(User.id == entity_id)
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by ID: %s", user_id)
logger.exception(
"Failed to get user by ID with plan: %s",
entity_id,
)
raise
async def get_by_email(self, email: str) -> User | None:
@@ -49,8 +173,8 @@ class UserRepository:
logger.exception("Failed to get user by API token")
raise
async def create(self, user_data: dict[str, Any]) -> User:
"""Create a new user."""
async def create(self, entity_data: dict[str, Any]) -> User:
"""Create a new user with plan assignment and first user admin logic."""
def _raise_plan_not_found() -> None:
msg = "Default plan not found"
@@ -65,7 +189,7 @@ class UserRepository:
if is_first_user:
# First user gets admin role and pro plan
plan_statement = select(Plan).where(Plan.code == "pro")
user_data["role"] = "admin"
entity_data["role"] = "admin"
logger.info("Creating first user with admin role and pro plan")
else:
# Regular users get free plan
@@ -81,48 +205,14 @@ class UserRepository:
assert default_plan is not None # noqa: S101
# Set plan_id and default credits
user_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits
entity_data["plan_id"] = default_plan.id
entity_data["credits"] = default_plan.credits
user = User(**user_data)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user)
# Use BaseRepository's create method
return await super().create(entity_data)
except Exception:
await self.session.rollback()
logger.exception("Failed to create user")
raise
else:
logger.info("Created new user with email: %s", user.email)
return user
async def update(self, user: User, update_data: dict[str, Any]) -> User:
"""Update a user."""
try:
for field, value in update_data.items():
setattr(user, field, value)
await self.session.commit()
await self.session.refresh(user)
except Exception:
await self.session.rollback()
logger.exception("Failed to update user")
raise
else:
logger.info("Updated user: %s", user.email)
return user
async def delete(self, user: User) -> None:
"""Delete a user."""
try:
await self.session.delete(user)
await self.session.commit()
logger.info("Deleted user: %s", user.email)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete user")
raise
async def email_exists(self, email: str) -> bool:
"""Check if an email address is already registered."""
@@ -133,3 +223,146 @@ class UserRepository:
except Exception:
logger.exception("Failed to check if email exists: %s", email)
raise
async def get_top_users(
self,
metric_type: str,
date_filter: datetime | None = None,
limit: int = 10,
) -> list[dict[str, Any]]:
"""Get top users by different metrics."""
try:
query = self._build_top_users_query(metric_type, date_filter)
# Add ordering and limit
query = query.order_by(func.count().desc()).limit(limit)
result = await self.session.exec(query)
rows = result.all()
return [
{
"id": row[0],
"name": row[1],
"count": int(row[2]),
}
for row in rows
]
except Exception:
logger.exception(
"Failed to get top users for metric=%s, date_filter=%s",
metric_type,
date_filter,
)
raise
def _build_top_users_query(
self,
metric_type: str,
date_filter: datetime | None,
) -> Select:
"""Build query for top users based on metric type."""
match metric_type:
case "sounds_played":
query = self._build_sounds_played_query()
case "credits_used":
query = self._build_credits_used_query()
case "tracks_added":
query = self._build_tracks_added_query()
case "tts_added":
query = self._build_tts_added_query()
case "playlists_created":
query = self._build_playlists_created_query()
case _:
msg = f"Unknown metric type: {metric_type}"
raise ValueError(msg)
# Apply date filter if provided
if date_filter:
query = self._apply_date_filter(query, metric_type, date_filter)
return query
def _build_sounds_played_query(self) -> Select:
"""Build query for sounds played metric."""
return (
select(
User.id,
User.name,
func.count(SoundPlayed.id).label("count"),
)
.join(SoundPlayed, User.id == SoundPlayed.user_id)
.group_by(User.id, User.name)
)
def _build_credits_used_query(self) -> Select:
"""Build query for credits used metric."""
return (
select(
User.id,
User.name,
func.sum(func.abs(CreditTransaction.amount)).label("count"),
)
.join(CreditTransaction, User.id == CreditTransaction.user_id)
.where(CreditTransaction.amount < 0)
.group_by(User.id, User.name)
)
def _build_tracks_added_query(self) -> Select:
"""Build query for tracks added metric."""
return (
select(
User.id,
User.name,
func.count(Extraction.id).label("count"),
)
.join(Extraction, User.id == Extraction.user_id)
.where(Extraction.sound_id.is_not(None))
.group_by(User.id, User.name)
)
def _build_tts_added_query(self) -> Select:
"""Build query for TTS added metric."""
return (
select(
User.id,
User.name,
func.count(TTS.id).label("count"),
)
.join(TTS, User.id == TTS.user_id)
.group_by(User.id, User.name)
)
def _build_playlists_created_query(self) -> Select:
"""Build query for playlists created metric."""
return (
select(
User.id,
User.name,
func.count(Playlist.id).label("count"),
)
.join(Playlist, User.id == Playlist.user_id)
.group_by(User.id, User.name)
)
def _apply_date_filter(
self,
query: Select,
metric_type: str,
date_filter: datetime,
) -> Select:
"""Apply date filter to query based on metric type."""
match metric_type:
case "sounds_played":
return query.where(SoundPlayed.created_at >= date_filter)
case "credits_used":
return query.where(CreditTransaction.created_at >= date_filter)
case "tracks_added":
return query.where(Extraction.created_at >= date_filter)
case "tts_added":
return query.where(TTS.created_at >= date_filter)
case "playlists_created":
return query.where(Playlist.created_at >= date_filter)
case _:
return query

View File

@@ -1,22 +1,21 @@
"""Repository for user OAuth operations."""
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.user_oauth import UserOauth
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class UserOauthRepository:
class UserOauthRepository(BaseRepository[UserOauth]):
"""Repository for user OAuth operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize repository with database session."""
self.session = session
super().__init__(UserOauth, session)
async def get_by_provider_user_id(
self,
@@ -61,57 +60,12 @@ class UserOauthRepository:
else:
return result.first()
async def create(self, oauth_data: dict[str, Any]) -> UserOauth:
"""Create a new user OAuth record."""
async def get_by_user_id(self, user_id: int) -> list[UserOauth]:
"""Get all OAuth providers for a user."""
try:
oauth = UserOauth(**oauth_data)
self.session.add(oauth)
await self.session.commit()
await self.session.refresh(oauth)
logger.info(
"Created OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
statement = select(UserOauth).where(UserOauth.user_id == user_id)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
await self.session.rollback()
logger.exception("Failed to create user OAuth")
raise
else:
return oauth
async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth:
"""Update a user OAuth record."""
try:
for key, value in update_data.items():
setattr(oauth, key, value)
self.session.add(oauth)
await self.session.commit()
await self.session.refresh(oauth)
logger.info(
"Updated OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to update user OAuth")
raise
else:
return oauth
async def delete(self, oauth: UserOauth) -> None:
"""Delete a user OAuth record."""
try:
await self.session.delete(oauth)
await self.session.commit()
logger.info(
"Deleted OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete user OAuth")
logger.exception("Failed to get OAuth providers for user ID: %s", user_id)
raise

View File

@@ -28,25 +28,16 @@ from .playlist import (
)
__all__ = [
# Auth schemas
"ApiTokenRequest",
"ApiTokenResponse",
"ApiTokenStatusResponse",
"AuthResponse",
"TokenResponse",
"UserLoginRequest",
"UserRegisterRequest",
"UserResponse",
# Common schemas
"HealthResponse",
"MessageResponse",
"StatusResponse",
# Player schemas
"MessageResponse",
"PlayerModeRequest",
"PlayerSeekRequest",
"PlayerStateResponse",
"PlayerVolumeRequest",
# Playlist schemas
"PlaylistAddSoundRequest",
"PlaylistCreateRequest",
"PlaylistReorderRequest",
@@ -54,4 +45,9 @@ __all__ = [
"PlaylistSoundResponse",
"PlaylistStatsResponse",
"PlaylistUpdateRequest",
"StatusResponse",
"TokenResponse",
"UserLoginRequest",
"UserRegisterRequest",
"UserResponse",
]

View File

@@ -79,3 +79,28 @@ class ApiTokenStatusResponse(BaseModel):
has_token: bool = Field(..., description="Whether user has an active API token")
expires_at: datetime | None = Field(None, description="Token expiration timestamp")
is_expired: bool = Field(..., description="Whether the token is expired")
class ChangePasswordRequest(BaseModel):
"""Schema for password change request."""
current_password: str | None = Field(
None,
description="Current password (required if user has existing password)",
)
new_password: str = Field(
...,
min_length=8,
description="New password (minimum 8 characters)",
)
class UpdateProfileRequest(BaseModel):
"""Schema for profile update request."""
name: str | None = Field(
None,
min_length=1,
max_length=100,
description="User display name",
)

View File

@@ -18,4 +18,4 @@ class StatusResponse(BaseModel):
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(description="Health status")
status: str = Field(description="Health status")

41
app/schemas/favorite.py Normal file
View File

@@ -0,0 +1,41 @@
"""Favorite response schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
class FavoriteResponse(BaseModel):
"""Response schema for a favorite."""
id: int = Field(description="Favorite ID")
user_id: int = Field(description="User ID")
sound_id: int | None = Field(
description="Sound ID if this is a sound favorite",
default=None,
)
playlist_id: int | None = Field(
description="Playlist ID if this is a playlist favorite",
default=None,
)
created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp")
class Config:
"""Pydantic config."""
from_attributes = True
class FavoritesListResponse(BaseModel):
"""Response schema for a list of favorites."""
favorites: list[FavoriteResponse] = Field(description="List of favorites")
class FavoriteCountsResponse(BaseModel):
"""Response schema for favorite counts."""
total: int = Field(description="Total number of favorites")
sounds: int = Field(description="Number of favorited sounds")
playlists: int = Field(description="Number of favorited playlists")

View File

@@ -10,7 +10,7 @@ from app.services.player import PlayerMode
class PlayerSeekRequest(BaseModel):
"""Request model for seek operation."""
position_ms: int = Field(ge=0, description="Position in milliseconds")
position: int = Field(ge=0, description="Position in milliseconds")
class PlayerVolumeRequest(BaseModel):
@@ -30,17 +30,26 @@ class PlayerStateResponse(BaseModel):
status: str = Field(description="Player status (playing, paused, stopped)")
current_sound: dict[str, Any] | None = Field(
None, description="Current sound information"
None,
description="Current sound information",
)
playlist: dict[str, Any] | None = Field(
None, description="Current playlist information"
None,
description="Current playlist information",
)
position_ms: int = Field(description="Current position in milliseconds")
duration_ms: int | None = Field(
None, description="Total duration in milliseconds",
position: int = Field(description="Current position in milliseconds")
duration: int | None = Field(
None,
description="Total duration in milliseconds",
)
volume: int = Field(description="Current volume (0-100)")
previous_volume: int = Field(description="Previous volume for unmuting (0-100)")
mode: str = Field(description="Current playback mode")
index: int | None = Field(
None, description="Current track index in playlist",
None,
description="Current track index in playlist",
)
play_next_queue: list[dict[str, Any]] = Field(
default_factory=list,
description="Play next queue",
)

View File

@@ -1,6 +1,6 @@
"""Playlist schemas."""
from pydantic import BaseModel, Field
from pydantic import BaseModel
from app.models.playlist import Playlist
from app.models.sound import Sound
@@ -33,14 +33,32 @@ class PlaylistResponse(BaseModel):
is_main: bool
is_current: bool
is_deletable: bool
is_favorited: bool = False
favorite_count: int = 0
created_at: str
updated_at: str | None
@classmethod
def from_playlist(cls, playlist: Playlist) -> "PlaylistResponse":
"""Create response from playlist model."""
def from_playlist(
cls,
playlist: Playlist,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "PlaylistResponse":
"""Create response from playlist model.
Args:
playlist: The Playlist model
is_favorited: Whether the playlist is favorited by the current user
favorite_count: Number of users who favorited this playlist
Returns:
PlaylistResponse instance
"""
if playlist.id is None:
raise ValueError("Playlist ID cannot be None")
msg = "Playlist ID cannot be None"
raise ValueError(msg)
return cls(
id=playlist.id,
name=playlist.name,
@@ -49,6 +67,8 @@ class PlaylistResponse(BaseModel):
is_main=playlist.is_main,
is_current=playlist.is_current,
is_deletable=playlist.is_deletable,
is_favorited=is_favorited,
favorite_count=favorite_count,
created_at=playlist.created_at.isoformat(),
updated_at=playlist.updated_at.isoformat() if playlist.updated_at else None,
)
@@ -70,7 +90,8 @@ class PlaylistSoundResponse(BaseModel):
def from_sound(cls, sound: Sound) -> "PlaylistSoundResponse":
"""Create response from sound model."""
if sound.id is None:
raise ValueError("Sound ID cannot be None")
msg = "Sound ID cannot be None"
raise ValueError(msg)
return cls(
id=sound.id,
name=sound.name,

197
app/schemas/scheduler.py Normal file
View File

@@ -0,0 +1,197 @@
"""Schemas for scheduled task API."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from app.models.scheduled_task import RecurrenceType, TaskStatus, TaskType
class TaskFilterParams(BaseModel):
"""Query parameters for filtering tasks."""
status: TaskStatus | None = Field(default=None, description="Filter by task status")
task_type: TaskType | None = Field(default=None, description="Filter by task type")
limit: int = Field(default=50, description="Maximum number of tasks to return")
offset: int = Field(default=0, description="Number of tasks to skip")
class ScheduledTaskBase(BaseModel):
"""Base schema for scheduled tasks."""
name: str = Field(description="Human-readable task name")
task_type: TaskType = Field(description="Type of task to execute")
scheduled_at: datetime = Field(description="When the task should be executed")
timezone: str = Field(default="UTC", description="Timezone for scheduling")
parameters: dict[str, Any] = Field(
default_factory=dict,
description="Task-specific parameters",
)
recurrence_type: RecurrenceType = Field(
default=RecurrenceType.NONE,
description="Recurrence pattern",
)
cron_expression: str | None = Field(
default=None,
description="Cron expression for custom recurrence",
)
recurrence_count: int | None = Field(
default=None,
description="Number of times to repeat (None for infinite)",
)
expires_at: datetime | None = Field(
default=None,
description="When the task expires (optional)",
)
class ScheduledTaskCreate(ScheduledTaskBase):
"""Schema for creating a scheduled task."""
class ScheduledTaskUpdate(BaseModel):
"""Schema for updating a scheduled task."""
name: str | None = None
scheduled_at: datetime | None = None
timezone: str | None = None
parameters: dict[str, Any] | None = None
is_active: bool | None = None
expires_at: datetime | None = None
class ScheduledTaskResponse(ScheduledTaskBase):
"""Schema for scheduled task responses."""
id: int
status: TaskStatus
user_id: int | None = None
executions_count: int
last_executed_at: datetime | None = None
next_execution_at: datetime | None = None
error_message: str | None = None
is_active: bool
created_at: datetime
updated_at: datetime
class Config:
"""Pydantic configuration."""
from_attributes = True
# Task-specific parameter schemas
class CreditRechargeParameters(BaseModel):
"""Parameters for credit recharge tasks."""
user_id: int | None = Field(
default=None,
description="Specific user ID to recharge (None for all users)",
)
class PlaySoundParameters(BaseModel):
"""Parameters for play sound tasks."""
sound_id: int = Field(description="ID of the sound to play")
class PlayPlaylistParameters(BaseModel):
"""Parameters for play playlist tasks."""
playlist_id: int = Field(description="ID of the playlist to play")
play_mode: str = Field(
default="continuous",
description="Play mode (continuous, loop, loop_one, random, single)",
)
shuffle: bool = Field(default=False, description="Whether to shuffle the playlist")
# Convenience schemas for creating specific task types
class CreateCreditRechargeTask(BaseModel):
"""Schema for creating credit recharge tasks."""
name: str = "Credit Recharge"
scheduled_at: datetime
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: str | None = None
recurrence_count: int | None = None
expires_at: datetime | None = None
user_id: int | None = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
return ScheduledTaskCreate(
name=self.name,
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=self.scheduled_at,
timezone=self.timezone,
parameters={"user_id": self.user_id},
recurrence_type=self.recurrence_type,
cron_expression=self.cron_expression,
recurrence_count=self.recurrence_count,
expires_at=self.expires_at,
)
class CreatePlaySoundTask(BaseModel):
"""Schema for creating play sound tasks."""
name: str
scheduled_at: datetime
sound_id: int
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: str | None = None
recurrence_count: int | None = None
expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
return ScheduledTaskCreate(
name=self.name,
task_type=TaskType.PLAY_SOUND,
scheduled_at=self.scheduled_at,
timezone=self.timezone,
parameters={"sound_id": self.sound_id},
recurrence_type=self.recurrence_type,
cron_expression=self.cron_expression,
recurrence_count=self.recurrence_count,
expires_at=self.expires_at,
)
class CreatePlayPlaylistTask(BaseModel):
"""Schema for creating play playlist tasks."""
name: str
scheduled_at: datetime
playlist_id: int
play_mode: str = "continuous"
shuffle: bool = False
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: str | None = None
recurrence_count: int | None = None
expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
return ScheduledTaskCreate(
name=self.name,
task_type=TaskType.PLAY_PLAYLIST,
scheduled_at=self.scheduled_at,
timezone=self.timezone,
parameters={
"playlist_id": self.playlist_id,
"play_mode": self.play_mode,
"shuffle": self.shuffle,
},
recurrence_type=self.recurrence_type,
cron_expression=self.cron_expression,
recurrence_count=self.recurrence_count,
expires_at=self.expires_at,
)

106
app/schemas/sound.py Normal file
View File

@@ -0,0 +1,106 @@
"""Sound response schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from app.models.sound import Sound
class SoundResponse(BaseModel):
"""Response schema for a sound with favorite indicator."""
id: int = Field(description="Sound ID")
type: str = Field(description="Sound type")
name: str = Field(description="Sound name")
filename: str = Field(description="Sound filename")
duration: int = Field(description="Duration in milliseconds")
size: int = Field(description="File size in bytes")
hash: str = Field(description="File hash")
normalized_filename: str | None = Field(
description="Normalized filename",
default=None,
)
normalized_duration: int | None = Field(
description="Normalized duration in milliseconds",
default=None,
)
normalized_size: int | None = Field(
description="Normalized file size in bytes",
default=None,
)
normalized_hash: str | None = Field(
description="Normalized file hash",
default=None,
)
thumbnail: str | None = Field(description="Thumbnail filename", default=None)
play_count: int = Field(description="Number of times played")
is_normalized: bool = Field(description="Whether the sound is normalized")
is_music: bool = Field(description="Whether the sound is music")
is_deletable: bool = Field(description="Whether the sound can be deleted")
is_favorited: bool = Field(
description="Whether the sound is favorited by the current user",
default=False,
)
favorite_count: int = Field(
description="Number of users who favorited this sound",
default=0,
)
created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp")
class Config:
"""Pydantic config."""
from_attributes = True
@classmethod
def from_sound(
cls,
sound: Sound,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "SoundResponse":
"""Create a SoundResponse from a Sound model.
Args:
sound: The Sound model
is_favorited: Whether the sound is favorited by the current user
favorite_count: Number of users who favorited this sound
Returns:
SoundResponse instance
"""
if sound.id is None:
msg = "Sound ID cannot be None"
raise ValueError(msg)
return cls(
id=sound.id,
type=sound.type,
name=sound.name,
filename=sound.filename,
duration=sound.duration,
size=sound.size,
hash=sound.hash,
normalized_filename=sound.normalized_filename,
normalized_duration=sound.normalized_duration,
normalized_size=sound.normalized_size,
normalized_hash=sound.normalized_hash,
thumbnail=sound.thumbnail,
play_count=sound.play_count,
is_normalized=sound.is_normalized,
is_music=sound.is_music,
is_deletable=sound.is_deletable,
is_favorited=is_favorited,
favorite_count=favorite_count,
created_at=sound.created_at,
updated_at=sound.updated_at,
)
class SoundsListResponse(BaseModel):
"""Response schema for a list of sounds."""
sounds: list[SoundResponse] = Field(description="List of sounds")

27
app/schemas/user.py Normal file
View File

@@ -0,0 +1,27 @@
"""User schemas."""
from pydantic import BaseModel, Field, field_validator
class UserUpdate(BaseModel):
"""Schema for updating a user."""
name: str | None = Field(
None,
min_length=1,
max_length=100,
description="User full name",
)
plan_id: int | None = Field(None, description="User plan ID")
credits: int | None = Field(None, ge=0, description="User credits")
is_active: bool | None = Field(None, description="Whether user is active")
role: str | None = Field(None, description="User role (admin or user)")
@field_validator("role")
@classmethod
def validate_role(cls, v: str | None) -> str | None:
"""Validate that role is either 'user' or 'admin'."""
if v is not None and v not in {"user", "admin"}:
msg = "Role must be either 'user' or 'admin'"
raise ValueError(msg)
return v

View File

@@ -20,11 +20,6 @@ from app.schemas.auth import (
)
from app.services.oauth import OAuthUserInfo
from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
from app.utils.exceptions import (
raise_bad_request,
raise_not_found,
raise_unauthorized,
)
logger = get_logger(__name__)
@@ -44,7 +39,10 @@ class AuthService:
# Check if email already exists
if await self.user_repo.email_exists(request.email):
raise_bad_request("Email address is already registered")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email address is already registered",
)
# Hash the password
hashed_password = PasswordUtils.hash_password(request.password)
@@ -77,18 +75,27 @@ class AuthService:
# Get user by email
user = await self.user_repo.get_by_email(request.email)
if not user:
raise_unauthorized("Invalid email or password")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
# Check if user is active
if not user.is_active:
raise_unauthorized("Account is deactivated")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
# Verify password
if not user.password_hash or not PasswordUtils.verify_password(
request.password,
user.password_hash,
):
raise_unauthorized("Invalid email or password")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
# Generate access token
token = self._create_access_token(user)
@@ -103,10 +110,16 @@ class AuthService:
"""Get the current authenticated user."""
user = await self.user_repo.get_by_id(user_id)
if not user:
raise_not_found("User")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
if not user.is_active:
raise_unauthorized("Account is deactivated")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return user
@@ -417,3 +430,93 @@ class AuthService:
oauth_user_info.email,
)
return AuthResponse(user=user_response, token=token)
async def update_user_profile(self, user: User, data: dict) -> User:
"""Update user profile information."""
logger.info("Updating profile for user: %s", user.email)
# Only allow updating specific fields
allowed_fields = {"name"}
update_data = {k: v for k, v in data.items() if k in allowed_fields}
if not update_data:
return user
# Update user
for field, value in update_data.items():
setattr(user, field, value)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user, ["plan"])
logger.info("Profile updated successfully for user: %s", user.email)
return user
async def change_user_password(
self,
user: User,
current_password: str | None,
new_password: str,
) -> None:
"""Change user's password."""
# Store user email before any operations to avoid session detachment issues
user_email = user.email
logger.info("Changing password for user: %s", user_email)
# Store whether user had existing password before we modify it
had_existing_password = user.password_hash is not None
# If user has existing password, verify it
if had_existing_password:
if not current_password:
msg = "Current password is required when changing existing password"
raise ValueError(msg)
if not PasswordUtils.verify_password(current_password, user.password_hash):
msg = "Current password is incorrect"
raise ValueError(msg)
else:
# User doesn't have a password (OAuth-only user), setting first password
logger.info("Setting first password for OAuth user: %s", user_email)
# Hash new password
new_password_hash = PasswordUtils.hash_password(new_password)
# Update user
user.password_hash = new_password_hash
self.session.add(user)
await self.session.commit()
logger.info(
"Password %s successfully for user: %s",
"changed" if had_existing_password else "set",
user_email,
)
async def user_to_response(self, user: User) -> UserResponse:
"""Convert User model to UserResponse with plan information."""
# Load plan relationship if not already loaded
if not hasattr(user, "plan") or not user.plan:
await self.session.refresh(user, ["plan"])
return UserResponse(
id=user.id,
email=user.email,
name=user.name,
picture=user.picture,
role=user.role,
credits=user.credits,
is_active=user.is_active,
plan={
"id": user.plan.id,
"name": user.plan.name,
"max_credits": user.plan.max_credits,
"features": [], # Add features if needed
},
created_at=user.created_at,
updated_at=user.updated_at,
)
async def get_user_oauth_providers(self, user: User) -> list:
"""Get OAuth providers connected to the user."""
return await self.oauth_repo.get_by_user_id(user.id)

View File

@@ -30,7 +30,7 @@ class InsufficientCreditsError(Exception):
self.required = required
self.available = available
super().__init__(
f"Insufficient credits: {required} required, {available} available"
f"Insufficient credits: {required} required, {available} available",
)
@@ -76,14 +76,12 @@ class CreditService:
self,
user_id: int,
action_type: CreditActionType,
metadata: dict[str, Any] | None = None,
) -> tuple[User, CreditAction]:
"""Validate user has sufficient credits and optionally reserve them.
Args:
user_id: The user ID
action_type: The type of action
metadata: Optional metadata to store with transaction
Returns:
Tuple of (user, credit_action)
@@ -118,6 +116,7 @@ class CreditService:
self,
user_id: int,
action_type: CreditActionType,
*,
success: bool = True,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
@@ -138,19 +137,27 @@ class CreditService:
"""
action = get_credit_action(action_type)
# Only deduct if action requires success and was successful, or doesn't require success
should_deduct = (action.requires_success and success) or not action.requires_success
# Only deduct if action requires success and was successful,
# or doesn't require success
should_deduct = (
action.requires_success and success
) or not action.requires_success
if not should_deduct:
logger.info(
"Skipping credit deduction for user %s: action %s failed and requires success",
"Skipping credit deduction for user %s: "
"action %s failed and requires success",
user_id,
action_type.value,
)
# Still create a transaction record for auditing
return await self._create_transaction_record(
user_id, action, 0, success, metadata
user_id,
action,
0,
success=success,
metadata=metadata,
)
session = self.db_session_factory()
@@ -204,18 +211,23 @@ class CreditService:
"action_type": action_type.value,
"success": success,
}
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
await socket_manager.send_to_user(
str(user_id),
"user_credits_changed",
event_data,
)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s", user_id,
"Failed to emit user_credits_changed event for user %s",
user_id,
)
return transaction
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
@@ -292,18 +304,23 @@ class CreditService:
"description": description,
"success": True,
}
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
await socket_manager.send_to_user(
str(user_id),
"user_credits_changed",
event_data,
)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s", user_id,
"Failed to emit user_credits_changed event for user %s",
user_id,
)
return transaction
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
@@ -312,6 +329,7 @@ class CreditService:
user_id: int,
action: CreditAction,
amount: int,
*,
success: bool,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
@@ -342,19 +360,22 @@ class CreditService:
amount=amount,
balance_before=user.credits,
balance_after=user.credits,
description=f"{action.description} (failed)" if not success else action.description,
description=(
f"{action.description} (failed)"
if not success
else action.description
),
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(transaction)
await session.commit()
return transaction
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
@@ -380,4 +401,227 @@ class CreditService:
raise ValueError(msg)
return user.credits
finally:
await session.close()
await session.close()
async def recharge_user_credits_auto(
self,
user_id: int,
) -> CreditTransaction | None:
"""Recharge credits for a user automatically based on their plan.
Args:
user_id: The user ID
Returns:
The created credit transaction if credits were added, None if no recharge
needed
Raises:
ValueError: If user not found or has no plan
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id_with_plan(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if not user.plan:
msg = f"User {user_id} has no plan assigned"
raise ValueError(msg)
# Call the main method with plan details
return await self.recharge_user_credits(
user_id,
user.plan.credits,
user.plan.max_credits,
)
finally:
await session.close()
async def recharge_user_credits(
self,
user_id: int,
plan_credits: int,
max_credits: int,
) -> CreditTransaction | None:
"""Recharge credits for a user based on their plan.
Args:
user_id: The user ID
plan_credits: Number of credits from the plan
max_credits: Maximum credits allowed for the plan
Returns:
The created credit transaction if credits were added, None if no recharge
needed
Raises:
ValueError: If user not found
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Calculate credits to add (can't exceed max_credits)
current_credits = user.credits
target_credits = min(current_credits + plan_credits, max_credits)
credits_to_add = target_credits - current_credits
# If no credits to add, return None
if credits_to_add <= 0:
logger.info(
"No credits to add for user %s: current=%s, "
"plan_credits=%s, max=%s",
user_id,
current_credits,
plan_credits,
max_credits,
)
return None
# Record transaction
transaction = CreditTransaction(
user_id=user_id,
action_type=CreditActionType.DAILY_RECHARGE.value,
amount=credits_to_add,
balance_before=current_credits,
balance_after=target_credits,
description="Daily credit recharge",
success=True,
metadata_json=json.dumps(
{
"plan_credits": plan_credits,
"max_credits": max_credits,
},
),
)
# Update user credits
await user_repo.update(user, {"credits": target_credits})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits recharged for user %s: %s credits added (balance: %s%s)",
user_id,
credits_to_add,
current_credits,
target_credits,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": current_credits,
"credits_after": target_credits,
"credits_added": credits_to_add,
"description": "Daily credit recharge",
"success": True,
}
await socket_manager.send_to_user(
str(user_id),
"user_credits_changed",
event_data,
)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s",
user_id,
)
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
async def recharge_all_users_credits(self) -> dict[str, int]:
"""Recharge credits for all users based on their plans.
Returns:
Dictionary with statistics about the recharge operation
"""
session = self.db_session_factory()
stats = {
"total_users": 0,
"recharged_users": 0,
"skipped_users": 0,
"total_credits_added": 0,
}
try:
user_repo = UserRepository(session)
# Process users in batches to avoid memory issues
offset = 0
batch_size = 100
while True:
users = await user_repo.get_all_with_plan(
limit=batch_size,
offset=offset,
)
if not users:
break
for user in users:
stats["total_users"] += 1
# Skip users without ID (shouldn't happen in practice)
if user.id is None:
continue
transaction = await self.recharge_user_credits(
user.id,
user.plan.credits,
user.plan.max_credits,
)
if transaction:
stats["recharged_users"] += 1
# Calculate the amount from plan data to avoid session issues
current_credits = user.credits
plan_credits = user.plan.credits
max_credits = user.plan.max_credits
target_credits = min(
current_credits + plan_credits, max_credits,
)
credits_added = target_credits - current_credits
stats["total_credits_added"] += credits_added
else:
stats["skipped_users"] += 1
offset += batch_size
# Break if we got fewer users than batch_size (last batch)
if len(users) < batch_size:
break
logger.info(
"Daily credit recharge completed: %s total users, "
"%s recharged, %s skipped, %s total credits added",
stats["total_users"],
stats["recharged_users"],
stats["skipped_users"],
stats["total_credits_added"],
)
return stats
finally:
await session.close()

161
app/services/dashboard.py Normal file
View File

@@ -0,0 +1,161 @@
"""Dashboard service for statistics and analytics."""
from datetime import UTC, datetime, timedelta
from typing import Any
from app.core.logging import get_logger
from app.repositories.sound import SoundRepository
from app.repositories.user import UserRepository
logger = get_logger(__name__)
class DashboardService:
"""Service for dashboard statistics and analytics."""
def __init__(
self,
sound_repository: SoundRepository,
user_repository: UserRepository,
) -> None:
"""Initialize the dashboard service."""
self.sound_repository = sound_repository
self.user_repository = user_repository
async def get_soundboard_statistics(self) -> dict[str, Any]:
"""Get comprehensive soundboard statistics."""
try:
stats = await self.sound_repository.get_soundboard_statistics()
return {
"sound_count": stats["count"],
"total_play_count": stats["total_plays"],
"total_duration": stats["total_duration"],
"total_size": stats["total_size"],
}
except Exception:
logger.exception("Failed to get soundboard statistics")
raise
async def get_track_statistics(self) -> dict[str, Any]:
"""Get comprehensive track statistics."""
try:
stats = await self.sound_repository.get_track_statistics()
return {
"track_count": stats["count"],
"total_play_count": stats["total_plays"],
"total_duration": stats["total_duration"],
"total_size": stats["total_size"],
}
except Exception:
logger.exception("Failed to get track statistics")
raise
async def get_top_sounds(
self,
sound_type: str,
period: str = "all_time",
limit: int = 10,
) -> list[dict[str, Any]]:
"""Get top sounds by play count for a specific period."""
try:
# Calculate the date filter based on period
date_filter = self._get_date_filter(period)
# Get top sounds from repository
top_sounds = await self.sound_repository.get_top_sounds(
sound_type=sound_type,
date_filter=date_filter,
limit=limit,
)
return [
{
"id": sound["id"],
"name": sound["name"],
"type": sound["type"],
"play_count": sound["play_count"],
"duration": sound["duration"],
"created_at": (
sound["created_at"].isoformat() if sound["created_at"] else None
),
}
for sound in top_sounds
]
except Exception:
logger.exception(
"Failed to get top sounds for type=%s, period=%s",
sound_type,
period,
)
raise
async def get_tts_statistics(self) -> dict[str, Any]:
"""Get comprehensive TTS statistics."""
try:
stats = await self.sound_repository.get_soundboard_statistics("TTS")
return {
"sound_count": stats["count"],
"total_play_count": stats["total_plays"],
"total_duration": stats["total_duration"],
"total_size": stats["total_size"],
}
except Exception:
logger.exception("Failed to get TTS statistics")
raise
async def get_top_users(
self,
metric_type: str,
period: str = "all_time",
limit: int = 10,
) -> list[dict[str, Any]]:
"""Get top users by different metrics for a specific period."""
try:
# Calculate the date filter based on period
date_filter = self._get_date_filter(period)
# Get top users from repository
top_users = await self.user_repository.get_top_users(
metric_type=metric_type,
date_filter=date_filter,
limit=limit,
)
return [
{
"id": user["id"],
"name": user["name"],
"count": user["count"],
}
for user in top_users
]
except Exception:
logger.exception(
"Failed to get top users for metric=%s, period=%s",
metric_type,
period,
)
raise
def _get_date_filter(self, period: str) -> datetime | None: # noqa: PLR0911
"""Calculate the date filter based on the period."""
now = datetime.now(UTC)
match period:
case "today":
return now.replace(hour=0, minute=0, second=0, microsecond=0)
case "1_day":
return now - timedelta(days=1)
case "1_week":
return now - timedelta(weeks=1)
case "1_month":
return now - timedelta(days=30)
case "1_year":
return now - timedelta(days=365)
case "all_time":
return None
case _:
return None # Default to all time for unknown periods

View File

@@ -2,8 +2,9 @@
import asyncio
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import TypedDict
from typing import Any, TypedDict
import yt_dlp
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -14,6 +15,7 @@ from app.models.extraction import Extraction
from app.models.sound import Sound
from app.repositories.extraction import ExtractionRepository
from app.repositories.sound import SoundRepository
from app.repositories.user import UserRepository
from app.services.playlist import PlaylistService
from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
@@ -21,6 +23,18 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
logger = get_logger(__name__)
@dataclass
class ExtractionContext:
"""Context data for extraction processing."""
extraction_id: int
extraction_url: str
extraction_service: str | None
extraction_service_id: str | None
extraction_title: str | None
user_id: int
class ExtractionInfo(TypedDict):
"""Type definition for extraction information."""
@@ -32,6 +46,20 @@ class ExtractionInfo(TypedDict):
status: str
error: str | None
sound_id: int | None
user_id: int
user_name: str | None
created_at: str
updated_at: str
class PaginatedExtractionsResponse(TypedDict):
"""Type definition for paginated extractions response."""
extractions: list[ExtractionInfo]
total: int
page: int
limit: int
total_pages: int
class ExtractionService:
@@ -42,6 +70,7 @@ class ExtractionService:
self.session = session
self.extraction_repo = ExtractionRepository(session)
self.sound_repo = SoundRepository(session)
self.user_repo = UserRepository(session)
self.playlist_service = PlaylistService(session)
# Ensure required directories exist
@@ -64,6 +93,15 @@ class ExtractionService:
logger.info("Creating extraction for URL: %s (user: %d)", url, user_id)
try:
# Get user information
user = await self.user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Extract user name immediately while in session context
user_name = user.name
# Create the extraction record without service detection for fast response
extraction_data = {
"url": url,
@@ -76,7 +114,10 @@ class ExtractionService:
extraction = await self.extraction_repo.create(extraction_data)
logger.info("Created extraction with ID: %d", extraction.id)
except Exception:
logger.exception("Failed to create extraction for URL: %s", url)
raise
else:
return {
"id": extraction.id or 0, # Should never be None for created extraction
"url": extraction.url,
@@ -86,12 +127,12 @@ class ExtractionService:
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user_name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
except Exception:
logger.exception("Failed to create extraction for URL: %s", url)
raise
async def _detect_service_info(self, url: str) -> dict[str, str | None] | None:
"""Detect service information from URL using yt-dlp."""
try:
@@ -123,14 +164,16 @@ class ExtractionService:
logger.exception("Failed to detect service info for URL: %s", url)
return None
async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
"""Process an extraction job."""
async def _validate_extraction(self, extraction_id: int) -> tuple:
"""Validate extraction and return extraction data."""
extraction = await self.extraction_repo.get_by_id(extraction_id)
if not extraction:
raise ValueError(f"Extraction {extraction_id} not found")
msg = f"Extraction {extraction_id} not found"
raise ValueError(msg)
if extraction.status != "pending":
raise ValueError(f"Extraction {extraction_id} is not pending")
msg = f"Extraction {extraction_id} is not pending"
raise ValueError(msg)
# Store all needed values early to avoid session detachment issues
user_id = extraction.user_id
@@ -139,106 +182,274 @@ class ExtractionService:
extraction_service_id = extraction.service_id
extraction_title = extraction.title
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
# Get user information for return value
try:
# Update status to processing
await self.extraction_repo.update(extraction, {"status": "processing"})
user = await self.user_repo.get_by_id(user_id)
user_name = user.name if user else None
except Exception:
logger.exception("Failed to get user %d for extraction", user_id)
user_name = None
# Detect service info if not already available
if not extraction_service or not extraction_service_id:
logger.info("Detecting service info for extraction %d", extraction_id)
service_info = await self._detect_service_info(extraction_url)
return (
extraction,
user_id,
extraction_url,
extraction_service,
extraction_service_id,
extraction_title,
user_name,
)
if not service_info:
raise ValueError("Unable to detect service information from URL")
# Check if extraction already exists for this service
existing = await self.extraction_repo.get_by_service_and_id(
service_info["service"], service_info["service_id"]
)
if existing and existing.id != extraction_id:
error_msg = (
f"Extraction already exists for "
f"{service_info['service']}:{service_info['service_id']}"
)
logger.warning(error_msg)
raise ValueError(error_msg)
# Update extraction with service info
update_data = {
"service": service_info["service"],
"service_id": service_info["service_id"],
"title": service_info.get("title") or extraction_title,
}
await self.extraction_repo.update(extraction, update_data)
# Update values for processing
extraction_service = service_info["service"]
extraction_service_id = service_info["service_id"]
extraction_title = service_info.get("title") or extraction_title
# Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_url
async def _handle_service_detection(
self,
extraction: Extraction,
context: ExtractionContext,
) -> tuple:
"""Handle service detection and duplicate checking."""
if context.extraction_service and context.extraction_service_id:
return (
context.extraction_service,
context.extraction_service_id,
context.extraction_title,
)
# Move files to final locations
(
final_audio_path,
final_thumbnail_path,
) = await self._move_files_to_final_location(
logger.info("Detecting service info for extraction %d", context.extraction_id)
service_info = await self._detect_service_info(context.extraction_url)
if not service_info:
msg = "Unable to detect service information from URL"
raise ValueError(msg)
# Check if extraction already exists for this service
service_name = service_info["service"]
service_id_val = service_info["service_id"]
if not service_name or not service_id_val:
msg = "Service info is incomplete"
raise ValueError(msg)
existing = await self.extraction_repo.get_by_service_and_id(
service_name,
service_id_val,
)
if existing and existing.id != context.extraction_id:
error_msg = (
f"Extraction already exists for "
f"{service_info['service']}:{service_info['service_id']}"
)
logger.warning(error_msg)
raise ValueError(error_msg)
# Update extraction with service info
update_data = {
"service": service_info["service"],
"service_id": service_info["service_id"],
"title": service_info.get("title") or context.extraction_title,
}
await self.extraction_repo.update(extraction, update_data)
# Update values for processing
new_service = service_info["service"]
new_service_id = service_info["service_id"]
new_title = service_info.get("title") or context.extraction_title
await self._emit_extraction_event(
context.user_id,
{
"extraction_id": context.extraction_id,
"status": "processing",
"title": new_title,
"url": context.extraction_url,
},
)
return new_service, new_service_id, new_title
async def _process_media_files(
self,
extraction_id: int,
extraction_url: str,
extraction_title: str | None,
extraction_service: str,
extraction_service_id: str,
) -> int:
"""Process media files and create sound record."""
# Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media(
extraction_id,
extraction_url,
)
# Move files to final locations
final_audio_path, final_thumbnail_path = (
await self._move_files_to_final_location(
audio_file,
thumbnail_file,
extraction_title,
extraction_service,
extraction_service_id,
)
)
# Create Sound record
sound = await self._create_sound_record(
final_audio_path,
extraction_title,
# Create Sound record
sound = await self._create_sound_record(
final_audio_path,
final_thumbnail_path,
extraction_title,
extraction_service,
extraction_service_id,
)
if not sound.id:
msg = "Sound creation failed - no ID returned"
raise RuntimeError(msg)
return sound.id
async def _complete_extraction(
self,
extraction: Extraction,
context: ExtractionContext,
sound_id: int,
) -> None:
"""Complete extraction processing."""
# Normalize the sound
await self._normalize_sound(sound_id)
# Add to main playlist
await self._add_to_main_playlist(sound_id, context.user_id)
# Update extraction with success
await self.extraction_repo.update(
extraction,
{
"status": "completed",
"sound_id": sound_id,
"error": None,
},
)
# Emit WebSocket event for completion
await self._emit_extraction_event(
context.user_id,
{
"extraction_id": context.extraction_id,
"status": "completed",
"title": context.extraction_title,
"url": context.extraction_url,
"sound_id": sound_id,
},
)
async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
"""Process an extraction job."""
# Validate extraction and get context data
(
extraction,
user_id,
extraction_url,
extraction_service,
extraction_service_id,
extraction_title,
user_name,
) = await self._validate_extraction(extraction_id)
# Create context object for helper methods
context = ExtractionContext(
extraction_id=extraction_id,
extraction_url=extraction_url,
extraction_service=extraction_service,
extraction_service_id=extraction_service_id,
extraction_title=extraction_title,
user_id=user_id,
)
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
try:
# Update status to processing
await self.extraction_repo.update(extraction, {"status": "processing"})
# Emit WebSocket event for processing start
await self._emit_extraction_event(
context.user_id,
{
"extraction_id": context.extraction_id,
"status": "processing",
"title": context.extraction_title or "Processing extraction...",
"url": context.extraction_url,
},
)
# Handle service detection and duplicate checking
extraction_service, extraction_service_id, extraction_title = (
await self._handle_service_detection(extraction, context)
)
# Update context with potentially new values
context.extraction_service = extraction_service
context.extraction_service_id = extraction_service_id
context.extraction_title = extraction_title
# Process media files and create sound record
sound_id = await self._process_media_files(
context.extraction_id,
context.extraction_url,
context.extraction_title,
extraction_service,
extraction_service_id,
)
# Store sound_id early to avoid session detachment issues
sound_id = sound.id
# Complete extraction processing
await self._complete_extraction(extraction, context, sound_id)
# Normalize the sound
await self._normalize_sound(sound_id)
logger.info("Successfully processed extraction %d", context.extraction_id)
# Add to main playlist
await self._add_to_main_playlist(sound_id, user_id)
# Update extraction with success
await self.extraction_repo.update(
extraction,
{
"status": "completed",
"sound_id": sound_id,
"error": None,
},
# Get updated extraction to get latest timestamps
updated_extraction = await self.extraction_repo.get_by_id(
context.extraction_id,
)
logger.info("Successfully processed extraction %d", extraction_id)
return {
"id": extraction_id,
"url": extraction_url,
"id": context.extraction_id,
"url": context.extraction_url,
"service": extraction_service,
"service_id": extraction_service_id,
"title": extraction_title,
"status": "completed",
"error": None,
"sound_id": sound_id,
"user_id": context.user_id,
"user_name": user_name,
"created_at": (
updated_extraction.created_at.isoformat()
if updated_extraction
else ""
),
"updated_at": (
updated_extraction.updated_at.isoformat()
if updated_extraction
else ""
),
}
except Exception as e:
error_msg = str(e)
logger.exception(
"Failed to process extraction %d: %s", extraction_id, error_msg
"Failed to process extraction %d: %s",
context.extraction_id,
error_msg,
)
# Emit WebSocket event for failure
await self._emit_extraction_event(
context.user_id,
{
"extraction_id": context.extraction_id,
"status": "failed",
"title": context.extraction_title or "Extraction failed",
"url": context.extraction_url,
"error": error_msg,
},
)
# Update extraction with error
@@ -250,26 +461,44 @@ class ExtractionService:
},
)
# Get updated extraction to get latest timestamps
updated_extraction = await self.extraction_repo.get_by_id(
context.extraction_id,
)
return {
"id": extraction_id,
"url": extraction_url,
"service": extraction_service,
"service_id": extraction_service_id,
"title": extraction_title,
"id": context.extraction_id,
"url": context.extraction_url,
"service": context.extraction_service,
"service_id": context.extraction_service_id,
"title": context.extraction_title,
"status": "failed",
"error": error_msg,
"sound_id": None,
"user_id": context.user_id,
"user_name": user_name,
"created_at": (
updated_extraction.created_at.isoformat()
if updated_extraction
else ""
),
"updated_at": (
updated_extraction.updated_at.isoformat()
if updated_extraction
else ""
),
}
async def _extract_media(
self, extraction_id: int, extraction_url: str
self,
extraction_id: int,
extraction_url: str,
) -> tuple[Path, Path | None]:
"""Extract audio and thumbnail using yt-dlp."""
temp_dir = Path(settings.EXTRACTION_TEMP_DIR)
# Create unique filename based on extraction ID
output_template = str(
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s"
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s",
)
# Configure yt-dlp options
@@ -304,8 +533,8 @@ class ExtractionService:
# Find the extracted files
audio_files = list(
temp_dir.glob(
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}"
)
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}",
),
)
thumbnail_files = (
list(temp_dir.glob(f"extraction_{extraction_id}_*.webp"))
@@ -314,7 +543,8 @@ class ExtractionService:
)
if not audio_files:
raise RuntimeError("No audio file was created during extraction")
msg = "No audio file was created during extraction"
raise RuntimeError(msg)
audio_file = audio_files[0]
thumbnail_file = thumbnail_files[0] if thumbnail_files else None
@@ -325,11 +555,12 @@ class ExtractionService:
thumbnail_file or "None",
)
return audio_file, thumbnail_file
except Exception as e:
logger.exception("yt-dlp extraction failed for %s", extraction_url)
raise RuntimeError(f"Audio extraction failed: {e}") from e
error_msg = f"Audio extraction failed: {e}"
raise RuntimeError(error_msg) from e
else:
return audio_file, thumbnail_file
async def _move_files_to_final_location(
self,
@@ -342,7 +573,7 @@ class ExtractionService:
"""Move extracted files to their final locations."""
# Generate clean filename based on title and service
safe_title = self._sanitize_filename(
title or f"{service or 'unknown'}_{service_id or 'unknown'}"
title or f"{service or 'unknown'}_{service_id or 'unknown'}",
)
# Move audio file
@@ -401,6 +632,7 @@ class ExtractionService:
async def _create_sound_record(
self,
audio_path: Path,
thumbnail_path: Path | None,
title: str | None,
service: str | None,
service_id: str | None,
@@ -419,6 +651,7 @@ class ExtractionService:
"duration": duration,
"size": size,
"hash": file_hash,
"thumbnail": thumbnail_path.name if thumbnail_path else None,
"is_deletable": True, # Extracted sounds can be deleted
"is_music": True, # Assume extracted content is music
"is_normalized": False,
@@ -426,7 +659,11 @@ class ExtractionService:
}
sound = await self.sound_repo.create(sound_data)
logger.info("Created sound record with ID: %d", sound.id)
logger.info(
"Created sound record with ID: %d, thumbnail: %s",
sound.id,
thumbnail_path.name if thumbnail_path else "None",
)
return sound
@@ -451,14 +688,17 @@ class ExtractionService:
else:
logger.info("Successfully normalized sound %d", sound_id)
except Exception as e:
logger.exception("Error normalizing sound %d: %s", sound_id, e)
except Exception:
logger.exception("Error normalizing sound %d", sound_id)
# Don't fail the extraction if normalization fails
async def _add_to_main_playlist(self, sound_id: int, user_id: int) -> None:
"""Add the sound to the user's main playlist."""
try:
await self.playlist_service.add_sound_to_main_playlist(sound_id, user_id)
await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001
sound_id,
user_id,
)
logger.info(
"Added sound %d to main playlist for user %d",
sound_id,
@@ -473,12 +713,31 @@ class ExtractionService:
)
# Don't fail the extraction if playlist addition fails
async def _emit_extraction_event(self, user_id: int, data: dict) -> None:
"""Emit WebSocket event for extraction status updates to all users."""
try:
# Import here to avoid circular imports
from app.services.socket import socket_manager # noqa: PLC0415
await socket_manager.broadcast_to_all("extraction_status_update", data)
logger.debug(
"Broadcasted extraction event (initiated by user %d): %s",
user_id,
data["status"],
)
except Exception:
logger.exception("Failed to emit extraction event")
async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None:
"""Get extraction information by ID."""
extraction = await self.extraction_repo.get_by_id(extraction_id)
if not extraction:
return None
# Get user information
user = await self.user_repo.get_by_id(extraction.user_id)
user_name = user.name if user else None
return {
"id": extraction.id or 0, # Should never be None for existing extraction
"url": extraction.url,
@@ -488,13 +747,38 @@ class ExtractionService:
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user_name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
async def get_user_extractions(self, user_id: int) -> list[ExtractionInfo]:
"""Get all extractions for a user."""
extractions = await self.extraction_repo.get_by_user(user_id)
async def get_user_extractions( # noqa: PLR0913
self,
user_id: int,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
page: int = 1,
limit: int = 50,
) -> PaginatedExtractionsResponse:
"""Get all extractions for a user with filtering, search, and sorting."""
offset = (page - 1) * limit
(
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_user_extractions_filtered(
user_id=user_id,
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
limit=limit,
offset=offset,
)
return [
extractions = [
{
"id": extraction.id
or 0, # Should never be None for existing extraction
@@ -505,13 +789,79 @@ class ExtractionService:
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user.name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
for extraction in extractions
for extraction, user in extraction_user_tuples
]
total_pages = (total_count + limit - 1) // limit # Ceiling division
return {
"extractions": extractions,
"total": total_count,
"page": page,
"limit": limit,
"total_pages": total_pages,
}
async def get_all_extractions( # noqa: PLR0913
self,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
page: int = 1,
limit: int = 50,
) -> PaginatedExtractionsResponse:
"""Get all extractions with filtering, search, and sorting."""
offset = (page - 1) * limit
(
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_all_extractions_filtered(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
limit=limit,
offset=offset,
)
extractions = [
{
"id": extraction.id
or 0, # Should never be None for existing extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user.name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
for extraction, user in extraction_user_tuples
]
total_pages = (total_count + limit - 1) // limit # Ceiling division
return {
"extractions": extractions,
"total": total_count,
"page": page,
"limit": limit,
"total_pages": total_pages,
}
async def get_pending_extractions(self) -> list[ExtractionInfo]:
"""Get all pending extractions."""
extractions = await self.extraction_repo.get_pending_extractions()
extraction_user_tuples = await self.extraction_repo.get_pending_extractions()
return [
{
@@ -524,6 +874,181 @@ class ExtractionService:
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user.name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
for extraction in extractions
for extraction, user in extraction_user_tuples
]
async def delete_extraction(
self,
extraction_id: int,
user_id: int | None = None,
) -> bool:
"""Delete an extraction and its associated sound and files.
Args:
extraction_id: The ID of the extraction to delete
user_id: Optional user ID for ownership verification (None for admin)
Returns:
True if deletion was successful, False if extraction not found
Raises:
ValueError: If user doesn't own the extraction (when user_id is provided)
"""
logger.info(
"Deleting extraction: %d (user: %s)",
extraction_id,
user_id or "admin",
)
# Get the extraction record
extraction = await self.extraction_repo.get_by_id(extraction_id)
if not extraction:
logger.warning("Extraction %d not found", extraction_id)
return False
# Check ownership if user_id is provided (non-admin request)
if user_id is not None and extraction.user_id != user_id:
msg = "You don't have permission to delete this extraction"
raise ValueError(msg)
# Get associated sound if it exists and capture its attributes immediately
sound_data = None
sound_object = None
if extraction.sound_id:
sound_object = await self.sound_repo.get_by_id(extraction.sound_id)
if sound_object:
# Capture attributes immediately while session is valid
sound_data = {
"id": sound_object.id,
"type": sound_object.type,
"filename": sound_object.filename,
"is_normalized": sound_object.is_normalized,
"normalized_filename": sound_object.normalized_filename,
"thumbnail": sound_object.thumbnail,
}
try:
# Delete the extraction record first
await self.extraction_repo.delete(extraction)
logger.info("Deleted extraction record: %d", extraction_id)
# Check if sound was in current playlist before deletion
sound_was_in_current_playlist = False
if sound_object and sound_data:
sound_was_in_current_playlist = (
await self._check_sound_in_current_playlist(sound_data["id"])
)
# If there's an associated sound, delete it and its files
if sound_object and sound_data:
await self._delete_sound_and_files(sound_object, sound_data)
logger.info(
"Deleted associated sound: %d (%s)",
sound_data["id"],
sound_data["filename"],
)
# Commit the transaction
await self.session.commit()
# Reload player playlist if deleted sound was in current playlist
if sound_was_in_current_playlist and sound_data:
await self._reload_player_playlist()
logger.info(
"Reloaded player playlist after deleting sound %d "
"from current playlist",
sound_data["id"],
)
except Exception:
# Rollback on any error
await self.session.rollback()
logger.exception("Failed to delete extraction %d", extraction_id)
raise
else:
return True
async def _delete_sound_and_files(
self,
sound: Sound,
sound_data: dict[str, Any],
) -> None:
"""Delete a sound record and all its associated files."""
# Collect all file paths to delete using captured attributes
files_to_delete = []
# Original audio file
if sound_data["type"] == "EXT": # Extracted sounds
original_path = Path("sounds/originals/extracted") / sound_data["filename"]
if original_path.exists():
files_to_delete.append(original_path)
# Normalized file
if sound_data["is_normalized"] and sound_data["normalized_filename"]:
normalized_path = (
Path("sounds/normalized/extracted") / sound_data["normalized_filename"]
)
if normalized_path.exists():
files_to_delete.append(normalized_path)
# Thumbnail file
if sound_data["thumbnail"]:
thumbnail_path = (
Path(settings.EXTRACTION_THUMBNAILS_DIR) / sound_data["thumbnail"]
)
if thumbnail_path.exists():
files_to_delete.append(thumbnail_path)
# Delete the sound from database first
await self.sound_repo.delete(sound)
# Delete all associated files
for file_path in files_to_delete:
try:
file_path.unlink()
logger.info("Deleted file: %s", file_path)
except OSError:
logger.exception("Failed to delete file %s", file_path)
# Continue with other files even if one fails
async def _check_sound_in_current_playlist(self, sound_id: int) -> bool:
"""Check if a sound is in the current playlist."""
try:
from app.repositories.playlist import PlaylistRepository # noqa: PLC0415
playlist_repo = PlaylistRepository(self.session)
current_playlist = await playlist_repo.get_current_playlist()
if not current_playlist or not current_playlist.id:
return False
return await playlist_repo.is_sound_in_playlist(
current_playlist.id, sound_id,
)
except (ImportError, AttributeError, ValueError, RuntimeError) as e:
logger.warning(
"Failed to check if sound %s is in current playlist: %s",
sound_id,
e,
exc_info=True,
)
return False
async def _reload_player_playlist(self) -> None:
"""Reload the player playlist after a sound is deleted."""
try:
# Import here to avoid circular import issues
from app.services.player import get_player_service # noqa: PLC0415
player = get_player_service()
await player.reload_playlist()
logger.debug("Player playlist reloaded after sound deletion")
except (ImportError, AttributeError, ValueError, RuntimeError) as e:
# Don't fail the deletion operation if player reload fails
logger.warning("Failed to reload player playlist: %s", e, exc_info=True)

View File

@@ -1,6 +1,7 @@
"""Background extraction processor for handling extraction queue."""
import asyncio
import contextlib
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -34,6 +35,9 @@ class ExtractionProcessor:
logger.warning("Extraction processor is already running")
return
# Reset any stuck extractions from previous runs
await self._reset_stuck_extractions()
self.shutdown_event.clear()
self.processor_task = asyncio.create_task(self._process_queue())
logger.info("Started extraction processor")
@@ -46,15 +50,13 @@ class ExtractionProcessor:
if self.processor_task and not self.processor_task.done():
try:
await asyncio.wait_for(self.processor_task, timeout=30.0)
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
"Extraction processor did not stop gracefully, cancelling..."
"Extraction processor did not stop gracefully, cancelling...",
)
self.processor_task.cancel()
try:
with contextlib.suppress(asyncio.CancelledError):
await self.processor_task
except asyncio.CancelledError:
pass
logger.info("Extraction processor stopped")
@@ -66,7 +68,8 @@ class ExtractionProcessor:
# The processor will pick it up on the next cycle
else:
logger.warning(
"Extraction %d is already being processed", extraction_id
"Extraction %d is already being processed",
extraction_id,
)
async def _process_queue(self) -> None:
@@ -81,16 +84,16 @@ class ExtractionProcessor:
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
break # Shutdown requested
except asyncio.TimeoutError:
except TimeoutError:
continue # Continue processing
except Exception as e:
logger.exception("Error in extraction queue processor: %s", e)
except Exception:
logger.exception("Error in extraction queue processor")
# Wait a bit before retrying to avoid tight error loops
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
break # Shutdown requested
except asyncio.TimeoutError:
except TimeoutError:
continue
logger.info("Extraction queue processor stopped")
@@ -125,13 +128,13 @@ class ExtractionProcessor:
# Start processing this extraction in the background
task = asyncio.create_task(
self._process_single_extraction(extraction_id)
self._process_single_extraction(extraction_id),
)
task.add_done_callback(
lambda t, eid=extraction_id: self._on_extraction_completed(
eid,
t,
)
),
)
logger.info(
@@ -156,8 +159,8 @@ class ExtractionProcessor:
result["status"],
)
except Exception as e:
logger.exception("Error processing extraction %d: %s", extraction_id, e)
except Exception:
logger.exception("Error processing extraction %d", extraction_id)
def _on_extraction_completed(self, extraction_id: int, task: asyncio.Task) -> None:
"""Handle completion of an extraction task."""
@@ -179,6 +182,47 @@ class ExtractionProcessor:
self.max_concurrent,
)
async def _reset_stuck_extractions(self) -> None:
"""Reset any extractions stuck in 'processing' status back to 'pending'."""
try:
async with AsyncSession(engine) as session:
extraction_service = ExtractionService(session)
# Get all extractions stuck in processing status
stuck_extractions = (
await extraction_service.extraction_repo.get_by_status("processing")
)
if not stuck_extractions:
logger.info("No stuck extractions found to reset")
return
reset_count = 0
for extraction in stuck_extractions:
try:
await extraction_service.extraction_repo.update(
extraction, {"status": "pending", "error": None},
)
reset_count += 1
logger.info(
"Reset stuck extraction %d from processing to pending",
extraction.id,
)
except Exception:
logger.exception(
"Failed to reset extraction %d", extraction.id,
)
await session.commit()
logger.info(
"Successfully reset %d stuck extractions from processing to "
"pending",
reset_count,
)
except Exception:
logger.exception("Failed to reset stuck extractions")
def get_status(self) -> dict:
"""Get the current status of the extraction processor."""
return {

382
app/services/favorite.py Normal file
View File

@@ -0,0 +1,382 @@
"""Service for managing user favorites."""
from collections.abc import Callable
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.favorite import Favorite
from app.repositories.favorite import FavoriteRepository
from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository
from app.repositories.user import UserRepository
from app.services.socket import socket_manager
logger = get_logger(__name__)
class FavoriteService:
"""Service for managing user favorites."""
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
"""Initialize the favorite service.
Args:
db_session_factory: Factory function to create database sessions
"""
self.db_session_factory = db_session_factory
async def add_sound_favorite(self, user_id: int, sound_id: int) -> Favorite:
"""Add a sound to user's favorites.
Args:
user_id: The user ID
sound_id: The sound ID
Returns:
The created favorite
Raises:
ValueError: If user or sound not found, or already favorited
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
user_repo = UserRepository(session)
sound_repo = SoundRepository(session)
# Verify user exists
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify sound exists
sound = await sound_repo.get_by_id(sound_id)
if not sound:
msg = f"Sound with ID {sound_id} not found"
raise ValueError(msg)
# Get data for the event immediately after loading
sound_name = sound.name
user_name = user.name
# Check if already favorited
existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if existing:
msg = f"Sound {sound_id} is already favorited by user {user_id}"
raise ValueError(msg)
# Create favorite
favorite_data = {
"user_id": user_id,
"sound_id": sound_id,
"playlist_id": None,
}
favorite = await favorite_repo.create(favorite_data)
logger.info("User %s favorited sound %s", user_id, sound_id)
# Get updated favorite count within the same session
favorite_count = await favorite_repo.count_sound_favorites(sound_id)
# Emit sound_favorited event via WebSocket (outside the session)
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": user_id,
"user_name": user_name,
"favorite_count": favorite_count,
}
await socket_manager.broadcast_to_all("sound_favorited", event_data)
logger.info("Broadcasted sound_favorited event for sound %s", sound_id)
except Exception:
logger.exception(
"Failed to broadcast sound_favorited event for sound %s",
sound_id,
)
return favorite
async def add_playlist_favorite(self, user_id: int, playlist_id: int) -> Favorite:
"""Add a playlist to user's favorites.
Args:
user_id: The user ID
playlist_id: The playlist ID
Returns:
The created favorite
Raises:
ValueError: If user or playlist not found, or already favorited
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
user_repo = UserRepository(session)
playlist_repo = PlaylistRepository(session)
# Verify user exists
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify playlist exists
playlist = await playlist_repo.get_by_id(playlist_id)
if not playlist:
msg = f"Playlist with ID {playlist_id} not found"
raise ValueError(msg)
# Check if already favorited
existing = await favorite_repo.get_by_user_and_playlist(
user_id,
playlist_id,
)
if existing:
msg = f"Playlist {playlist_id} is already favorited by user {user_id}"
raise ValueError(msg)
# Create favorite
favorite_data = {
"user_id": user_id,
"sound_id": None,
"playlist_id": playlist_id,
}
favorite = await favorite_repo.create(favorite_data)
logger.info("User %s favorited playlist %s", user_id, playlist_id)
return favorite
async def remove_sound_favorite(self, user_id: int, sound_id: int) -> None:
"""Remove a sound from user's favorites.
Args:
user_id: The user ID
sound_id: The sound ID
Raises:
ValueError: If favorite not found
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if not favorite:
msg = f"Sound {sound_id} is not favorited by user {user_id}"
raise ValueError(msg)
# Get user and sound info before deletion for the event
user_repo = UserRepository(session)
sound_repo = SoundRepository(session)
user = await user_repo.get_by_id(user_id)
sound = await sound_repo.get_by_id(sound_id)
# Get data for the event immediately after loading
sound_name = sound.name if sound else "Unknown"
user_name = user.name if user else "Unknown"
await favorite_repo.delete(favorite)
logger.info("User %s removed sound %s from favorites", user_id, sound_id)
# Get updated favorite count after deletion within the same session
favorite_count = await favorite_repo.count_sound_favorites(sound_id)
# Emit sound_favorited event via WebSocket (outside the session)
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": user_id,
"user_name": user_name,
"favorite_count": favorite_count,
}
await socket_manager.broadcast_to_all("sound_favorited", event_data)
logger.info(
"Broadcasted sound_favorited event for sound %s removal",
sound_id,
)
except Exception:
logger.exception(
"Failed to broadcast sound_favorited event for sound %s removal",
sound_id,
)
async def remove_playlist_favorite(self, user_id: int, playlist_id: int) -> None:
"""Remove a playlist from user's favorites.
Args:
user_id: The user ID
playlist_id: The playlist ID
Raises:
ValueError: If favorite not found
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
favorite = await favorite_repo.get_by_user_and_playlist(
user_id,
playlist_id,
)
if not favorite:
msg = f"Playlist {playlist_id} is not favorited by user {user_id}"
raise ValueError(msg)
await favorite_repo.delete(favorite)
logger.info(
"User %s removed playlist %s from favorites",
user_id,
playlist_id,
)
async def get_user_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get all favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user favorites
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.get_user_favorites(user_id, limit, offset)
async def get_user_sound_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get sound favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user sound favorites
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.get_user_sound_favorites(user_id, limit, offset)
async def get_user_playlist_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get playlist favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user playlist favorites
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.get_user_playlist_favorites(
user_id,
limit,
offset,
)
async def is_sound_favorited(self, user_id: int, sound_id: int) -> bool:
"""Check if a sound is favorited by a user.
Args:
user_id: The user ID
sound_id: The sound ID
Returns:
True if the sound is favorited, False otherwise
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.is_sound_favorited(user_id, sound_id)
async def is_playlist_favorited(self, user_id: int, playlist_id: int) -> bool:
"""Check if a playlist is favorited by a user.
Args:
user_id: The user ID
playlist_id: The playlist ID
Returns:
True if the playlist is favorited, False otherwise
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.is_playlist_favorited(user_id, playlist_id)
async def get_favorite_counts(self, user_id: int) -> dict[str, int]:
"""Get favorite counts for a user.
Args:
user_id: The user ID
Returns:
Dictionary with favorite counts
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
total = await favorite_repo.count_user_favorites(user_id)
sounds = len(await favorite_repo.get_user_sound_favorites(user_id))
playlists = len(await favorite_repo.get_user_playlist_favorites(user_id))
return {
"total": total,
"sounds": sounds,
"playlists": playlists,
}
async def get_sound_favorite_count(self, sound_id: int) -> int:
"""Get the number of users who have favorited a sound.
Args:
sound_id: The sound ID
Returns:
Number of users who favorited this sound
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.count_sound_favorites(sound_id)
async def get_playlist_favorite_count(self, playlist_id: int) -> int:
"""Get the number of users who have favorited a playlist.
Args:
playlist_id: The playlist ID
Returns:
Number of users who favorited this playlist
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.count_playlist_favorites(playlist_id)

View File

@@ -70,7 +70,7 @@ class OAuthProvider(ABC):
"""Generate authorization URL with state parameter."""
# Construct provider-specific redirect URI
redirect_uri = (
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
f"{settings.BACKEND_URL}/api/v1/auth/{self.provider_name}/callback"
)
params = {
@@ -86,7 +86,7 @@ class OAuthProvider(ABC):
"""Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request)
redirect_uri = (
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
f"{settings.BACKEND_URL}/api/v1/auth/{self.provider_name}/callback"
)
data = {
@@ -150,7 +150,7 @@ class GoogleOAuthProvider(OAuthProvider):
"""Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request)
redirect_uri = (
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
f"{settings.BACKEND_URL}/api/v1/auth/{self.provider_name}/callback"
)
data = {

View File

@@ -8,6 +8,8 @@ from enum import Enum
from typing import Any
import vlc # type: ignore[import-untyped]
from sqlalchemy.orm import selectinload
from sqlmodel import select
from app.core.logging import get_logger
from app.models.playlist import Playlist
@@ -16,6 +18,7 @@ from app.models.sound_played import SoundPlayed
from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository
from app.services.socket import socket_manager
from app.services.volume import volume_service
from app.utils.audio import get_sound_file_path
logger = get_logger(__name__)
@@ -46,7 +49,11 @@ class PlayerState:
"""Initialize player state."""
self.status: PlayerStatus = PlayerStatus.STOPPED
self.mode: PlayerMode = PlayerMode.CONTINUOUS
self.volume: int = 50
# Initialize volume from host system or default to 80
host_volume = volume_service.get_volume()
self.volume: int = host_volume if host_volume is not None else 80
self.previous_volume: int = self.volume
self.current_sound_id: int | None = None
self.current_sound_index: int | None = None
self.current_sound_position: int = 0
@@ -57,24 +64,35 @@ class PlayerState:
self.playlist_length: int = 0
self.playlist_duration: int = 0
self.playlist_sounds: list[Sound] = []
self.play_next_queue: list[Sound] = []
self.playlist_index_before_play_next: int | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert player state to dictionary for serialization."""
return {
"status": self.status.value,
"mode": self.mode.value,
"mode": self.mode.value if isinstance(self.mode, PlayerMode) else self.mode,
"volume": self.volume,
"current_sound_id": self.current_sound_id,
"current_sound_index": self.current_sound_index,
"current_sound_position": self.current_sound_position,
"current_sound_duration": self.current_sound_duration,
"previous_volume": self.previous_volume,
"position": self.current_sound_position or 0,
"duration": self.current_sound_duration,
"index": self.current_sound_index,
"current_sound": self._serialize_sound(self.current_sound),
"playlist_id": self.playlist_id,
"playlist_name": self.playlist_name,
"playlist_length": self.playlist_length,
"playlist_duration": self.playlist_duration,
"playlist_sounds": [
self._serialize_sound(sound) for sound in self.playlist_sounds
"playlist": (
{
"id": self.playlist_id,
"name": self.playlist_name,
"length": self.playlist_length,
"duration": self.playlist_duration,
"sounds": [
self._serialize_sound(sound) for sound in self.playlist_sounds
],
}
if self.playlist_id
else None
),
"play_next_queue": [
self._serialize_sound(sound) for sound in self.play_next_queue
],
}
@@ -82,6 +100,14 @@ class PlayerState:
"""Serialize a sound object for JSON serialization."""
if not sound:
return None
# Get extraction URL if sound is linked to an extraction
extract_url = None
if hasattr(sound, "extractions") and sound.extractions:
# Get the first extraction (there should only be one per sound)
extraction = sound.extractions[0]
extract_url = extraction.url
return {
"id": sound.id,
"name": sound.name,
@@ -91,6 +117,7 @@ class PlayerState:
"type": sound.type,
"thumbnail": sound.thumbnail,
"play_count": sound.play_count,
"extract_url": extract_url,
}
@@ -102,6 +129,14 @@ class PlayerService:
self.db_session_factory = db_session_factory
self.state = PlayerState()
self._vlc_instance = vlc.Instance()
if self._vlc_instance is None:
msg = (
"VLC instance could not be created. "
"Ensure VLC is installed and accessible."
)
raise RuntimeError(msg)
self._player = self._vlc_instance.media_player_new()
self._is_running = False
self._position_thread: threading.Thread | None = None
@@ -124,12 +159,13 @@ class PlayerService:
# Start position tracking thread
self._position_thread = threading.Thread(
target=self._position_tracker, daemon=True,
target=self._position_tracker,
daemon=True,
)
self._position_thread.start()
# Set initial volume
self._player.audio_set_volume(self.state.volume)
# Set VLC to 100% volume - host volume is controlled separately
self._player.audio_set_volume(100)
logger.info("Player service started")
@@ -152,83 +188,139 @@ class PlayerService:
async def play(self, index: int | None = None) -> None:
"""Play audio at specified index or current position."""
# Check if we're resuming from pause
is_resuming = (
index is None and
self.state.status == PlayerStatus.PAUSED and
self.state.current_sound is not None
)
if is_resuming:
# Simply resume playback
result = self._player.play()
if result == 0: # VLC returns 0 on success
self.state.status = PlayerStatus.PLAYING
# Ensure play time tracking is initialized for resumed track
if (
self.state.current_sound_id
and self.state.current_sound_id not in self._play_time_tracking
):
self._play_time_tracking[self.state.current_sound_id] = {
"total_time": 0,
"last_position": self.state.current_sound_position,
"last_update": time.time(),
"threshold_reached": False,
}
await self._broadcast_state()
logger.info("Resumed playing sound: %s", self.state.current_sound.name)
else:
logger.error("Failed to resume playback: VLC error code %s", result)
if self._should_resume_playback(index):
await self._resume_playback()
return
# Starting new track or changing track
if index is not None:
if index < 0 or index >= len(self.state.playlist_sounds):
msg = "Invalid sound index"
raise ValueError(msg)
self.state.current_sound_index = index
self.state.current_sound = self.state.playlist_sounds[index]
self.state.current_sound_id = self.state.current_sound.id
await self._start_new_track(index)
def _should_resume_playback(self, index: int | None) -> bool:
"""Check if we should resume paused playback."""
return (
index is None
and self.state.status == PlayerStatus.PAUSED
and self.state.current_sound is not None
)
async def _resume_playback(self) -> None:
"""Resume paused playback."""
result = self._player.play()
if result == 0: # VLC returns 0 on success
self.state.status = PlayerStatus.PLAYING
self._ensure_play_time_tracking_for_resume()
await self._broadcast_state()
sound_name = (
self.state.current_sound.name if self.state.current_sound else "Unknown"
)
logger.info("Resumed playing sound: %s", sound_name)
else:
logger.error("Failed to resume playback: VLC error code %s", result)
def _ensure_play_time_tracking_for_resume(self) -> None:
"""Ensure play time tracking is initialized for resumed track."""
if (
self.state.current_sound_id
and self.state.current_sound_id not in self._play_time_tracking
):
self._play_time_tracking[self.state.current_sound_id] = {
"total_time": 0,
"last_position": self.state.current_sound_position,
"last_update": time.time(),
"threshold_reached": False,
}
async def _start_new_track(self, index: int | None) -> None:
"""Start playing a new track."""
if not self._prepare_sound_for_playback(index):
return
if not self._load_and_play_media():
return
await self._handle_successful_playback()
def _prepare_sound_for_playback(self, index: int | None) -> bool:
"""Prepare sound for playback, return True if ready."""
if index is not None and not self._set_sound_by_index(index):
return False
if not self.state.current_sound:
logger.warning("No sound to play")
return
return False
return self._validate_sound_file()
def _set_sound_by_index(self, index: int) -> bool:
"""Set current sound by index, return True if valid."""
if index < 0 or index >= len(self.state.playlist_sounds):
msg = "Invalid sound index"
raise ValueError(msg)
self.state.current_sound_index = index
self.state.current_sound = self.state.playlist_sounds[index]
self.state.current_sound_id = self.state.current_sound.id
return True
def _validate_sound_file(self) -> bool:
"""Validate sound file exists, return True if valid."""
if not self.state.current_sound:
return False
# Get sound file path
sound_path = get_sound_file_path(self.state.current_sound)
if not sound_path.exists():
logger.error("Sound file not found: %s", sound_path)
return
return False
return True
# Load and play media (new track)
def _load_and_play_media(self) -> bool:
"""Load media and start playback, return True if successful."""
if self._vlc_instance is None:
logger.error("VLC instance is not initialized. Cannot play media.")
return False
if not self.state.current_sound:
logger.error("No current sound to play")
return False
sound_path = get_sound_file_path(self.state.current_sound)
media = self._vlc_instance.media_new(str(sound_path))
self._player.set_media(media)
result = self._player.play()
if result == 0: # VLC returns 0 on success
self.state.status = PlayerStatus.PLAYING
self.state.current_sound_duration = self.state.current_sound.duration or 0
# Initialize play time tracking for new track
if self.state.current_sound_id:
self._play_time_tracking[self.state.current_sound_id] = {
"total_time": 0,
"last_position": 0,
"last_update": time.time(),
"threshold_reached": False,
}
logger.info(
"Initialized play time tracking for sound %s (duration: %s ms)",
self.state.current_sound_id,
self.state.current_sound_duration,
)
await self._broadcast_state()
logger.info("Started playing sound: %s", self.state.current_sound.name)
else:
if result != 0: # VLC returns 0 on success
logger.error("Failed to start playback: VLC error code %s", result)
return False
return True
async def _handle_successful_playback(self) -> None:
"""Handle successful playback start."""
if not self.state.current_sound:
logger.error("No current sound for successful playback")
return
self.state.status = PlayerStatus.PLAYING
self.state.current_sound_duration = self.state.current_sound.duration or 0
self._initialize_play_time_tracking()
await self._broadcast_state()
logger.info("Started playing sound: %s", self.state.current_sound.name)
def _initialize_play_time_tracking(self) -> None:
"""Initialize play time tracking for new track."""
if self.state.current_sound_id:
self._play_time_tracking[self.state.current_sound_id] = {
"total_time": 0,
"last_position": 0,
"last_update": time.time(),
"threshold_reached": False,
}
logger.info(
"Initialized play time tracking for sound %s (duration: %s ms)",
self.state.current_sound_id,
self.state.current_sound_duration,
)
async def pause(self) -> None:
"""Pause playback."""
@@ -257,6 +349,31 @@ class PlayerService:
async def next(self) -> None:
"""Skip to next track."""
# Check if there's a track in the play_next queue
if self.state.play_next_queue:
await self._play_next_from_queue()
return
# If currently playing from play_next queue (no index but have stored index)
if (
self.state.current_sound_index is None
and self.state.playlist_index_before_play_next is not None
and self.state.playlist_sounds
):
# Skipped the last play_next track, go to next in playlist
restored_index = self.state.playlist_index_before_play_next
next_index = self._get_next_index(restored_index)
# Clear the stored index
self.state.playlist_index_before_play_next = None
if next_index is not None:
await self.play(next_index)
else:
await self._stop_playback()
await self._broadcast_state()
return
if not self.state.playlist_sounds:
return
@@ -297,26 +414,121 @@ class PlayerService:
logger.debug("Seeked to position: %sms", position_ms)
async def set_volume(self, volume: int) -> None:
"""Set playback volume (0-100)."""
"""Set playback volume (0-100) by controlling host system volume."""
volume = max(0, min(100, volume)) # Clamp to valid range
# Store previous volume when muting (going from >0 to 0)
if self.state.volume > 0 and volume == 0:
self.state.previous_volume = self.state.volume
self.state.volume = volume
self._player.audio_set_volume(volume)
# Control host system volume instead of VLC volume
if volume == 0:
# Mute the host system
volume_service.set_mute(muted=True)
else:
# Unmute and set host volume
if volume_service.is_muted():
volume_service.set_mute(muted=False)
volume_service.set_volume(volume)
# Keep VLC at 100% volume
self._player.audio_set_volume(100)
await self._broadcast_state()
logger.debug("Volume set to: %s", volume)
logger.debug("Host volume set to: %s", volume)
async def set_mode(self, mode: PlayerMode) -> None:
async def mute(self) -> None:
"""Mute the host system (stores current volume as previous_volume)."""
if self.state.volume > 0:
await self.set_volume(0)
async def unmute(self) -> None:
"""Unmute the host system (restores previous_volume)."""
if self.state.volume == 0 and self.state.previous_volume > 0:
await self.set_volume(self.state.previous_volume)
async def set_mode(self, mode: PlayerMode | str) -> None:
"""Set playback mode."""
if isinstance(mode, str):
# Convert string to PlayerMode enum
try:
mode = PlayerMode(mode)
except ValueError:
logger.exception("Invalid player mode: %s", mode)
return
self.state.mode = mode
await self._broadcast_state()
logger.info("Playback mode set to: %s", mode.value)
async def add_to_play_next(self, sound_id: int) -> None:
"""Add a sound to the play_next queue."""
session = self.db_session_factory()
try:
# Eagerly load extractions to avoid lazy loading issues
statement = select(Sound).where(Sound.id == sound_id)
statement = statement.options(selectinload(Sound.extractions)) # type: ignore[arg-type]
result = await session.exec(statement)
sound = result.first()
if not sound:
logger.warning("Sound %s not found for play_next", sound_id)
return
self.state.play_next_queue.append(sound)
await self._broadcast_state()
logger.info("Added sound %s to play_next queue", sound.name)
finally:
await session.close()
async def _play_next_from_queue(self) -> None:
"""Play the first track from the play_next queue."""
if not self.state.play_next_queue:
return
# Store current playlist index before switching to play_next track
# Only store if we're currently playing from the playlist
if (
self.state.current_sound_index is not None
and self.state.playlist_index_before_play_next is None
):
self.state.playlist_index_before_play_next = (
self.state.current_sound_index
)
logger.info(
"Stored playlist index %s before playing from play_next queue",
self.state.playlist_index_before_play_next,
)
# Get the first sound from the queue
next_sound = self.state.play_next_queue.pop(0)
# Stop current playback and process play count
if self.state.status != PlayerStatus.STOPPED:
await self._stop_playback()
# Set the sound as current (without index since it's from play_next)
self.state.current_sound = next_sound
self.state.current_sound_id = next_sound.id
self.state.current_sound_index = None # No index for play_next tracks
# Play the sound
if not self._validate_sound_file():
return
if not self._load_and_play_media():
return
await self._handle_successful_playback()
async def reload_playlist(self) -> None:
"""Reload current playlist from database."""
session = self.db_session_factory()
try:
playlist_repo = PlaylistRepository(session)
current_playlist = await playlist_repo.get_main_playlist()
current_playlist = await playlist_repo.get_current_playlist()
if current_playlist and current_playlist.id:
sounds = await playlist_repo.get_playlist_sounds(current_playlist.id)
@@ -333,6 +545,26 @@ class PlayerService:
await self._broadcast_state()
async def load_playlist(self, playlist_id: int) -> None:
"""Load a specific playlist by ID."""
session = self.db_session_factory()
try:
playlist_repo = PlaylistRepository(session)
playlist = await playlist_repo.get_by_id(playlist_id)
if playlist and playlist.id:
sounds = await playlist_repo.get_playlist_sounds(playlist.id)
await self._handle_playlist_reload(playlist, sounds)
logger.info(
"Loaded playlist: %s (%s sounds)",
playlist.name,
len(sounds),
)
else:
logger.warning("Playlist not found: %s", playlist_id)
finally:
await session.close()
await self._broadcast_state()
async def _handle_playlist_reload(
self,
current_playlist: Playlist,
@@ -353,7 +585,9 @@ class PlayerService:
and previous_playlist_id != current_playlist.id
):
await self._handle_playlist_id_changed(
previous_playlist_id, current_playlist.id, sounds,
previous_playlist_id,
current_playlist.id,
sounds,
)
elif previous_current_sound_id:
await self._handle_same_playlist_track_check(
@@ -377,6 +611,16 @@ class PlayerService:
current_id,
)
# Clear play_next queue when playlist changes
if self.state.play_next_queue:
logger.info("Clearing play_next queue due to playlist change")
self.state.play_next_queue.clear()
# Clear stored playlist index
if self.state.playlist_index_before_play_next is not None:
logger.info("Clearing stored playlist index due to playlist change")
self.state.playlist_index_before_play_next = None
if self.state.status != PlayerStatus.STOPPED:
await self._stop_playback()
@@ -392,6 +636,9 @@ class PlayerService:
sounds: list[Sound],
) -> None:
"""Handle track checking when playlist ID is the same."""
# Remove tracks from play_next queue that are no longer in the playlist
self._clean_play_next_queue(sounds)
# Find the current track in the new playlist
new_index = self._find_sound_index(previous_sound_id, sounds)
@@ -431,7 +678,9 @@ class PlayerService:
self._clear_current_track()
def _update_playlist_state(
self, current_playlist: Playlist, sounds: list[Sound],
self,
current_playlist: Playlist,
sounds: list[Sound],
) -> None:
"""Update basic playlist state information."""
self.state.playlist_id = current_playlist.id
@@ -447,6 +696,29 @@ class PlayerService:
return i
return None
def _clean_play_next_queue(self, playlist_sounds: list[Sound]) -> None:
"""Remove tracks from play_next queue that are no longer in the playlist."""
if not self.state.play_next_queue:
return
# Get IDs of all sounds in the current playlist
playlist_sound_ids = {sound.id for sound in playlist_sounds}
# Filter out tracks that are no longer in the playlist
original_length = len(self.state.play_next_queue)
self.state.play_next_queue = [
sound
for sound in self.state.play_next_queue
if sound.id in playlist_sound_ids
]
removed_count = original_length - len(self.state.play_next_queue)
if removed_count > 0:
logger.info(
"Removed %s track(s) from play_next queue (no longer in playlist)",
removed_count,
)
def _set_first_track_as_current(self, sounds: list[Sound]) -> None:
"""Set the first track as the current track."""
self.state.current_sound_index = 0
@@ -463,7 +735,6 @@ class PlayerService:
"""Get current player state."""
return self.state.to_dict()
def _get_next_index(self, current_index: int) -> int | None:
"""Get next track index based on current mode."""
if not self.state.playlist_sounds:
@@ -496,11 +767,7 @@ class PlayerService:
prev_index = current_index - 1
if prev_index < 0:
return (
playlist_length - 1
if self.state.mode == PlayerMode.LOOP
else None
)
return playlist_length - 1 if self.state.mode == PlayerMode.LOOP else None
return prev_index
def _position_tracker(self) -> None:
@@ -516,15 +783,16 @@ class PlayerService:
# Check if track finished
player_state = self._player.get_state()
if hasattr(vlc, "State") and player_state == vlc.State.Ended:
vlc_state_ended = 6 # vlc.State.Ended value
if player_state == vlc_state_ended:
# Track finished, handle auto-advance
self._schedule_async_task(self._handle_track_finished())
# Update play time tracking
self._update_play_time()
# Broadcast state every 0.5 seconds while playing
broadcast_interval = 0.5
# Broadcast state every second while playing
broadcast_interval = 1
current_time = time.time()
if current_time - self._last_position_broadcast >= broadcast_interval:
self._last_position_broadcast = current_time
@@ -534,10 +802,7 @@ class PlayerService:
def _update_play_time(self) -> None:
"""Update play time tracking for current sound."""
if (
not self.state.current_sound_id
or self.state.status != PlayerStatus.PLAYING
):
if not self.state.current_sound_id or self.state.status != PlayerStatus.PLAYING:
return
sound_id = self.state.current_sound_id
@@ -576,10 +841,8 @@ class PlayerService:
sound_id,
tracking["total_time"],
self.state.current_sound_duration,
(
tracking["total_time"]
/ self.state.current_sound_duration
) * 100,
(tracking["total_time"] / self.state.current_sound_duration)
* 100,
)
self._schedule_async_task(self._record_play_count(sound_id))
@@ -595,7 +858,8 @@ class PlayerService:
if sound:
old_count = sound.play_count
await sound_repo.update(
sound, {"play_count": sound.play_count + 1},
sound,
{"play_count": sound.play_count + 1},
)
logger.info(
"Updated sound %s play_count: %s -> %s",
@@ -644,7 +908,12 @@ class PlayerService:
"""Handle when a track finishes playing."""
await self._process_play_count()
# Auto-advance to next track
# Check if there's a track in the play_next queue
if self.state.play_next_queue:
await self._play_next_from_queue()
return
# Auto-advance to next track in playlist
if self.state.current_sound_index is not None:
next_index = self._get_next_index(self.state.current_sound_index)
if next_index is not None:
@@ -652,6 +921,32 @@ class PlayerService:
else:
await self._stop_playback()
await self._broadcast_state()
elif (
self.state.playlist_sounds
and self.state.playlist_index_before_play_next is not None
):
# Current track was from play_next queue, restore to next track in playlist
restored_index = self.state.playlist_index_before_play_next
logger.info(
"Play next queue finished, continuing from playlist index %s",
restored_index,
)
# Get the next index based on the stored position
next_index = self._get_next_index(restored_index)
# Clear the stored index since we're done with play_next queue
self.state.playlist_index_before_play_next = None
if next_index is not None:
await self.play(next_index)
else:
# No next track (end of playlist in non-loop mode)
await self._stop_playback()
await self._broadcast_state()
else:
await self._stop_playback()
await self._broadcast_state()
async def _broadcast_state(self) -> None:
"""Broadcast current player state via WebSocket."""

View File

@@ -1,24 +1,59 @@
"""Playlist service for business logic operations."""
from typing import Any
from typing import Any, TypedDict
from fastapi import HTTPException, status
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.playlist import Playlist
from app.models.playlist_sound import PlaylistSound
from app.models.sound import Sound
from app.repositories.playlist import PlaylistRepository
from app.repositories.playlist import PlaylistRepository, PlaylistSortField, SortOrder
from app.repositories.sound import SoundRepository
from app.utils.exceptions import (
raise_bad_request,
raise_internal_server_error,
raise_not_found,
)
logger = get_logger(__name__)
class PaginatedPlaylistsResponse(TypedDict):
"""Response type for paginated playlists."""
playlists: list[dict]
total: int
page: int
limit: int
total_pages: int
async def _reload_player_playlist() -> None:
"""Reload the player playlist after current playlist changes."""
try:
# Import here to avoid circular import issues
from app.services.player import get_player_service # noqa: PLC0415
player = get_player_service()
await player.reload_playlist()
logger.debug("Player playlist reloaded after current playlist change")
except Exception: # noqa: BLE001
# Don't fail the playlist operation if player reload fails
logger.warning("Failed to reload player playlist", exc_info=True)
async def _is_current_playlist(session: AsyncSession, playlist_id: int) -> bool:
"""Check if the given playlist is the current playlist."""
try:
from app.repositories.playlist import PlaylistRepository # noqa: PLC0415
playlist_repo = PlaylistRepository(session)
current_playlist = await playlist_repo.get_current_playlist()
except Exception: # noqa: BLE001
logger.warning("Failed to check if playlist is current", exc_info=True)
return False
else:
return current_playlist is not None and current_playlist.id == playlist_id
class PlaylistService:
"""Service for playlist operations."""
@@ -28,11 +63,24 @@ class PlaylistService:
self.playlist_repo = PlaylistRepository(session)
self.sound_repo = SoundRepository(session)
async def _is_main_playlist(self, playlist_id: int) -> bool:
"""Check if the given playlist is the main playlist."""
try:
playlist = await self.playlist_repo.get_by_id(playlist_id)
except Exception:
logger.exception("Failed to check if playlist is main: %s", playlist_id)
return False
else:
return playlist is not None and playlist.is_main
async def get_playlist_by_id(self, playlist_id: int) -> Playlist:
"""Get a playlist by ID."""
playlist = await self.playlist_repo.get_by_id(playlist_id)
if not playlist:
raise_not_found("Playlist")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Playlist not found",
)
return playlist
@@ -49,27 +97,29 @@ class PlaylistService:
main_playlist = await self.playlist_repo.get_main_playlist()
if not main_playlist:
raise_internal_server_error(
"Main playlist not found. Make sure to run database seeding."
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Main playlist not found. Make sure to run database seeding.",
)
return main_playlist
async def get_current_playlist(self, user_id: int) -> Playlist:
"""Get the user's current playlist, fallback to main playlist."""
current_playlist = await self.playlist_repo.get_current_playlist(user_id)
async def get_current_playlist(self) -> Playlist:
"""Get the global current playlist, fallback to main playlist."""
current_playlist = await self.playlist_repo.get_current_playlist()
if current_playlist:
return current_playlist
# Fallback to main playlist if no current playlist is set
return await self.get_main_playlist()
async def create_playlist(
async def create_playlist( # noqa: PLR0913
self,
user_id: int,
name: str,
description: str | None = None,
genre: str | None = None,
*,
is_main: bool = False,
is_current: bool = False,
is_deletable: bool = True,
@@ -85,7 +135,7 @@ class PlaylistService:
# If this is set as current, unset the previous current playlist
if is_current:
await self._unset_current_playlist(user_id)
await self._unset_current_playlist()
playlist_data = {
"user_id": user_id,
@@ -99,18 +149,31 @@ class PlaylistService:
playlist = await self.playlist_repo.create(playlist_data)
logger.info("Created playlist '%s' for user %s", name, user_id)
# If this was set as current, reload player playlist
if is_current:
await _reload_player_playlist()
return playlist
async def update_playlist(
async def update_playlist( # noqa: PLR0913
self,
playlist_id: int,
user_id: int,
*,
name: str | None = None,
description: str | None = None,
genre: str | None = None,
is_current: bool | None = None,
) -> Playlist:
"""Update a playlist."""
# Check if this is the main playlist
if await self._is_main_playlist(playlist_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="The main playlist cannot be edited",
)
playlist = await self.get_playlist_by_id(playlist_id)
update_data: dict[str, Any] = {}
@@ -137,17 +200,28 @@ class PlaylistService:
if is_current is not None:
if is_current:
await self._unset_current_playlist(user_id)
await self._unset_current_playlist()
update_data["is_current"] = is_current
if update_data:
playlist = await self.playlist_repo.update(playlist, update_data)
logger.info("Updated playlist %s for user %s", playlist_id, user_id)
# If is_current was changed, reload player playlist
if "is_current" in update_data:
await _reload_player_playlist()
return playlist
async def delete_playlist(self, playlist_id: int, user_id: int) -> None:
"""Delete a playlist."""
# Check if this is the main playlist
if await self._is_main_playlist(playlist_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="The main playlist cannot be deleted",
)
playlist = await self.get_playlist_by_id(playlist_id)
if not playlist.is_deletable:
@@ -156,15 +230,27 @@ class PlaylistService:
detail="This playlist cannot be deleted",
)
# Check if this is the current playlist
# Check if this was the current playlist before deleting
was_current = playlist.is_current
# First, delete all playlist_sound relationships
await self._delete_playlist_sounds(playlist_id)
# Then delete the playlist itself
await self.playlist_repo.delete(playlist)
logger.info("Deleted playlist %s for user %s", playlist_id, user_id)
# If the deleted playlist was current, set main playlist as current
if was_current:
await self._set_main_as_current(user_id)
main_playlist = await self.get_main_playlist()
await self.playlist_repo.update(main_playlist, {"is_current": True})
logger.info(
"Set main playlist as current after deleting current playlist %s",
playlist_id,
)
# Reload player to reflect the change
await _reload_player_playlist()
async def search_playlists(self, query: str, user_id: int) -> list[Playlist]:
"""Search user's playlists by name."""
@@ -174,15 +260,91 @@ class PlaylistService:
"""Search all playlists by name."""
return await self.playlist_repo.search_by_name(query)
async def search_and_sort_playlists( # noqa: PLR0913
self,
search_query: str | None = None,
sort_by: PlaylistSortField | None = None,
sort_order: SortOrder = SortOrder.ASC,
user_id: int | None = None,
*,
include_stats: bool = False,
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False,
current_user_id: int | None = None,
) -> list[dict]:
"""Search and sort playlists with optional statistics."""
return await self.playlist_repo.search_and_sort(
search_query=search_query,
sort_by=sort_by,
sort_order=sort_order,
user_id=user_id,
include_stats=include_stats,
limit=limit,
offset=offset,
favorites_only=favorites_only,
current_user_id=current_user_id,
)
async def search_and_sort_playlists_paginated( # noqa: PLR0913
self,
search_query: str | None = None,
sort_by: PlaylistSortField | None = None,
sort_order: SortOrder = SortOrder.ASC,
user_id: int | None = None,
*,
include_stats: bool = False,
page: int = 1,
limit: int = 50,
favorites_only: bool = False,
current_user_id: int | None = None,
) -> PaginatedPlaylistsResponse:
"""Search and sort playlists with pagination."""
offset = (page - 1) * limit
playlists, total_count = await self.playlist_repo.search_and_sort(
search_query=search_query,
sort_by=sort_by,
sort_order=sort_order,
user_id=user_id,
include_stats=include_stats,
limit=limit,
offset=offset,
favorites_only=favorites_only,
current_user_id=current_user_id,
return_count=True,
)
total_pages = (total_count + limit - 1) // limit # Ceiling division
return PaginatedPlaylistsResponse(
playlists=playlists,
total=total_count,
page=page,
limit=limit,
total_pages=total_pages,
)
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
"""Get all sounds in a playlist."""
await self.get_playlist_by_id(playlist_id) # Verify playlist exists
return await self.playlist_repo.get_playlist_sounds(playlist_id)
async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, user_id: int, position: int | None = None
self,
playlist_id: int,
sound_id: int,
user_id: int,
position: int | None = None,
) -> None:
"""Add a sound to a playlist."""
# Check if this is the main playlist
if await self._is_main_playlist(playlist_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Sounds cannot be added to the main playlist",
)
# Verify playlist exists
await self.get_playlist_by_id(playlist_id)
@@ -201,15 +363,43 @@ class PlaylistService:
detail="Sound is already in this playlist",
)
# If position is None or beyond current positions, place at the end
if position is None:
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
position = len(current_sounds)
else:
# Ensure position doesn't create gaps - if too high, place at end
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
max_position = len(current_sounds)
position = min(position, max_position)
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
logger.info(
"Added sound %s to playlist %s for user %s", sound_id, playlist_id, user_id
"Added sound %s to playlist %s for user %s at position %s",
sound_id,
playlist_id,
user_id,
position,
)
# If this is the current playlist, reload player
if await _is_current_playlist(self.session, playlist_id):
await _reload_player_playlist()
async def remove_sound_from_playlist(
self, playlist_id: int, sound_id: int, user_id: int
self,
playlist_id: int,
sound_id: int,
user_id: int,
) -> None:
"""Remove a sound from a playlist."""
# Check if this is the main playlist
if await self._is_main_playlist(playlist_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Sounds cannot be removed from the main playlist",
)
# Verify playlist exists
await self.get_playlist_by_id(playlist_id)
@@ -221,15 +411,41 @@ class PlaylistService:
)
await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id)
# Reorder remaining sounds to eliminate gaps
await self._reorder_playlist_positions(playlist_id)
logger.info(
"Removed sound %s from playlist %s for user %s",
"Removed sound %s from playlist %s for user %s and reordered positions",
sound_id,
playlist_id,
user_id,
)
# If this is the current playlist, reload player
if await _is_current_playlist(self.session, playlist_id):
await _reload_player_playlist()
async def _reorder_playlist_positions(self, playlist_id: int) -> None:
"""Reorder all sounds in a playlist to eliminate position gaps."""
sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
if not sounds:
return
# Create sequential positions: 0, 1, 2, 3...
sound_positions = [(sound.id, index) for index, sound in enumerate(sounds)]
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)
logger.debug(
"Reordered %s sounds in playlist %s to eliminate gaps",
len(sounds),
playlist_id,
)
async def reorder_playlist_sounds(
self, playlist_id: int, user_id: int, sound_positions: list[tuple[int, int]]
self,
playlist_id: int,
user_id: int,
sound_positions: list[tuple[int, int]],
) -> None:
"""Reorder sounds in a playlist."""
# Verify playlist exists
@@ -246,25 +462,9 @@ class PlaylistService:
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)
logger.info("Reordered sounds in playlist %s for user %s", playlist_id, user_id)
async def set_current_playlist(self, playlist_id: int, user_id: int) -> Playlist:
"""Set a playlist as the current playlist."""
playlist = await self.get_playlist_by_id(playlist_id)
# Unset previous current playlist
await self._unset_current_playlist(user_id)
# Set new current playlist
playlist = await self.playlist_repo.update(playlist, {"is_current": True})
logger.info("Set playlist %s as current for user %s", playlist_id, user_id)
return playlist
async def unset_current_playlist(self, user_id: int) -> None:
"""Unset the current playlist and set main playlist as current."""
await self._unset_current_playlist(user_id)
await self._set_main_as_current(user_id)
logger.info(
"Unset current playlist and set main as current for user %s", user_id
)
# If this is the current playlist, reload player
if await _is_current_playlist(self.session, playlist_id):
await _reload_player_playlist()
async def get_playlist_stats(self, playlist_id: int) -> dict[str, Any]:
"""Get statistics for a playlist."""
@@ -282,36 +482,97 @@ class PlaylistService:
"total_play_count": total_plays,
}
async def add_sound_to_main_playlist(self, sound_id: int, user_id: int) -> None:
async def add_sound_to_main_playlist(
self,
sound_id: int, # noqa: ARG002
user_id: int, # noqa: ARG002
) -> None:
"""Add a sound to the global main playlist."""
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Sounds cannot be added to the main playlist",
)
async def _add_sound_to_main_playlist_internal(
self,
sound_id: int,
user_id: int,
) -> None:
"""Add sound to main playlist bypassing restrictions.
This method is intended for internal system use only (e.g., extraction service).
It bypasses the main playlist modification restrictions.
"""
main_playlist = await self.get_main_playlist()
if main_playlist.id is None:
raise ValueError("Main playlist has no ID")
msg = "Main playlist has no ID, cannot add sound"
raise ValueError(msg)
# Extract ID before async operations to avoid session issues
main_playlist_id = main_playlist.id
# Check if sound is already in main playlist
if not await self.playlist_repo.is_sound_in_playlist(
main_playlist.id, sound_id
main_playlist_id,
sound_id,
):
await self.playlist_repo.add_sound_to_playlist(main_playlist.id, sound_id)
await self.playlist_repo.add_sound_to_playlist(main_playlist_id, sound_id)
logger.info(
"Added sound %s to main playlist for user %s",
"Added sound %s to main playlist for user %s (internal)",
sound_id,
user_id,
)
async def _unset_current_playlist(self, user_id: int) -> None:
"""Unset the current playlist for a user."""
current_playlist = await self.playlist_repo.get_current_playlist(user_id)
# If main playlist is current, reload player
if await _is_current_playlist(self.session, main_playlist_id):
await _reload_player_playlist()
# Current playlist methods (global by default)
async def set_current_playlist(self, playlist_id: int) -> Playlist:
"""Set a playlist as the current playlist (app-wide)."""
playlist = await self.get_playlist_by_id(playlist_id)
# Unset any existing current playlist globally
await self._unset_current_playlist()
# Set new current playlist
playlist = await self.playlist_repo.update(playlist, {"is_current": True})
logger.info("Set playlist %s as current playlist", playlist_id)
# Reload player playlist to reflect the change
await _reload_player_playlist()
return playlist
async def unset_current_playlist(self) -> None:
"""Unset the current playlist (main playlist becomes fallback)."""
await self._unset_current_playlist()
logger.info("Unset current playlist, main playlist is now fallback")
# Reload player playlist to reflect the change (will fallback to main)
await _reload_player_playlist()
async def _delete_playlist_sounds(self, playlist_id: int) -> None:
"""Delete all playlist_sound records for a given playlist."""
# Get all playlist_sound records for this playlist
stmt = select(PlaylistSound).where(PlaylistSound.playlist_id == playlist_id)
result = await self.session.exec(stmt)
playlist_sounds = result.all()
# Delete each playlist_sound record
for playlist_sound in playlist_sounds:
await self.session.delete(playlist_sound)
await self.session.commit()
logger.info(
"Deleted %d playlist_sound records for playlist %s",
len(playlist_sounds),
playlist_id,
)
async def _unset_current_playlist(self) -> None:
"""Unset any current playlist globally."""
current_playlist = await self.playlist_repo.get_current_playlist()
if current_playlist:
await self.playlist_repo.update(current_playlist, {"is_current": False})
async def _set_main_as_current(self, user_id: int) -> None:
"""Unset current playlist so main playlist becomes the fallback current."""
# Just ensure no user playlist is marked as current
# The get_current_playlist method will fallback to main playlist
await self._unset_current_playlist(user_id)
logger.info(
"Unset current playlist for user %s, main playlist is now fallback",
user_id,
)

490
app/services/scheduler.py Normal file
View File

@@ -0,0 +1,490 @@
"""Enhanced scheduler service for flexible task scheduling with timezone support."""
from collections.abc import Callable
from contextlib import suppress
from datetime import UTC, datetime, timedelta
import pytz
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.date import DateTrigger
from apscheduler.triggers.interval import IntervalTrigger
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.scheduled_task import ScheduledTaskRepository
from app.schemas.scheduler import ScheduledTaskCreate
from app.services.credit import CreditService
from app.services.player import PlayerService
from app.services.task_handlers import TaskHandlerRegistry
logger = get_logger(__name__)
class SchedulerService:
"""Enhanced service for managing scheduled tasks with timezone support."""
def __init__(
self,
db_session_factory: Callable[[], AsyncSession],
player_service: PlayerService,
) -> None:
"""Initialize the scheduler service.
Args:
db_session_factory: Factory function to create database sessions
player_service: Player service for audio playback tasks
"""
self.db_session_factory = db_session_factory
self.scheduler = AsyncIOScheduler(timezone=pytz.UTC)
self.credit_service = CreditService(db_session_factory)
self.player_service = player_service
self._running_tasks: set[str] = set()
async def start(self) -> None:
"""Start the scheduler and load all active tasks."""
logger.info("Starting enhanced scheduler service...")
self.scheduler.start()
# Schedule system tasks initialization for after startup
self.scheduler.add_job(
self._initialize_system_tasks,
"date",
run_date=datetime.now(tz=UTC) + timedelta(seconds=2),
id="initialize_system_tasks",
name="Initialize System Tasks",
replace_existing=True,
)
# Schedule periodic cleanup and maintenance
self.scheduler.add_job(
self._maintenance_job,
"interval",
minutes=5,
id="scheduler_maintenance",
name="Scheduler Maintenance",
replace_existing=True,
)
logger.info("Enhanced scheduler service started successfully")
async def stop(self) -> None:
"""Stop the scheduler."""
logger.info("Stopping scheduler service...")
self.scheduler.shutdown(wait=True)
logger.info("Scheduler service stopped")
async def create_task(
self,
task_data: ScheduledTaskCreate,
user_id: int | None = None,
) -> ScheduledTask:
"""Create a new scheduled task from schema data."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Convert scheduled_at to UTC if it's in a different timezone
scheduled_at = task_data.scheduled_at
if task_data.timezone != "UTC":
tz = pytz.timezone(task_data.timezone)
if scheduled_at.tzinfo is None:
# Assume the datetime is in the specified timezone
scheduled_at = tz.localize(scheduled_at)
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
db_task_data = {
"name": task_data.name,
"task_type": task_data.task_type,
"scheduled_at": scheduled_at,
"timezone": task_data.timezone,
"parameters": task_data.parameters,
"user_id": user_id,
"recurrence_type": task_data.recurrence_type,
"cron_expression": task_data.cron_expression,
"recurrence_count": task_data.recurrence_count,
"expires_at": task_data.expires_at,
}
created_task = await repo.create(db_task_data)
await self._schedule_apscheduler_job(created_task)
logger.info(
"Created scheduled task: %s (%s)",
created_task.name,
created_task.id,
)
return created_task
async def cancel_task(self, task_id: int) -> bool:
"""Cancel a scheduled task."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
task = await repo.get_by_id(task_id)
if not task:
return False
await repo.update(task, {
"status": TaskStatus.CANCELLED,
"is_active": False,
})
# Remove from APScheduler (job might not exist in scheduler)
with suppress(Exception):
self.scheduler.remove_job(str(task_id))
logger.info("Cancelled task: %s (%s)", task.name, task_id)
return True
async def delete_task(self, task_id: int) -> bool:
"""Delete a scheduled task completely."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
task = await repo.get_by_id(task_id)
if not task:
return False
# Remove from APScheduler first (job might not exist in scheduler)
with suppress(Exception):
self.scheduler.remove_job(str(task_id))
# Delete from database
await repo.delete(task)
logger.info("Deleted task: %s (%s)", task.name, task_id)
return True
async def get_user_tasks(
self,
user_id: int,
status: TaskStatus | None = None,
task_type: TaskType | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[ScheduledTask]:
"""Get tasks for a specific user."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
async def _initialize_system_tasks(self) -> None:
"""Initialize system tasks and load active tasks from database."""
logger.info("Initializing system tasks...")
try:
# Create system tasks if they don't exist
await self._ensure_system_tasks()
# Load all active tasks from database
await self._load_active_tasks()
logger.info("System tasks initialized successfully")
except Exception:
logger.exception("Failed to initialize system tasks")
async def _ensure_system_tasks(self) -> None:
"""Ensure required system tasks exist."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Check if daily credit recharge task exists
system_tasks = await repo.get_system_tasks(
task_type=TaskType.CREDIT_RECHARGE,
)
daily_recharge_exists = any(
task.recurrence_type == RecurrenceType.DAILY
and task.is_active
for task in system_tasks
)
if not daily_recharge_exists:
# Create daily credit recharge task
tomorrow_midnight = datetime.now(tz=UTC).replace(
hour=0, minute=0, second=0, microsecond=0,
) + timedelta(days=1)
task_data = {
"name": "Daily Credit Recharge",
"task_type": TaskType.CREDIT_RECHARGE,
"scheduled_at": tomorrow_midnight,
"recurrence_type": RecurrenceType.DAILY,
"parameters": {},
}
await repo.create(task_data)
logger.info("Created system daily credit recharge task")
async def _load_active_tasks(self) -> None:
"""Load all active tasks from database into scheduler."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
active_tasks = await repo.get_active_tasks()
for task in active_tasks:
await self._schedule_apscheduler_job(task)
logger.info("Loaded %s active tasks into scheduler", len(active_tasks))
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
"""Schedule a task in APScheduler."""
job_id = str(task.id)
# Remove existing job if it exists
with suppress(Exception):
self.scheduler.remove_job(job_id)
# Don't schedule if task is not active or already completed/failed
inactive_statuses = [
TaskStatus.COMPLETED,
TaskStatus.FAILED,
TaskStatus.CANCELLED,
]
if not task.is_active or task.status in inactive_statuses:
return
# Create trigger based on recurrence type
trigger = self._create_trigger(task)
if not trigger:
logger.warning("Could not create trigger for task %s", task.id)
return
# Schedule the job
self.scheduler.add_job(
self._execute_task,
trigger=trigger,
args=[task.id],
id=job_id,
name=task.name,
replace_existing=True,
)
logger.debug("Scheduled APScheduler job for task %s", task.id)
def _create_trigger(
self, task: ScheduledTask,
) -> DateTrigger | IntervalTrigger | CronTrigger | None:
"""Create APScheduler trigger based on task configuration."""
tz = pytz.timezone(task.timezone)
scheduled_time = task.scheduled_at
# Handle special cases first
if task.recurrence_type == RecurrenceType.NONE:
return DateTrigger(run_date=scheduled_time, timezone=tz)
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
# Handle interval-based recurrence types
interval_configs = {
RecurrenceType.MINUTELY: {"minutes": 1},
RecurrenceType.HOURLY: {"hours": 1},
RecurrenceType.DAILY: {"days": 1},
RecurrenceType.WEEKLY: {"weeks": 1},
}
if task.recurrence_type in interval_configs:
config = interval_configs[task.recurrence_type]
return IntervalTrigger(start_date=scheduled_time, timezone=tz, **config)
# Handle cron-based recurrence types
cron_configs = {
RecurrenceType.MONTHLY: {
"day": scheduled_time.day,
"hour": scheduled_time.hour,
"minute": scheduled_time.minute,
},
RecurrenceType.YEARLY: {
"month": scheduled_time.month,
"day": scheduled_time.day,
"hour": scheduled_time.hour,
"minute": scheduled_time.minute,
},
}
if task.recurrence_type in cron_configs:
config = cron_configs[task.recurrence_type]
return CronTrigger(timezone=tz, **config)
return None
async def _execute_task(self, task_id: int) -> None:
"""Execute a scheduled task."""
task_id_str = str(task_id)
logger.info("APScheduler triggered task %s execution", task_id)
# Prevent concurrent execution of the same task
if task_id_str in self._running_tasks:
logger.warning("Task %s is already running, skipping execution", task_id)
return
self._running_tasks.add(task_id_str)
try:
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Get fresh task data
task = await repo.get_by_id(task_id)
if not task:
logger.warning("Task %s not found", task_id)
return
logger.info(
"Task %s current state - status: %s, is_active: %s, executions: %s",
task_id, task.status, task.is_active, task.executions_count,
)
# Check if task is still active and pending
if not task.is_active or task.status != TaskStatus.PENDING:
logger.warning(
"Task %s execution skipped - is_active: %s, status: %s "
"(should be %s)",
task_id,
task.is_active,
task.status,
TaskStatus.PENDING,
)
return
# Check if task has expired
if task.is_expired():
logger.info("Task %s has expired, marking as cancelled", task_id)
await repo.update(task, {
"status": TaskStatus.CANCELLED,
"is_active": False,
})
return
# Mark task as running
logger.info(
"Task %s starting execution (type: %s)",
task_id,
task.recurrence_type,
)
await repo.mark_as_running(task)
# Execute the task
try:
handler_registry = TaskHandlerRegistry(
session,
self.db_session_factory,
self.credit_service,
self.player_service,
)
await handler_registry.execute_task(task)
# Handle completion based on task type
if task.recurrence_type == RecurrenceType.CRON:
# For CRON tasks, update execution metadata but keep PENDING
# APScheduler handles the recurring schedule automatically
logger.info(
"Task %s (CRON) executed successfully, updating metadata",
task_id,
)
task.last_executed_at = datetime.now(tz=UTC)
task.executions_count += 1
task.error_message = None
task.status = TaskStatus.PENDING # Explicitly set to PENDING
session.add(task)
await session.commit()
logger.info(
"Task %s (CRON) metadata updated, status: %s, "
"executions: %s",
task_id,
task.status,
task.executions_count,
)
else:
# For non-CRON recurring tasks, calculate next execution
next_execution_at = None
if task.should_repeat():
next_execution_at = self._calculate_next_execution(task)
# Mark as completed
await repo.mark_as_completed(task, next_execution_at)
# Reschedule if recurring
if next_execution_at and task.should_repeat():
# Refresh task to get updated data
await session.refresh(task)
await self._schedule_apscheduler_job(task)
except Exception as e:
await repo.mark_as_failed(task, str(e))
logger.exception("Task %s execution failed", task_id)
finally:
self._running_tasks.discard(task_id_str)
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
"""Calculate the next execution time for a recurring task."""
now = datetime.now(tz=UTC)
recurrence_deltas = {
RecurrenceType.MINUTELY: timedelta(minutes=1),
RecurrenceType.HOURLY: timedelta(hours=1),
RecurrenceType.DAILY: timedelta(days=1),
RecurrenceType.WEEKLY: timedelta(weeks=1),
RecurrenceType.MONTHLY: timedelta(days=30), # Approximate
RecurrenceType.YEARLY: timedelta(days=365), # Approximate
}
if task.recurrence_type in recurrence_deltas:
return now + recurrence_deltas[task.recurrence_type]
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
# For CRON tasks, let APScheduler handle the timing
return now
return None
async def _maintenance_job(self) -> None:
"""Periodic maintenance job to clean up expired tasks and handle scheduling."""
try:
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Handle expired tasks
expired_tasks = await repo.get_expired_tasks()
for task in expired_tasks:
await repo.update(task, {
"status": TaskStatus.CANCELLED,
"is_active": False,
})
# Remove from scheduler
with suppress(Exception):
self.scheduler.remove_job(str(task.id))
if expired_tasks:
logger.info("Cleaned up %s expired tasks", len(expired_tasks))
# Handle any missed recurring tasks
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
for task in due_recurring:
if task.should_repeat():
next_scheduled_at = (
task.next_execution_at or datetime.now(tz=UTC)
)
await repo.update(task, {
"status": TaskStatus.PENDING,
"scheduled_at": next_scheduled_at,
})
await self._schedule_apscheduler_job(task)
if due_recurring:
logger.info("Rescheduled %s recurring tasks", len(due_recurring))
except Exception:
logger.exception("Maintenance job failed")

View File

@@ -4,6 +4,7 @@ import logging
import socketio
from app.core.config import settings
from app.utils.auth import JWTUtils
from app.utils.cookies import extract_access_token_from_cookies
@@ -13,9 +14,10 @@ logger = logging.getLogger(__name__)
class SocketManager:
"""Manages WebSocket connections and user rooms."""
def __init__(self):
def __init__(self) -> None:
"""Initialize the SocketManager with a Socket.IO server."""
self.sio = socketio.AsyncServer(
cors_allowed_origins=["http://localhost:8001"],
cors_allowed_origins=settings.CORS_ORIGINS,
logger=True,
engineio_logger=True,
async_mode="asgi",
@@ -27,20 +29,20 @@ class SocketManager:
self._setup_handlers()
def _setup_handlers(self):
def _setup_handlers(self) -> None:
"""Set up socket event handlers."""
@self.sio.event
async def connect(sid, environ, auth=None):
async def connect(sid: str, environ: dict) -> None:
"""Handle client connection."""
logger.info(f"Client {sid} attempting to connect")
logger.info("Client %s attempting to connect", sid)
# Extract access token from cookies
cookie_header = environ.get("HTTP_COOKIE", "")
access_token = extract_access_token_from_cookies(cookie_header)
if not access_token:
logger.warning(f"Client {sid} connecting without access token")
logger.warning("Client %s connecting without access token", sid)
await self.sio.disconnect(sid)
return
@@ -50,13 +52,13 @@ class SocketManager:
user_id = payload.get("sub")
if not user_id:
logger.warning(f"Client {sid} token missing user ID")
logger.warning("Client %s token missing user ID", sid)
await self.sio.disconnect(sid)
return
logger.info(f"User {user_id} connected with socket {sid}")
except Exception as e:
logger.warning(f"Client {sid} invalid token: {e}")
logger.info("User %s connected with socket %s", user_id, sid)
except Exception:
logger.exception("Client %s invalid token", sid)
await self.sio.disconnect(sid)
return
@@ -70,7 +72,7 @@ class SocketManager:
# Update room tracking
self.user_rooms[user_id] = room_id
logger.info(f"User {user_id} joined room {room_id}")
logger.info("User %s joined room %s", user_id, room_id)
# Send welcome message to user
await self.sio.emit(
@@ -84,33 +86,78 @@ class SocketManager:
)
@self.sio.event
async def disconnect(sid):
async def disconnect(sid: str) -> None:
"""Handle client disconnection."""
user_id = self.socket_users.get(sid)
if user_id:
logger.info(f"User {user_id} disconnected (socket {sid})")
logger.info("User %s disconnected (socket %s)", user_id, sid)
# Clean up mappings
del self.socket_users[sid]
if user_id in self.user_rooms:
del self.user_rooms[user_id]
else:
logger.info(f"Unknown client {sid} disconnected")
logger.info("Unknown client %s disconnected", sid)
async def send_to_user(self, user_id: str, event: str, data: dict):
@self.sio.event
async def play_sound(sid: str, data: dict) -> None:
"""Handle play sound event from client."""
await self._handle_play_sound(sid, data)
async def _handle_play_sound(self, sid: str, data: dict) -> None:
"""Handle play sound request from WebSocket client."""
user_id = self.socket_users.get(sid)
if not user_id:
logger.warning("Play sound request from unknown client %s", sid)
return
sound_id = data.get("sound_id")
if not sound_id:
logger.warning(
"Play sound request missing sound_id from user %s",
user_id,
)
return
try:
# Import here to avoid circular imports
from app.core.database import get_session_factory # noqa: PLC0415
from app.services.vlc_player import get_vlc_player_service # noqa: PLC0415
# Get VLC player service with database factory
vlc_player = get_vlc_player_service(get_session_factory())
# Call the service method
await vlc_player.play_sound_with_credits(int(sound_id), int(user_id))
logger.info("User %s played sound %s via WebSocket", user_id, sound_id)
except Exception as e:
logger.exception(
"Error playing sound %s for user %s",
sound_id,
user_id,
)
# Emit error back to user
await self.sio.emit(
"sound_play_error",
{"sound_id": sound_id, "error": str(e)},
room=sid,
)
async def send_to_user(self, user_id: str, event: str, data: dict) -> bool:
"""Send a message to a specific user's room."""
room_id = self.user_rooms.get(user_id)
if room_id:
await self.sio.emit(event, data, room=room_id)
logger.debug(f"Sent {event} to user {user_id} in room {room_id}")
logger.debug("Sent %s to user %s in room %s", event, user_id, room_id)
return True
logger.warning(f"User {user_id} not found in any room")
logger.warning("User %s not found in any room", user_id)
return False
async def broadcast_to_all(self, event: str, data: dict):
async def broadcast_to_all(self, event: str, data: dict) -> None:
"""Broadcast a message to all connected users."""
await self.sio.emit(event, data)
logger.info(f"Broadcasted {event} to all users")
logger.info("Broadcasted %s to all users", event)
def get_connected_users(self) -> list:
"""Get list of currently connected user IDs."""

View File

@@ -1,5 +1,6 @@
"""Sound normalizer service for normalizing audio files using ffmpeg loudnorm."""
import asyncio
import json
import os
import re
@@ -138,10 +139,15 @@ class SoundNormalizerService:
stream = ffmpeg.output(stream, str(output_path), **output_args)
stream = ffmpeg.overwrite_output(stream)
ffmpeg.run(stream, quiet=True, overwrite_output=True)
await asyncio.to_thread(
ffmpeg.run,
stream,
quiet=True,
overwrite_output=True,
)
logger.info("One-pass normalization completed: %s", output_path)
except Exception as e:
except Exception:
logger.exception("One-pass normalization failed for %s", input_path)
raise
@@ -153,7 +159,9 @@ class SoundNormalizerService:
"""Normalize audio using two-pass loudnorm for better quality."""
try:
logger.info(
"Starting two-pass normalization: %s -> %s", input_path, output_path
"Starting two-pass normalization: %s -> %s",
input_path,
output_path,
)
# First pass: analyze
@@ -174,10 +182,15 @@ class SoundNormalizerService:
# Run first pass and capture output
try:
result = ffmpeg.run(stream, capture_stderr=True, quiet=True)
result = await asyncio.to_thread(
ffmpeg.run,
stream,
capture_stderr=True,
quiet=True,
)
analysis_output = result[1].decode("utf-8")
except ffmpeg.Error as e:
logger.error(
logger.exception(
"FFmpeg first pass failed for %s. Stdout: %s, Stderr: %s",
input_path,
e.stdout.decode() if e.stdout else "None",
@@ -193,9 +206,11 @@ class SoundNormalizerService:
json_match = re.search(r'\{[^{}]*"input_i"[^{}]*\}', analysis_output)
if not json_match:
logger.error(
"Could not find JSON in loudnorm output: %s", analysis_output
"Could not find JSON in loudnorm output: %s",
analysis_output,
)
raise ValueError("Could not extract loudnorm analysis data")
msg = "Could not find JSON in loudnorm output"
raise ValueError(msg)
logger.debug("Found JSON match: %s", json_match.group())
analysis_data = json.loads(json_match.group())
@@ -211,7 +226,10 @@ class SoundNormalizerService:
]:
if str(analysis_data.get(key, "")).lower() in invalid_values:
logger.warning(
"Invalid analysis value for %s: %s. Falling back to one-pass normalization.",
(
"Invalid analysis value for %s: %s. "
"Falling back to one-pass normalization."
),
key,
analysis_data.get(key),
)
@@ -249,10 +267,15 @@ class SoundNormalizerService:
stream = ffmpeg.overwrite_output(stream)
try:
ffmpeg.run(stream, quiet=True, overwrite_output=True)
await asyncio.to_thread(
ffmpeg.run,
stream,
quiet=True,
overwrite_output=True,
)
logger.info("Two-pass normalization completed: %s", output_path)
except ffmpeg.Error as e:
logger.error(
logger.exception(
"FFmpeg second pass failed for %s. Stdout: %s, Stderr: %s",
input_path,
e.stdout.decode() if e.stdout else "None",
@@ -260,19 +283,21 @@ class SoundNormalizerService:
)
raise
except Exception as e:
except Exception:
logger.exception("Two-pass normalization failed for %s", input_path)
raise
async def normalize_sound(
self,
sound: Sound,
*,
force: bool = False,
one_pass: bool | None = None,
sound_data: dict | None = None,
) -> NormalizationInfo:
"""Normalize a single sound."""
# Use provided sound_data to avoid detached instance issues, or capture from sound
# Use provided sound_data to avoid detached instance issues,
# or capture from sound
if sound_data:
filename = sound_data["filename"]
sound_id = sound_data["id"]
@@ -391,6 +416,7 @@ class SoundNormalizerService:
async def normalize_all_sounds(
self,
*,
force: bool = False,
one_pass: bool | None = None,
) -> NormalizationResults:
@@ -409,7 +435,7 @@ class SoundNormalizerService:
if force:
# Get all sounds if forcing
sounds = []
for sound_type in self.type_directories.keys():
for sound_type in self.type_directories:
type_sounds = await self.sound_repo.get_by_type(sound_type)
sounds.extend(type_sounds)
else:
@@ -419,17 +445,16 @@ class SoundNormalizerService:
logger.info("Found %d sounds to process", len(sounds))
# Capture all sound data upfront to avoid session detachment issues
sound_data_list = []
for sound in sounds:
sound_data_list.append(
{
"id": sound.id,
"filename": sound.filename,
"type": sound.type,
"is_normalized": sound.is_normalized,
"name": sound.name,
}
)
sound_data_list = [
{
"id": sound.id,
"filename": sound.filename,
"type": sound.type,
"is_normalized": sound.is_normalized,
"name": sound.name,
}
for sound in sounds
]
# Process each sound using captured data
for i, sound in enumerate(sounds):
@@ -476,7 +501,7 @@ class SoundNormalizerService:
"normalized_hash": None,
"id": sound_id,
"error": str(e),
}
},
)
logger.info("Normalization completed: %s", results)
@@ -485,6 +510,7 @@ class SoundNormalizerService:
async def normalize_sounds_by_type(
self,
sound_type: str,
*,
force: bool = False,
one_pass: bool | None = None,
) -> NormalizationResults:
@@ -508,17 +534,16 @@ class SoundNormalizerService:
logger.info("Found %d %s sounds to process", len(sounds), sound_type)
# Capture all sound data upfront to avoid session detachment issues
sound_data_list = []
for sound in sounds:
sound_data_list.append(
{
"id": sound.id,
"filename": sound.filename,
"type": sound.type,
"is_normalized": sound.is_normalized,
"name": sound.name,
}
)
sound_data_list = [
{
"id": sound.id,
"filename": sound.filename,
"type": sound.type,
"is_normalized": sound.is_normalized,
"name": sound.name,
}
for sound in sounds
]
# Process each sound using captured data
for i, sound in enumerate(sounds):
@@ -565,7 +590,7 @@ class SoundNormalizerService:
"normalized_hash": None,
"id": sound_id,
"error": str(e),
}
},
)
logger.info("Type normalization completed: %s", results)

View File

@@ -1,5 +1,6 @@
"""Sound scanner service for scanning and importing audio files."""
from dataclasses import dataclass
from pathlib import Path
from typing import TypedDict
@@ -13,6 +14,28 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
logger = get_logger(__name__)
@dataclass
class AudioFileInfo:
"""Data class for audio file metadata."""
filename: str
name: str
duration: int
size: int
file_hash: str
@dataclass
class SyncContext:
"""Context data for audio file synchronization."""
file_path: Path
sound_type: str
existing_sound_by_hash: dict | Sound | None
existing_sound_by_filename: dict | Sound | None
file_hash: str
class FileInfo(TypedDict):
"""Type definition for file information in scan results."""
@@ -35,6 +58,7 @@ class ScanResults(TypedDict):
updated: int
deleted: int
skipped: int
duplicates: int
errors: int
files: list[FileInfo]
@@ -56,6 +80,13 @@ class SoundScannerService:
".aac",
}
# Directory mappings for normalized files (matching sound_normalizer)
self.normalized_directories = {
"SDB": "sounds/normalized/soundboard",
"TTS": "sounds/normalized/text_to_speech",
"EXT": "sounds/normalized/extracted",
}
def extract_name_from_filename(self, filename: str) -> str:
"""Extract a clean name from filename."""
# Remove extension
@@ -65,6 +96,415 @@ class SoundScannerService:
# Capitalize words
return " ".join(word.capitalize() for word in name.split())
def _get_normalized_path(self, sound_type: str, filename: str) -> Path:
"""Get the normalized file path for a sound."""
directory = self.normalized_directories.get(
sound_type, "sounds/normalized/other",
)
return Path(directory) / filename
def _rename_normalized_file(
self, sound_type: str, old_filename: str, new_filename: str,
) -> bool:
"""Rename normalized file if exists. Returns True if renamed, else False."""
old_path = self._get_normalized_path(sound_type, old_filename)
new_path = self._get_normalized_path(sound_type, new_filename)
if old_path.exists():
try:
# Ensure the directory exists
new_path.parent.mkdir(parents=True, exist_ok=True)
old_path.rename(new_path)
logger.info("Renamed normalized file: %s -> %s", old_path, new_path)
except OSError:
logger.exception(
"Failed to rename normalized file %s -> %s",
old_path,
new_path,
)
return False
else:
return True
return False
def _delete_normalized_file(self, sound_type: str, filename: str) -> bool:
"""Delete normalized file if exists. Returns True if deleted, else False."""
normalized_path = self._get_normalized_path(sound_type, filename)
if normalized_path.exists():
try:
normalized_path.unlink()
logger.info("Deleted normalized file: %s", normalized_path)
except OSError:
logger.exception(
"Failed to delete normalized file %s", normalized_path,
)
return False
else:
return True
return False
def _extract_sound_attributes(self, sound_data: dict | Sound | None) -> dict:
"""Extract attributes from sound data (dict or Sound object)."""
if sound_data is None:
return {}
if isinstance(sound_data, dict):
return {
"filename": sound_data.get("filename"),
"name": sound_data.get("name"),
"duration": sound_data.get("duration"),
"size": sound_data.get("size"),
"id": sound_data.get("id"),
"object": sound_data.get("sound_object"),
"type": sound_data.get("type"),
"is_normalized": sound_data.get("is_normalized"),
"normalized_filename": sound_data.get("normalized_filename"),
}
# Sound object (for tests)
return {
"filename": sound_data.filename,
"name": sound_data.name,
"duration": sound_data.duration,
"size": sound_data.size,
"id": sound_data.id,
"object": sound_data,
"type": sound_data.type,
"is_normalized": sound_data.is_normalized,
"normalized_filename": sound_data.normalized_filename,
}
def _handle_unchanged_file(
self,
filename: str,
existing_attrs: dict,
results: ScanResults,
) -> None:
"""Handle unchanged file (same hash, same filename)."""
logger.debug("Sound unchanged: %s", filename)
results["skipped"] += 1
results["files"].append({
"filename": filename,
"status": "skipped",
"reason": "file unchanged",
"name": existing_attrs["name"],
"duration": existing_attrs["duration"],
"size": existing_attrs["size"],
"id": existing_attrs["id"],
"error": None,
"changes": None,
})
def _handle_duplicate_file(
self,
filename: str,
existing_filename: str,
file_hash: str,
existing_attrs: dict,
results: ScanResults,
) -> None:
"""Handle duplicate file (same hash, different filename)."""
logger.warning(
"Duplicate file detected: '%s' has same content as existing "
"'%s' (hash: %s). Skipping duplicate file.",
filename,
existing_filename,
file_hash[:8] + "...",
)
results["skipped"] += 1
results["duplicates"] += 1
results["files"].append({
"filename": filename,
"status": "skipped",
"reason": "duplicate content",
"name": existing_attrs["name"],
"duration": existing_attrs["duration"],
"size": existing_attrs["size"],
"id": existing_attrs["id"],
"error": None,
"changes": None,
})
async def _handle_file_rename(
self,
file_info: AudioFileInfo,
existing_attrs: dict,
results: ScanResults,
) -> None:
"""Handle file rename (same hash, different filename)."""
update_data = {
"filename": file_info.filename,
"name": file_info.name,
}
# If the sound has a normalized file, rename it too
if existing_attrs["is_normalized"] and existing_attrs["normalized_filename"]:
old_normalized_base = Path(existing_attrs["normalized_filename"]).name
new_normalized_base = (
Path(file_info.filename).stem
+ Path(existing_attrs["normalized_filename"]).suffix
)
renamed = self._rename_normalized_file(
existing_attrs["type"],
old_normalized_base,
new_normalized_base,
)
if renamed:
update_data["normalized_filename"] = new_normalized_base
logger.info(
"Renamed normalized file: %s -> %s",
old_normalized_base,
new_normalized_base,
)
await self.sound_repo.update(existing_attrs["object"], update_data)
logger.info(
"Detected rename: %s -> %s (ID: %s)",
existing_attrs["filename"],
file_info.filename,
existing_attrs["id"],
)
# Build changes list
changes = ["filename", "name"]
if "normalized_filename" in update_data:
changes.append("normalized_filename")
results["updated"] += 1
results["files"].append({
"filename": file_info.filename,
"status": "updated",
"reason": "file was renamed",
"name": file_info.name,
"duration": existing_attrs["duration"],
"size": existing_attrs["size"],
"id": existing_attrs["id"],
"error": None,
"changes": changes,
# Store old filename to prevent deletion
"old_filename": existing_attrs["filename"],
})
async def _handle_file_modification(
self,
file_info: AudioFileInfo,
existing_attrs: dict,
results: ScanResults,
) -> None:
"""Handle file modification (same filename, different hash)."""
update_data = {
"name": file_info.name,
"duration": file_info.duration,
"size": file_info.size,
"hash": file_info.file_hash,
}
await self.sound_repo.update(existing_attrs["object"], update_data)
logger.info(
"Updated modified sound: %s (ID: %s)",
file_info.name,
existing_attrs["id"],
)
results["updated"] += 1
results["files"].append({
"filename": file_info.filename,
"status": "updated",
"reason": "file was modified",
"name": file_info.name,
"duration": file_info.duration,
"size": file_info.size,
"id": existing_attrs["id"],
"error": None,
"changes": ["hash", "duration", "size", "name"],
})
async def _handle_new_file(
self,
file_info: AudioFileInfo,
sound_type: str,
results: ScanResults,
) -> None:
"""Handle new file (neither hash nor filename exists)."""
sound_data = {
"type": sound_type,
"name": file_info.name,
"filename": file_info.filename,
"duration": file_info.duration,
"size": file_info.size,
"hash": file_info.file_hash,
"is_deletable": False,
"is_music": False,
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id)
results["added"] += 1
results["files"].append({
"filename": file_info.filename,
"status": "added",
"reason": None,
"name": file_info.name,
"duration": file_info.duration,
"size": file_info.size,
"id": sound.id,
"error": None,
"changes": None,
})
async def _load_existing_sounds(self, sound_type: str) -> tuple[dict, dict]:
"""Load existing sounds and create lookup dictionaries."""
existing_sounds = await self.sound_repo.get_by_type(sound_type)
# Create lookup dictionaries with immediate attribute access
# to avoid session detachment
sounds_by_hash = {}
sounds_by_filename = {}
for sound in existing_sounds:
# Capture all attributes immediately while session is valid
sound_data = {
"id": sound.id,
"hash": sound.hash,
"filename": sound.filename,
"name": sound.name,
"duration": sound.duration,
"size": sound.size,
"type": sound.type,
"is_normalized": sound.is_normalized,
"normalized_filename": sound.normalized_filename,
"sound_object": sound, # Keep reference for database operations
}
sounds_by_hash[sound.hash] = sound_data
sounds_by_filename[sound.filename] = sound_data
return sounds_by_hash, sounds_by_filename
async def _process_audio_files(
self,
scan_path: Path,
sound_type: str,
sounds_by_hash: dict,
sounds_by_filename: dict,
results: ScanResults,
) -> set[str]:
"""Process all audio files in directory and return processed filenames."""
# Get all audio files from directory
audio_files = [
f
for f in scan_path.iterdir()
if f.is_file() and f.suffix.lower() in self.supported_extensions
]
# Process each file in directory
processed_filenames = set()
for file_path in audio_files:
results["scanned"] += 1
filename = file_path.name
processed_filenames.add(filename)
try:
# Calculate hash first to enable hash-based lookup
file_hash = get_file_hash(file_path)
existing_sound_by_hash = sounds_by_hash.get(file_hash)
existing_sound_by_filename = sounds_by_filename.get(filename)
# Create sync context
sync_context = SyncContext(
file_path=file_path,
sound_type=sound_type,
existing_sound_by_hash=existing_sound_by_hash,
existing_sound_by_filename=existing_sound_by_filename,
file_hash=file_hash,
)
await self._sync_audio_file(sync_context, results)
# Check if this was a rename and mark old filename as processed
if results["files"] and results["files"][-1].get("old_filename"):
old_filename = results["files"][-1]["old_filename"]
processed_filenames.add(old_filename)
logger.debug("Marked old filename as processed: %s", old_filename)
# Remove temporary tracking field from results
del results["files"][-1]["old_filename"]
except Exception as e:
logger.exception("Error processing file %s", file_path)
results["errors"] += 1
results["files"].append({
"filename": filename,
"status": "error",
"reason": None,
"name": None,
"duration": None,
"size": None,
"id": None,
"error": str(e),
"changes": None,
})
return processed_filenames
async def _delete_missing_sounds(
self,
sounds_by_filename: dict,
processed_filenames: set[str],
results: ScanResults,
) -> None:
"""Delete sounds that no longer exist in directory."""
for filename, sound_data in sounds_by_filename.items():
if filename not in processed_filenames:
# Attributes already captured in sound_data dictionary
sound_name = sound_data["name"]
sound_duration = sound_data["duration"]
sound_size = sound_data["size"]
sound_id = sound_data["id"]
sound_object = sound_data["sound_object"]
sound_type = sound_data["type"]
sound_is_normalized = sound_data["is_normalized"]
sound_normalized_filename = sound_data["normalized_filename"]
try:
# Delete the sound from database first
await self.sound_repo.delete(sound_object)
logger.info("Deleted sound no longer in directory: %s", filename)
# If the sound had a normalized file, delete it too
if sound_is_normalized and sound_normalized_filename:
normalized_base = Path(sound_normalized_filename).name
self._delete_normalized_file(sound_type, normalized_base)
results["deleted"] += 1
results["files"].append({
"filename": filename,
"status": "deleted",
"reason": "file no longer exists",
"name": sound_name,
"duration": sound_duration,
"size": sound_size,
"id": sound_id,
"error": None,
"changes": None,
})
except Exception as e:
logger.exception("Error deleting sound %s", filename)
results["errors"] += 1
results["files"].append({
"filename": filename,
"status": "error",
"reason": "failed to delete",
"name": sound_name,
"duration": sound_duration,
"size": sound_size,
"id": sound_id,
"error": str(e),
"changes": None,
})
async def scan_directory(
self,
directory_path: str,
@@ -87,185 +527,91 @@ class SoundScannerService:
"updated": 0,
"deleted": 0,
"skipped": 0,
"duplicates": 0,
"errors": 0,
"files": [],
}
logger.info("Starting sync of directory: %s", directory_path)
# Get all existing sounds of this type from database
existing_sounds = await self.sound_repo.get_by_type(sound_type)
sounds_by_filename = {sound.filename: sound for sound in existing_sounds}
# Load existing sounds from database
sounds_by_hash, sounds_by_filename = await self._load_existing_sounds(
sound_type,
)
# Get all audio files from directory
audio_files = [
f
for f in scan_path.iterdir()
if f.is_file() and f.suffix.lower() in self.supported_extensions
]
# Process each file in directory
processed_filenames = set()
for file_path in audio_files:
results["scanned"] += 1
filename = file_path.name
processed_filenames.add(filename)
try:
await self._sync_audio_file(
file_path,
sound_type,
sounds_by_filename.get(filename),
results,
)
except Exception as e:
logger.exception("Error processing file %s", file_path)
results["errors"] += 1
results["files"].append(
{
"filename": filename,
"status": "error",
"reason": None,
"name": None,
"duration": None,
"size": None,
"id": None,
"error": str(e),
"changes": None,
}
)
# Process audio files in directory
processed_filenames = await self._process_audio_files(
scan_path,
sound_type,
sounds_by_hash,
sounds_by_filename,
results,
)
# Delete sounds that no longer exist in directory
for filename, sound in sounds_by_filename.items():
if filename not in processed_filenames:
try:
await self.sound_repo.delete(sound)
logger.info("Deleted sound no longer in directory: %s", filename)
results["deleted"] += 1
results["files"].append(
{
"filename": filename,
"status": "deleted",
"reason": "file no longer exists",
"name": sound.name,
"duration": sound.duration,
"size": sound.size,
"id": sound.id,
"error": None,
"changes": None,
}
)
except Exception as e:
logger.exception("Error deleting sound %s", filename)
results["errors"] += 1
results["files"].append(
{
"filename": filename,
"status": "error",
"reason": "failed to delete",
"name": sound.name,
"duration": sound.duration,
"size": sound.size,
"id": sound.id,
"error": str(e),
"changes": None,
}
)
await self._delete_missing_sounds(
sounds_by_filename,
processed_filenames,
results,
)
logger.info("Sync completed: %s", results)
return results
async def _sync_audio_file(
self,
file_path: Path,
sound_type: str,
existing_sound: Sound | None,
sync_context: SyncContext,
results: ScanResults,
) -> None:
"""Sync a single audio file (add new or update existing)."""
filename = file_path.name
file_hash = get_file_hash(file_path)
duration = get_audio_duration(file_path)
size = get_file_size(file_path)
"""Sync a single audio file using hash-first identification strategy."""
filename = sync_context.file_path.name
duration = get_audio_duration(sync_context.file_path)
size = get_file_size(sync_context.file_path)
name = self.extract_name_from_filename(filename)
if existing_sound is None:
# Add new sound
sound_data = {
"type": sound_type,
"name": name,
"filename": filename,
"duration": duration,
"size": size,
"hash": file_hash,
"is_deletable": False,
"is_music": False,
"is_normalized": False,
"play_count": 0,
}
# Create file info object
file_info = AudioFileInfo(
filename=filename,
name=name,
duration=duration,
size=size,
file_hash=sync_context.file_hash,
)
sound = await self.sound_repo.create(sound_data)
logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id)
# Extract attributes from existing sounds
hash_attrs = self._extract_sound_attributes(sync_context.existing_sound_by_hash)
filename_attrs = self._extract_sound_attributes(
sync_context.existing_sound_by_filename,
)
results["added"] += 1
results["files"].append(
{
"filename": filename,
"status": "added",
"reason": None,
"name": name,
"duration": duration,
"size": size,
"id": sound.id,
"error": None,
"changes": None,
}
)
elif existing_sound.hash != file_hash:
# Update existing sound (file was modified)
update_data = {
"name": name,
"duration": duration,
"size": size,
"hash": file_hash,
}
await self.sound_repo.update(existing_sound, update_data)
logger.info("Updated modified sound: %s (ID: %s)", name, existing_sound.id)
results["updated"] += 1
results["files"].append(
{
"filename": filename,
"status": "updated",
"reason": "file was modified",
"name": name,
"duration": duration,
"size": size,
"id": existing_sound.id,
"error": None,
"changes": ["hash", "duration", "size", "name"],
}
)
# Hash-first identification strategy
if sync_context.existing_sound_by_hash is not None:
# Content exists in database (same hash)
if hash_attrs["filename"] == filename:
# Same hash, same filename - file unchanged
self._handle_unchanged_file(filename, hash_attrs, results)
else:
# Same hash, different filename - could be rename or duplicate
old_file_path = sync_context.file_path.parent / hash_attrs["filename"]
if old_file_path.exists():
# Both files exist with same hash - this is a duplicate
self._handle_duplicate_file(
filename,
hash_attrs["filename"],
sync_context.file_hash,
hash_attrs,
results,
)
else:
# Old file doesn't exist - this is a genuine rename
await self._handle_file_rename(file_info, hash_attrs, results)
elif sync_context.existing_sound_by_filename is not None:
# Same filename but different hash - file was modified
await self._handle_file_modification(file_info, filename_attrs, results)
else:
# File unchanged, skip
logger.debug("Sound unchanged: %s", filename)
results["skipped"] += 1
results["files"].append(
{
"filename": filename,
"status": "skipped",
"reason": "file unchanged",
"name": existing_sound.name,
"duration": existing_sound.duration,
"size": existing_sound.size,
"id": existing_sound.id,
"error": None,
"changes": None,
}
)
# New file - neither hash nor filename exists
await self._handle_new_file(file_info, sync_context.sound_type, results)
async def scan_soundboard_directory(self) -> ScanResults:
"""Sync the default soundboard directory."""

View File

@@ -0,0 +1,194 @@
"""Task execution handlers for different task types."""
from collections.abc import Callable
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.scheduled_task import ScheduledTask, TaskType
from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository
from app.services.credit import CreditService
from app.services.player import PlayerService
from app.services.vlc_player import VLCPlayerService
logger = get_logger(__name__)
class TaskExecutionError(Exception):
"""Exception raised when task execution fails."""
class TaskHandlerRegistry:
"""Registry for task execution handlers."""
def __init__(
self,
db_session: AsyncSession,
db_session_factory: Callable[[], AsyncSession],
credit_service: CreditService,
player_service: PlayerService,
) -> None:
"""Initialize the task handler registry."""
self.db_session = db_session
self.db_session_factory = db_session_factory
self.credit_service = credit_service
self.player_service = player_service
self.sound_repository = SoundRepository(db_session)
self.playlist_repository = PlaylistRepository(db_session)
# Register handlers
self._handlers = {
TaskType.CREDIT_RECHARGE: self._handle_credit_recharge,
TaskType.PLAY_SOUND: self._handle_play_sound,
TaskType.PLAY_PLAYLIST: self._handle_play_playlist,
}
async def execute_task(self, task: ScheduledTask) -> None:
"""Execute a task based on its type."""
handler = self._handlers.get(task.task_type)
if not handler:
msg = f"No handler registered for task type: {task.task_type}"
raise TaskExecutionError(msg)
logger.info(
"Executing task %s (%s): %s",
task.id,
task.task_type.value,
task.name,
)
try:
await handler(task)
logger.info("Task %s executed successfully", task.id)
except Exception as e:
logger.exception("Task %s execution failed", task.id)
msg = f"Task execution failed: {e!s}"
raise TaskExecutionError(msg) from e
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
"""Handle credit recharge task."""
parameters = task.parameters
user_id = parameters.get("user_id")
if user_id:
# Recharge specific user
try:
user_id_int = int(user_id)
except (ValueError, TypeError) as e:
msg = f"Invalid user_id format: {user_id}"
raise TaskExecutionError(msg) from e
transaction = await self.credit_service.recharge_user_credits_auto(
user_id_int,
)
if transaction:
logger.info(
"Recharged credits for user %s: %s credits added",
user_id,
transaction.amount,
)
else:
logger.info(
"No credits added for user %s (already at maximum)", user_id,
)
else:
# Recharge all users (system task)
stats = await self.credit_service.recharge_all_users_credits()
logger.info("Recharged credits for all users: %s", stats)
async def _handle_play_sound(self, task: ScheduledTask) -> None:
"""Handle play sound task."""
parameters = task.parameters
sound_id = parameters.get("sound_id")
if not sound_id:
msg = "sound_id parameter is required for PLAY_SOUND tasks"
raise TaskExecutionError(msg)
try:
# Handle both integer and string sound IDs
sound_id_int = int(sound_id)
except (ValueError, TypeError) as e:
msg = f"Invalid sound_id format: {sound_id}"
raise TaskExecutionError(msg) from e
# Check if this is a user task (has user_id)
if task.user_id:
# User task: use credit-aware playback
vlc_service = VLCPlayerService(self.db_session_factory)
try:
result = await vlc_service.play_sound_with_credits(
sound_id_int, task.user_id,
)
logger.info(
(
"Played sound %s via scheduled task for user %s "
"(credits deducted: %s)"
),
result.get("sound_name", sound_id),
task.user_id,
result.get("credits_deducted", 0),
)
except Exception as e:
# Convert HTTP exceptions or credit errors to task execution errors
msg = f"Failed to play sound with credits: {e!s}"
raise TaskExecutionError(msg) from e
else:
# System task: play without credit deduction
sound = await self.sound_repository.get_by_id(sound_id_int)
if not sound:
msg = f"Sound not found: {sound_id}"
raise TaskExecutionError(msg)
vlc_service = VLCPlayerService(self.db_session_factory)
success = await vlc_service.play_sound(sound)
if not success:
msg = f"Failed to play sound {sound.filename}"
raise TaskExecutionError(msg)
logger.info("Played sound %s via scheduled system task", sound.filename)
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
"""Handle play playlist task."""
parameters = task.parameters
playlist_id = parameters.get("playlist_id")
play_mode = parameters.get("play_mode", "continuous")
shuffle = parameters.get("shuffle", False)
if not playlist_id:
msg = "playlist_id parameter is required for PLAY_PLAYLIST tasks"
raise TaskExecutionError(msg)
try:
# Handle both integer and string playlist IDs
playlist_id_int = int(playlist_id)
except (ValueError, TypeError) as e:
msg = f"Invalid playlist_id format: {playlist_id}"
raise TaskExecutionError(msg) from e
# Get the playlist from database
playlist = await self.playlist_repository.get_by_id(playlist_id_int)
if not playlist:
msg = f"Playlist not found: {playlist_id}"
raise TaskExecutionError(msg)
# Load playlist in player
await self.player_service.load_playlist(playlist_id_int)
# Set play mode if specified
if play_mode in ["continuous", "loop", "loop_one", "random", "single"]:
await self.player_service.set_mode(play_mode)
# Enable shuffle if requested
if shuffle:
await self.player_service.set_shuffle(shuffle=True)
# Start playing
await self.player_service.play()
logger.info("Started playing playlist %s via scheduled task", playlist.name)

View File

@@ -0,0 +1,6 @@
"""Text-to-Speech services package."""
from .base import TTSProvider
from .service import TTSService
__all__ = ["TTSProvider", "TTSService"]

41
app/services/tts/base.py Normal file
View File

@@ -0,0 +1,41 @@
"""Base TTS provider interface."""
from abc import ABC, abstractmethod
# Type alias for TTS options
TTSOptions = dict[str, str | bool | int | float]
class TTSProvider(ABC):
"""Abstract base class for TTS providers."""
@abstractmethod
async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
"""Generate speech from text with provider-specific options.
Args:
text: The text to convert to speech
**options: Provider-specific options
Returns:
Audio data as bytes
"""
@abstractmethod
def get_supported_languages(self) -> list[str]:
"""Return list of supported language codes."""
@abstractmethod
def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
"""Return schema for provider-specific options."""
@property
@abstractmethod
def name(self) -> str:
"""Return the provider name."""
@property
@abstractmethod
def file_extension(self) -> str:
"""Return the default file extension for this provider."""

View File

@@ -0,0 +1,5 @@
"""TTS providers package."""
from .gtts import GTTSProvider
__all__ = ["GTTSProvider"]

View File

@@ -0,0 +1,80 @@
"""Google Text-to-Speech provider."""
import asyncio
import io
from gtts import gTTS
from app.services.tts.base import TTSProvider
class GTTSProvider(TTSProvider):
"""Google Text-to-Speech provider implementation."""
@property
def name(self) -> str:
"""Return the provider name."""
return "gtts"
@property
def file_extension(self) -> str:
"""Return the default file extension for this provider."""
return "mp3"
async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
"""Generate speech from text using Google TTS.
Args:
text: The text to convert to speech
**options: GTTS-specific options (lang, tld, slow)
Returns:
MP3 audio data as bytes
"""
lang = options.get("lang", "en")
tld = options.get("tld", "com")
slow = options.get("slow", False)
# Run TTS generation in thread pool since gTTS is synchronous
def _generate() -> bytes:
tts = gTTS(text=text, lang=lang, tld=tld, slow=slow)
fp = io.BytesIO()
tts.write_to_fp(fp)
fp.seek(0)
return fp.read()
# Use asyncio.to_thread which is more reliable than run_in_executor
return await asyncio.to_thread(_generate)
def get_supported_languages(self) -> list[str]:
"""Return list of supported language codes."""
# Complete list of GTTS supported languages including regional variants
return [
"af", "ar", "bg", "bn", "bs", "ca", "cs", "cy", "da", "de", "el",
"en", "en-au", "en-ca", "en-gb", "en-ie", "en-in", "en-ng", "en-nz",
"en-ph", "en-za", "en-tz", "en-uk", "en-us",
"eo", "es", "es-es", "es-mx", "es-us", "et", "eu", "fa", "fi",
"fr", "fr-ca", "fr-fr", "ga", "gu", "he", "hi", "hr", "hu", "hy",
"id", "is", "it", "ja", "jw", "ka", "kk", "km", "kn", "ko", "la",
"lv", "mk", "ml", "mr", "ms", "mt", "my", "ne", "nl", "no", "pa",
"pl", "pt", "pt-br", "pt-pt", "ro", "ru", "si", "sk", "sl", "sq",
"sr", "su", "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur",
"vi", "yo", "zh", "zh-cn", "zh-tw", "zu",
]
def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
"""Return schema for GTTS-specific options."""
return {
"lang": {
"type": "string",
"default": "en",
"description": "Language code",
"enum": self.get_supported_languages(),
},
"slow": {
"type": "boolean",
"default": False,
"description": "Speak slowly",
},
}

555
app/services/tts/service.py Normal file
View File

@@ -0,0 +1,555 @@
"""TTS service implementation."""
import asyncio
import io
import uuid
from pathlib import Path
from typing import Any
from gtts import gTTS
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_session_factory
from app.core.logging import get_logger
from app.models.sound import Sound
from app.models.tts import TTS
from app.repositories.sound import SoundRepository
from app.repositories.tts import TTSRepository
from app.services.socket import socket_manager
from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
from .base import TTSProvider
from .providers import GTTSProvider
# Constants
MAX_TEXT_LENGTH = 1000
MAX_NAME_LENGTH = 50
async def _get_tts_processor() -> object:
"""Get TTS processor instance, avoiding circular import."""
from app.services.tts_processor import tts_processor # noqa: PLC0415
return tts_processor
class TTSService:
"""Text-to-Speech service with provider management."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize TTS service.
Args:
session: Database session
"""
self.session = session
self.sound_repo = SoundRepository(session)
self.tts_repo = TTSRepository(session)
self.providers: dict[str, TTSProvider] = {}
# Register default providers
self._register_default_providers()
def _register_default_providers(self) -> None:
"""Register default TTS providers."""
self.register_provider(GTTSProvider())
def register_provider(self, provider: TTSProvider) -> None:
"""Register a TTS provider.
Args:
provider: TTS provider instance
"""
self.providers[provider.name] = provider
def get_providers(self) -> dict[str, TTSProvider]:
"""Get all registered providers."""
return self.providers.copy()
def get_provider(self, name: str) -> TTSProvider | None:
"""Get a specific provider by name."""
return self.providers.get(name)
async def create_tts_request(
self,
text: str,
user_id: int,
provider: str = "gtts",
**options: str | bool | float,
) -> dict[str, Any]:
"""Create a TTS request that will be processed in the background.
Args:
text: Text to convert to speech
user_id: ID of user creating the sound
provider: TTS provider name
**options: Provider-specific options
Returns:
Dictionary with TTS record information
Raises:
ValueError: If provider not found or text too long
Exception: If request creation fails
"""
provider_not_found_msg = f"Provider '{provider}' not found"
if provider not in self.providers:
raise ValueError(provider_not_found_msg)
text_too_long_msg = f"Text too long (max {MAX_TEXT_LENGTH} characters)"
if len(text) > MAX_TEXT_LENGTH:
raise ValueError(text_too_long_msg)
empty_text_msg = "Text cannot be empty"
if not text.strip():
raise ValueError(empty_text_msg)
# Create TTS record with pending status
tts = TTS(
text=text,
provider=provider,
options=options,
status="pending",
sound_id=None, # Will be set when processing completes
user_id=user_id,
)
self.session.add(tts)
await self.session.commit()
await self.session.refresh(tts)
# Queue for background processing using the TTS processor
if tts.id is not None:
tts_processor = await _get_tts_processor()
await tts_processor.queue_tts(tts.id)
return {"tts": tts, "message": "TTS generation queued successfully"}
async def _queue_tts_processing(self, tts_id: int) -> None:
"""Queue TTS for background processing."""
# For now, process immediately in a different way
# This could be moved to a proper background queue later
task = asyncio.create_task(self._process_tts_in_background(tts_id))
# Store reference to prevent garbage collection
self._background_tasks = getattr(self, "_background_tasks", set())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
async def _process_tts_in_background(self, tts_id: int) -> None:
"""Process TTS generation in background."""
try:
# Create a new session for background processing
session_factory = get_session_factory()
async with session_factory() as background_session:
tts_service = TTSService(background_session)
# Get the TTS record
stmt = select(TTS).where(TTS.id == tts_id)
result = await background_session.exec(stmt)
tts = result.first()
if not tts:
return
# Use a synchronous approach for the actual generation
sound = await tts_service._generate_tts_sync(
tts.text,
tts.provider,
tts.user_id,
tts.options,
)
# Update the TTS record with the sound ID
if sound.id is not None:
tts.sound_id = sound.id
background_session.add(tts)
await background_session.commit()
except Exception:
# Log error but don't fail - avoiding print for production
logger = get_logger(__name__)
logger.exception("Error processing TTS generation %s", tts_id)
async def _generate_tts_sync(
self,
text: str,
provider: str,
user_id: int,
options: dict[str, Any],
) -> Sound:
"""Generate TTS using a synchronous approach."""
# Generate the audio using the provider
# (avoid async issues by doing it directly)
tts_provider = self.providers[provider]
# Create directories if they don't exist
original_dir = Path("sounds/originals/text_to_speech")
original_dir.mkdir(parents=True, exist_ok=True)
# Create UUID filename
sound_uuid = str(uuid.uuid4())
original_filename = f"{sound_uuid}.{tts_provider.file_extension}"
original_path = original_dir / original_filename
# Generate audio synchronously
try:
# Generate TTS audio
lang = options.get("lang", "en")
tld = options.get("tld", "com")
slow = options.get("slow", False)
tts_instance = gTTS(text=text, lang=lang, tld=tld, slow=slow)
fp = io.BytesIO()
tts_instance.write_to_fp(fp)
fp.seek(0)
audio_bytes = fp.read()
# Save the file
original_path.write_bytes(audio_bytes)
except Exception:
logger = get_logger(__name__)
logger.exception("Error generating TTS audio")
raise
# Create Sound record with proper metadata
sound = await self._create_sound_record_complete(
original_path,
text,
user_id,
)
# Normalize the sound
if sound.id is not None:
await self._normalize_sound_safe(sound.id)
return sound
async def get_user_tts_history(
self,
user_id: int,
limit: int = 50,
offset: int = 0,
) -> list[TTS]:
"""Get TTS history for a user.
Args:
user_id: User ID
limit: Maximum number of records
offset: Offset for pagination
Returns:
List of TTS records
"""
result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
return list(result)
async def _create_sound_record(
self,
audio_path: Path,
text: str,
user_id: int,
file_hash: str,
) -> Sound:
"""Create a Sound record for the TTS audio."""
# Get audio metadata
duration = get_audio_duration(audio_path)
size = get_file_size(audio_path)
name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "")
name = " ".join(word.capitalize() for word in name.split())
# Create sound data
sound_data = {
"type": "TTS",
"name": name,
"filename": audio_path.name,
"duration": duration,
"size": size,
"hash": file_hash,
"user_id": user_id,
"is_deletable": True,
"is_music": False, # TTS is speech, not music
"is_normalized": False,
"play_count": 0,
}
return await self.sound_repo.create(sound_data)
async def _create_sound_record_simple(
self,
audio_path: Path,
text: str,
user_id: int,
) -> Sound:
"""Create a Sound record for the TTS audio with minimal processing."""
# Create sound data with basic info
name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "")
name = " ".join(word.capitalize() for word in name.split())
sound_data = {
"type": "TTS",
"name": name,
"filename": audio_path.name,
"duration": 0, # Skip duration calculation for now
"size": 0, # Skip size calculation for now
"hash": str(uuid.uuid4()), # Use UUID as temporary hash
"user_id": user_id,
"is_deletable": True,
"is_music": False, # TTS is speech, not music
"is_normalized": False,
"play_count": 0,
}
return await self.sound_repo.create(sound_data)
async def _create_sound_record_complete(
self,
audio_path: Path,
text: str,
user_id: int,
) -> Sound:
"""Create a Sound record for the TTS audio with complete metadata."""
# Get audio metadata
duration = get_audio_duration(audio_path)
size = get_file_size(audio_path)
file_hash = get_file_hash(audio_path)
name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "")
name = " ".join(word.capitalize() for word in name.split())
# Check if a sound with this hash already exists
existing_sound = await self.sound_repo.get_by_hash(file_hash)
if existing_sound:
# Clean up the temporary file since we have a duplicate
if audio_path.exists():
audio_path.unlink()
return existing_sound
# Create sound data with complete metadata
sound_data = {
"type": "TTS",
"name": name,
"filename": audio_path.name,
"duration": duration,
"size": size,
"hash": file_hash,
"user_id": user_id,
"is_deletable": True,
"is_music": False, # TTS is speech, not music
"is_normalized": False,
"play_count": 0,
}
return await self.sound_repo.create(sound_data)
async def _normalize_sound_safe(self, sound_id: int) -> None:
"""Normalize the TTS sound with error handling."""
try:
# Get fresh sound object from database for normalization
sound = await self.sound_repo.get_by_id(sound_id)
if not sound:
return
normalizer_service = SoundNormalizerService(self.session)
result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error":
logger = get_logger(__name__)
logger.warning(
"Warning: Failed to normalize TTS sound %s: %s",
sound_id,
result.get("error"),
)
except Exception:
logger = get_logger(__name__)
logger.exception("Exception during TTS sound normalization %s", sound_id)
# Don't fail the TTS generation if normalization fails
async def _normalize_sound(self, sound_id: int) -> None:
"""Normalize the TTS sound."""
try:
# Get fresh sound object from database for normalization
sound = await self.sound_repo.get_by_id(sound_id)
if not sound:
return
normalizer_service = SoundNormalizerService(self.session)
result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error":
# Log warning but don't fail the TTS generation
pass
except Exception:
# Don't fail the TTS generation if normalization fails
logger = get_logger(__name__)
logger.exception("Error normalizing sound %s", sound_id)
async def delete_tts(self, tts_id: int, user_id: int) -> None:
"""Delete a TTS generation and its associated sound and files."""
# Get the TTS record
tts = await self.tts_repo.get_by_id(tts_id)
if not tts:
tts_not_found_msg = f"TTS with ID {tts_id} not found"
raise ValueError(tts_not_found_msg)
# Check ownership
if tts.user_id != user_id:
permission_error_msg = (
"You don't have permission to delete this TTS generation"
)
raise PermissionError(permission_error_msg)
# If there's an associated sound, delete it and its files
if tts.sound_id:
sound = await self.sound_repo.get_by_id(tts.sound_id)
if sound:
# Delete the sound files
await self._delete_sound_files(sound)
# Delete the sound record
await self.sound_repo.delete(sound)
# Delete the TTS record
await self.tts_repo.delete(tts)
async def _delete_sound_files(self, sound: Sound) -> None:
"""Delete all files associated with a sound."""
# Delete original file
original_path = Path("sounds/originals/text_to_speech") / sound.filename
if original_path.exists():
original_path.unlink()
# Delete normalized file if it exists
if sound.normalized_filename:
normalized_path = (
Path("sounds/normalized/text_to_speech") / sound.normalized_filename
)
if normalized_path.exists():
normalized_path.unlink()
async def get_pending_tts(self) -> list[TTS]:
"""Get all pending TTS generations."""
stmt = select(TTS).where(TTS.status == "pending").order_by(TTS.created_at)
result = await self.session.exec(stmt)
return list(result.all())
async def mark_tts_processing(self, tts_id: int) -> None:
"""Mark a TTS generation as processing."""
stmt = select(TTS).where(TTS.id == tts_id)
result = await self.session.exec(stmt)
tts = result.first()
if tts:
tts.status = "processing"
self.session.add(tts)
await self.session.commit()
async def mark_tts_completed(self, tts_id: int, sound_id: int) -> None:
"""Mark a TTS generation as completed."""
stmt = select(TTS).where(TTS.id == tts_id)
result = await self.session.exec(stmt)
tts = result.first()
if tts:
tts.status = "completed"
tts.sound_id = sound_id
tts.error = None
self.session.add(tts)
await self.session.commit()
async def mark_tts_failed(self, tts_id: int, error_message: str) -> None:
"""Mark a TTS generation as failed."""
stmt = select(TTS).where(TTS.id == tts_id)
result = await self.session.exec(stmt)
tts = result.first()
if tts:
tts.status = "failed"
tts.error = error_message
self.session.add(tts)
await self.session.commit()
async def reset_stuck_tts(self) -> int:
"""Reset stuck TTS generations from processing back to pending."""
stmt = select(TTS).where(TTS.status == "processing")
result = await self.session.exec(stmt)
stuck_tts = list(result.all())
for tts in stuck_tts:
tts.status = "pending"
tts.error = None
self.session.add(tts)
await self.session.commit()
return len(stuck_tts)
async def process_tts_generation(self, tts_id: int) -> None:
"""Process a TTS generation (used by the processor)."""
# Mark as processing
await self.mark_tts_processing(tts_id)
try:
# Get the TTS record
stmt = select(TTS).where(TTS.id == tts_id)
result = await self.session.exec(stmt)
tts = result.first()
if not tts:
tts_not_found_msg = f"TTS with ID {tts_id} not found"
raise ValueError(tts_not_found_msg)
# Generate the TTS
sound = await self._generate_tts_sync(
tts.text,
tts.provider,
tts.user_id,
tts.options,
)
# Capture sound ID before session issues
sound_id = sound.id
if sound_id is None:
sound_creation_error = "Sound creation failed - no ID assigned"
raise ValueError(sound_creation_error)
# Mark as completed
await self.mark_tts_completed(tts_id, sound_id)
# Emit socket event for completion
await self._emit_tts_event("tts_completed", tts_id, sound_id)
except Exception as e:
# Mark as failed
await self.mark_tts_failed(tts_id, str(e))
# Emit socket event for failure
await self._emit_tts_event("tts_failed", tts_id, None, str(e))
raise
async def _emit_tts_event(
self,
event: str,
tts_id: int,
sound_id: int | None = None,
error: str | None = None,
) -> None:
"""Emit a socket event for TTS status change."""
try:
logger = get_logger(__name__)
data = {
"tts_id": tts_id,
"sound_id": sound_id,
}
if error:
data["error"] = error
logger.info("Emitting TTS socket event: %s with data: %s", event, data)
await socket_manager.broadcast_to_all(event, data)
logger.info("Successfully emitted TTS socket event: %s", event)
except Exception:
# Don't fail TTS processing if socket emission fails
logger = get_logger(__name__)
logger.exception("Failed to emit TTS socket event %s", event)

View File

@@ -0,0 +1,193 @@
"""Background TTS processor for handling TTS generation queue."""
import asyncio
import contextlib
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.database import engine
from app.core.logging import get_logger
from app.services.tts import TTSService
logger = get_logger(__name__)
class TTSProcessor:
"""Background processor for handling TTS generation queue."""
def __init__(self) -> None:
"""Initialize the TTS processor."""
self.max_concurrent = getattr(settings, "TTS_MAX_CONCURRENT", 3)
self.running_tts: set[int] = set()
self.processing_lock = asyncio.Lock()
self.shutdown_event = asyncio.Event()
self.processor_task: asyncio.Task | None = None
logger.info(
"Initialized TTS processor with max concurrent: %d",
self.max_concurrent,
)
async def start(self) -> None:
"""Start the background TTS processor."""
if self.processor_task and not self.processor_task.done():
logger.warning("TTS processor is already running")
return
# Reset any stuck TTS generations from previous runs
await self._reset_stuck_tts()
self.shutdown_event.clear()
self.processor_task = asyncio.create_task(self._process_queue())
logger.info("Started TTS processor")
async def stop(self) -> None:
"""Stop the background TTS processor."""
logger.info("Stopping TTS processor...")
self.shutdown_event.set()
if self.processor_task and not self.processor_task.done():
try:
await asyncio.wait_for(self.processor_task, timeout=30.0)
except TimeoutError:
logger.warning(
"TTS processor did not stop gracefully, cancelling...",
)
self.processor_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.processor_task
logger.info("TTS processor stopped")
async def queue_tts(self, tts_id: int) -> None:
"""Queue a TTS generation for processing."""
async with self.processing_lock:
if tts_id not in self.running_tts:
logger.info("Queued TTS %d for processing", tts_id)
# The processor will pick it up on the next cycle
else:
logger.warning(
"TTS %d is already being processed",
tts_id,
)
async def _process_queue(self) -> None:
"""Process the TTS queue in the main processing loop."""
logger.info("Starting TTS queue processor")
while not self.shutdown_event.is_set():
try:
await self._process_pending_tts()
# Wait before checking for new TTS generations
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
break # Shutdown requested
except TimeoutError:
continue # Continue processing
except Exception:
logger.exception("Error in TTS queue processor")
# Wait a bit before retrying to avoid tight error loops
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
break # Shutdown requested
except TimeoutError:
continue
logger.info("TTS queue processor stopped")
async def _process_pending_tts(self) -> None:
"""Process pending TTS generations up to the concurrency limit."""
async with self.processing_lock:
# Check how many slots are available
available_slots = self.max_concurrent - len(self.running_tts)
if available_slots <= 0:
return # No available slots
# Get pending TTS generations from database
async with AsyncSession(engine) as session:
tts_service = TTSService(session)
pending_tts = await tts_service.get_pending_tts()
# Filter out TTS that are already being processed
available_tts = [
tts
for tts in pending_tts
if tts.id not in self.running_tts
]
# Start processing up to available slots
tts_to_start = available_tts[:available_slots]
for tts in tts_to_start:
tts_id = tts.id
self.running_tts.add(tts_id)
# Start processing this TTS in the background
task = asyncio.create_task(
self._process_single_tts(tts_id),
)
task.add_done_callback(
lambda t, tid=tts_id: self._on_tts_completed(
tid,
t,
),
)
logger.info(
"Started processing TTS %d (%d/%d slots used)",
tts_id,
len(self.running_tts),
self.max_concurrent,
)
async def _process_single_tts(self, tts_id: int) -> None:
"""Process a single TTS generation."""
try:
async with AsyncSession(engine) as session:
tts_service = TTSService(session)
await tts_service.process_tts_generation(tts_id)
logger.info("Successfully processed TTS %d", tts_id)
except Exception:
logger.exception("Failed to process TTS %d", tts_id)
# Mark TTS as failed in database
try:
async with AsyncSession(engine) as session:
tts_service = TTSService(session)
await tts_service.mark_tts_failed(tts_id, "Processing failed")
except Exception:
logger.exception("Failed to mark TTS %d as failed", tts_id)
def _on_tts_completed(self, tts_id: int, task: asyncio.Task) -> None:
"""Handle completion of a TTS processing task."""
self.running_tts.discard(tts_id)
if task.exception():
logger.error(
"TTS processing task %d failed: %s",
tts_id,
task.exception(),
)
else:
logger.info("TTS processing task %d completed successfully", tts_id)
async def _reset_stuck_tts(self) -> None:
"""Reset any TTS generations that were stuck in 'processing' state."""
try:
async with AsyncSession(engine) as session:
tts_service = TTSService(session)
reset_count = await tts_service.reset_stuck_tts()
if reset_count > 0:
logger.info("Reset %d stuck TTS generations", reset_count)
else:
logger.info("No stuck TTS generations found to reset")
except Exception:
logger.exception("Failed to reset stuck TTS generations")
# Global TTS processor instance
tts_processor = TTSProcessor()

View File

@@ -6,10 +6,10 @@ from collections.abc import Callable
from pathlib import Path
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.credit_action import CreditActionType
from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.repositories.sound import SoundRepository
@@ -24,7 +24,8 @@ class VLCPlayerService:
"""Service for launching VLC instances via subprocess to play sounds."""
def __init__(
self, db_session_factory: Callable[[], AsyncSession] | None = None,
self,
db_session_factory: Callable[[], AsyncSession] | None = None,
) -> None:
"""Initialize the VLC player service."""
self.vlc_executable = self._find_vlc_executable()
@@ -53,7 +54,7 @@ class VLCPlayerService:
# For "vlc", try to find it in PATH
if path == "vlc":
result = subprocess.run(
["which", "vlc"],
["which", "vlc"], # noqa: S607
capture_output=True,
check=False,
text=True,
@@ -72,6 +73,9 @@ class VLCPlayerService:
async def play_sound(self, sound: Sound) -> bool:
"""Play a sound using a new VLC subprocess instance.
VLC always plays at 100% volume. Host system volume is controlled separately
by the player service.
Args:
sound: The Sound object to play
@@ -96,6 +100,7 @@ class VLCPlayerService:
"--no-video", # Audio only
"--no-repeat", # Don't repeat
"--no-loop", # Don't loop
"--volume=100", # Always use 100% VLC volume
]
# Launch VLC process asynchronously without waiting
@@ -113,13 +118,19 @@ class VLCPlayerService:
# Record play count and emit event
if self.db_session_factory and sound.id:
asyncio.create_task(self._record_play_count(sound.id, sound.name))
return True
task = asyncio.create_task(
self._record_play_count(sound.id, sound.name),
)
# Store reference to prevent garbage collection
self._background_tasks = getattr(self, "_background_tasks", set())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
except Exception:
logger.exception("Failed to launch VLC for sound %s", sound.name)
return False
else:
return True
async def stop_all_vlc_instances(self) -> dict[str, Any]:
"""Stop all running VLC processes by killing them.
@@ -137,7 +148,7 @@ class VLCPlayerService:
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await find_process.communicate()
stdout, _stderr = await find_process.communicate()
if find_process.returncode != 0:
# No VLC processes found
@@ -231,72 +242,164 @@ class VLCPlayerService:
return
logger.info("Recording play count for sound %s", sound_id)
session = self.db_session_factory()
# Initialize variables for WebSocket event
old_count = 0
sound = None
admin_user_id = None
admin_user_name = None
try:
sound_repo = SoundRepository(session)
user_repo = UserRepository(session)
async with self.db_session_factory() as session:
sound_repo = SoundRepository(session)
user_repo = UserRepository(session)
# Update sound play count
sound = await sound_repo.get_by_id(sound_id)
old_count = 0
if sound:
old_count = sound.play_count
await sound_repo.update(
sound,
{"play_count": sound.play_count + 1},
# Update sound play count
sound = await sound_repo.get_by_id(sound_id)
if sound:
old_count = sound.play_count
# Update the sound's play count using direct attribute modification
sound.play_count = sound.play_count + 1
session.add(sound)
await session.commit()
await session.refresh(sound)
logger.info(
"Updated sound %s play_count: %s -> %s",
sound_id,
old_count,
old_count + 1,
)
else:
logger.warning("Sound %s not found for play count update", sound_id)
# Record play history for admin user (ID 1) as placeholder
# This could be refined to track per-user play history
admin_user = await user_repo.get_by_id(1)
if admin_user:
admin_user_id = admin_user.id
admin_user_name = admin_user.name
# Always create a new SoundPlayed record for each play event
sound_played = SoundPlayed(
user_id=admin_user_id, # Can be None for player-based plays
sound_id=sound_id,
)
session.add(sound_played)
logger.info(
"Updated sound %s play_count: %s -> %s",
"Created SoundPlayed record for user %s, sound %s",
admin_user_id,
sound_id,
old_count,
old_count + 1,
)
else:
logger.warning("Sound %s not found for play count update", sound_id)
# Record play history for admin user (ID 1) as placeholder
# This could be refined to track per-user play history
admin_user = await user_repo.get_by_id(1)
admin_user_id = None
if admin_user:
admin_user_id = admin_user.id
await session.commit()
logger.info("Successfully recorded play count for sound %s", sound_id)
except Exception:
logger.exception("Error recording play count for sound %s", sound_id)
# Always create a new SoundPlayed record for each play event
sound_played = SoundPlayed(
user_id=admin_user_id, # Can be None for player-based plays
sound_id=sound_id,
)
session.add(sound_played)
logger.info(
"Created SoundPlayed record for user %s, sound %s",
admin_user_id,
# Emit sound_played event via WebSocket (outside session context)
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": admin_user_id,
"user_name": admin_user_name,
"play_count": (old_count + 1) if sound else None,
}
await socket_manager.broadcast_to_all("sound_played", event_data)
logger.info("Broadcasted sound_played event for sound %s", sound_id)
except Exception:
logger.exception(
"Failed to broadcast sound_played event for sound %s",
sound_id,
)
await session.commit()
logger.info("Successfully recorded play count for sound %s", sound_id)
async def play_sound_with_credits(
self,
sound_id: int,
user_id: int,
) -> dict[str, str | int | bool]:
"""Play sound with VLC with credit validation and deduction.
# Emit sound_played event via WebSocket
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": admin_user_id,
"play_count": (old_count + 1) if sound else None,
}
await socket_manager.broadcast_to_all("sound_played", event_data)
logger.info("Broadcasted sound_played event for sound %s", sound_id)
except Exception:
logger.exception(
"Failed to broadcast sound_played event for sound %s", sound_id,
This method combines credit checking, sound playing, and credit deduction
in a single operation. Used by both HTTP and WebSocket endpoints.
Args:
sound_id: ID of the sound to play
user_id: ID of the user playing the sound
Returns:
dict: Result information including success status and message
Raises:
HTTPException: For various error conditions (sound not found,
insufficient credits, VLC failure)
"""
from fastapi import HTTPException, status # noqa: PLC0415, I001
from app.services.credit import ( # noqa: PLC0415
CreditService,
InsufficientCreditsError,
)
if not self.db_session_factory:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Database session factory not configured",
)
async with self.db_session_factory() as session:
sound_repo = SoundRepository(session)
# Get the sound
sound = await sound_repo.get_by_id(sound_id)
if not sound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Sound with ID {sound_id} not found",
)
except Exception:
logger.exception("Error recording play count for sound %s", sound_id)
await session.rollback()
finally:
await session.close()
# Get credit service
credit_service = CreditService(self.db_session_factory)
# Check and validate credits before playing
try:
await credit_service.validate_and_reserve_credits(
user_id,
CreditActionType.VLC_PLAY_SOUND,
)
except InsufficientCreditsError as e:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=(
f"Insufficient credits: {e.required} required, "
f"{e.available} available"
),
) from e
# Play the sound using VLC (always at 100% VLC volume)
success = await self.play_sound(sound)
# Deduct credits based on success
await credit_service.deduct_credits(
user_id,
CreditActionType.VLC_PLAY_SOUND,
success=success,
metadata={"sound_id": sound_id, "sound_name": sound.name},
)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to launch VLC for sound playback",
)
return {
"message": f"Sound '{sound.name}' is now playing via VLC",
"sound_id": sound_id,
"sound_name": sound.name,
"success": True,
"credits_deducted": 1,
}
# Global VLC player service instance

251
app/services/volume.py Normal file
View File

@@ -0,0 +1,251 @@
"""Volume service for host system volume control."""
import platform
from app.core.logging import get_logger
logger = get_logger(__name__)
# Constants
MIN_VOLUME = 0
MAX_VOLUME = 100
class VolumeService:
"""Service for controlling host system volume."""
def __init__(self) -> None:
"""Initialize volume service."""
self._system = platform.system().lower()
self._pycaw_available = False
self._pulsectl_available = False
# Try to import Windows volume control
if self._system == "windows":
try:
from comtypes import ( # noqa: PLC0415
CLSCTX_ALL, # type: ignore[import-untyped]
)
from pycaw.pycaw import ( # type: ignore[import-untyped] # noqa: PLC0415
AudioUtilities,
IAudioEndpointVolume,
)
self._AudioUtilities = AudioUtilities
self._IAudioEndpointVolume = IAudioEndpointVolume
self._CLSCTX_ALL = CLSCTX_ALL
self._pycaw_available = True
logger.info("Windows volume control (pycaw) initialized")
except ImportError as e:
logger.warning("pycaw not available: %s", e)
# Try to import Linux volume control
elif self._system == "linux":
try:
import pulsectl # type: ignore[import-untyped] # noqa: PLC0415
self._pulsectl = pulsectl
self._pulsectl_available = True
logger.info("Linux volume control (pulsectl) initialized")
except ImportError as e:
logger.warning("pulsectl not available: %s", e)
else:
logger.warning("Volume control not supported on %s", self._system)
def get_volume(self) -> int | None:
"""Get the current system volume as a percentage (0-100).
Returns:
Volume level as percentage, or None if not available
"""
try:
if self._system == "windows" and self._pycaw_available:
return self._get_windows_volume()
if self._system == "linux" and self._pulsectl_available:
return self._get_linux_volume()
except Exception:
logger.exception("Failed to get volume")
return None
else:
logger.warning("Volume control not available for this system")
return None
def set_volume(self, volume: int) -> bool:
"""Set the system volume to a percentage (0-100).
Args:
volume: Volume level as percentage (0-100)
Returns:
True if successful, False otherwise
"""
if not (MIN_VOLUME <= volume <= MAX_VOLUME):
logger.error(
"Volume must be between %s and %s, got %s",
MIN_VOLUME,
MAX_VOLUME,
volume,
)
return False
try:
if self._system == "windows" and self._pycaw_available:
return self._set_windows_volume(volume)
if self._system == "linux" and self._pulsectl_available:
return self._set_linux_volume(volume)
except Exception:
logger.exception("Failed to set volume")
return False
else:
logger.warning("Volume control not available for this system")
return False
def is_muted(self) -> bool | None:
"""Check if the system is muted.
Returns:
True if muted, False if not muted, None if not available
"""
try:
if self._system == "windows" and self._pycaw_available:
return self._get_windows_mute_status()
if self._system == "linux" and self._pulsectl_available:
return self._get_linux_mute_status()
except Exception:
logger.exception("Failed to get mute status")
return None
else:
logger.warning("Mute status not available for this system")
return None
def set_mute(self, *, muted: bool) -> bool:
"""Set the system mute status.
Args:
muted: True to mute, False to unmute
Returns:
True if successful, False otherwise
"""
try:
if self._system == "windows" and self._pycaw_available:
return self._set_windows_mute(muted=muted)
if self._system == "linux" and self._pulsectl_available:
return self._set_linux_mute(muted=muted)
except Exception:
logger.exception("Failed to set mute status")
return False
else:
logger.warning("Mute control not available for this system")
return False
def _get_windows_volume(self) -> int:
"""Get Windows volume using pycaw."""
devices = self._AudioUtilities.GetSpeakers()
interface = devices.Activate(
self._IAudioEndpointVolume._iid_, self._CLSCTX_ALL, None,
)
volume = interface.QueryInterface(self._IAudioEndpointVolume)
current_volume = volume.GetMasterVolume()
# Convert from scalar (0.0-1.0) to percentage (0-100)
return int(current_volume * MAX_VOLUME)
def _set_windows_volume(self, volume_percent: int) -> bool:
"""Set Windows volume using pycaw."""
devices = self._AudioUtilities.GetSpeakers()
interface = devices.Activate(
self._IAudioEndpointVolume._iid_, self._CLSCTX_ALL, None,
)
volume = interface.QueryInterface(self._IAudioEndpointVolume)
# Convert from percentage (0-100) to scalar (0.0-1.0)
volume_scalar = volume_percent / MAX_VOLUME
volume.SetMasterVolume(volume_scalar, None)
logger.info("Windows volume set to %s%%", volume_percent)
return True
def _get_windows_mute_status(self) -> bool:
"""Get Windows mute status using pycaw."""
devices = self._AudioUtilities.GetSpeakers()
interface = devices.Activate(
self._IAudioEndpointVolume._iid_, self._CLSCTX_ALL, None,
)
volume = interface.QueryInterface(self._IAudioEndpointVolume)
return bool(volume.GetMute())
def _set_windows_mute(self, *, muted: bool) -> bool:
"""Set Windows mute status using pycaw."""
devices = self._AudioUtilities.GetSpeakers()
interface = devices.Activate(
self._IAudioEndpointVolume._iid_, self._CLSCTX_ALL, None,
)
volume = interface.QueryInterface(self._IAudioEndpointVolume)
volume.SetMute(muted, None)
logger.info("Windows mute set to %s", muted)
return True
def _get_linux_volume(self) -> int:
"""Get Linux volume using pulsectl."""
with self._pulsectl.Pulse("volume-service") as pulse:
# Get the default sink (output device)
default_sink = pulse.get_sink_by_name(pulse.server_info().default_sink_name)
if default_sink is None:
logger.error("No default audio sink found")
return MIN_VOLUME
# Get volume as percentage (PulseAudio uses 0.0-1.0, we convert to 0-100)
volume = default_sink.volume
avg_volume = sum(volume.values) / len(volume.values)
return int(avg_volume * MAX_VOLUME)
def _set_linux_volume(self, volume_percent: int) -> bool:
"""Set Linux volume using pulsectl."""
with self._pulsectl.Pulse("volume-service") as pulse:
# Get the default sink (output device)
default_sink = pulse.get_sink_by_name(pulse.server_info().default_sink_name)
if default_sink is None:
logger.error("No default audio sink found")
return False
# Convert percentage to PulseAudio volume (0.0-1.0)
volume_scalar = volume_percent / MAX_VOLUME
# Set volume for all channels
pulse.volume_set_all_chans(default_sink, volume_scalar)
logger.info("Linux volume set to %s%%", volume_percent)
return True
def _get_linux_mute_status(self) -> bool:
"""Get Linux mute status using pulsectl."""
with self._pulsectl.Pulse("volume-service") as pulse:
# Get the default sink (output device)
default_sink = pulse.get_sink_by_name(pulse.server_info().default_sink_name)
if default_sink is None:
logger.error("No default audio sink found")
return False
return bool(default_sink.mute)
def _set_linux_mute(self, *, muted: bool) -> bool:
"""Set Linux mute status using pulsectl."""
with self._pulsectl.Pulse("volume-service") as pulse:
# Get the default sink (output device)
default_sink = pulse.get_sink_by_name(pulse.server_info().default_sink_name)
if default_sink is None:
logger.error("No default audio sink found")
return False
# Set mute status
pulse.mute(default_sink, muted)
logger.info("Linux mute set to %s", muted)
return True
# Global volume service instance
volume_service = VolumeService()

View File

@@ -34,7 +34,7 @@ def get_audio_duration(file_path: Path) -> int:
probe = ffmpeg.probe(str(file_path))
duration = float(probe["format"]["duration"])
return int(duration * 1000) # Convert to milliseconds
except Exception as e:
except (ffmpeg.Error, KeyError, ValueError, TypeError, Exception) as e: # noqa: BLE001
logger.warning("Failed to get duration for %s: %s", file_path, e)
return 0

View File

@@ -1,16 +1,20 @@
"""Cookie parsing utilities for WebSocket authentication."""
"""Cookie parsing and setting utilities for WebSocket and HTTP authentication."""
from fastapi import Response
from app.core.config import settings
def parse_cookies(cookie_header: str) -> dict[str, str]:
"""Parse HTTP cookie header into a dictionary."""
cookies = {}
cookies: dict[str, str] = {}
if not cookie_header:
return cookies
for cookie in cookie_header.split(";"):
cookie = cookie.strip()
if "=" in cookie:
name, value = cookie.split("=", 1)
for cookie_part in cookie_header.split(";"):
cookie_str = cookie_part.strip()
if "=" in cookie_str:
name, value = cookie_str.split("=", 1)
cookies[name.strip()] = value.strip()
return cookies
@@ -20,3 +24,52 @@ def extract_access_token_from_cookies(cookie_header: str) -> str | None:
"""Extract access token from HTTP cookies."""
cookies = parse_cookies(cookie_header)
return cookies.get("access_token")
def set_access_token_cookie(
response: Response,
access_token: str,
expires_in: int,
path: str = "/",
) -> None:
"""Set access token cookie with consistent configuration."""
response.set_cookie(
key="access_token",
value=access_token,
max_age=expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain=settings.COOKIE_DOMAIN,
path=path,
)
def set_refresh_token_cookie(
response: Response,
refresh_token: str,
path: str = "/",
) -> None:
"""Set refresh token cookie with consistent configuration."""
response.set_cookie(
key="refresh_token",
value=refresh_token,
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain=settings.COOKIE_DOMAIN,
path=path,
)
def set_auth_cookies(
response: Response,
access_token: str,
refresh_token: str,
expires_in: int,
path: str = "/",
) -> None:
"""Set both access and refresh token cookies with consistent configuration."""
set_access_token_cookie(response, access_token, expires_in, path)
set_refresh_token_cookie(response, refresh_token, path)

View File

@@ -1,11 +1,13 @@
"""Decorators for credit management and validation."""
import functools
import inspect
import types
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.credit import CreditService
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
@@ -16,7 +18,7 @@ def requires_credits(
user_id_param: str = "user_id",
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
) -> Callable[[F], F]:
"""Decorator to enforce credit requirements for actions.
"""Enforce credit requirements for actions.
Args:
action_type: The type of action that requires credits
@@ -38,16 +40,16 @@ def requires_credits(
return True
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
@@ -69,26 +71,31 @@ def requires_credits(
# Validate credits before execution
await credit_service.validate_and_reserve_credits(
user_id, action_type, metadata
user_id,
action_type,
)
# Execute the function
success = False
result = None
try:
result = await func(*args, **kwargs)
success = bool(result) # Consider function result as success indicator
return result
except Exception:
success = False
raise
else:
return result
finally:
# Deduct credits based on success
await credit_service.deduct_credits(
user_id, action_type, success, metadata
user_id,
action_type,
success=success,
metadata=metadata,
)
return wrapper # type: ignore[return-value]
return decorator
@@ -97,7 +104,7 @@ def validate_credits_only(
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
) -> Callable[[F], F]:
"""Decorator to only validate credits without deducting them.
"""Validate credits without deducting them.
Useful for checking if a user can perform an action before actual execution.
@@ -110,16 +117,16 @@ def validate_credits_only(
Decorated function that validates credits only
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
@@ -141,6 +148,7 @@ def validate_credits_only(
return await func(*args, **kwargs)
return wrapper # type: ignore[return-value]
return decorator
@@ -173,20 +181,29 @@ class CreditManager:
async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type, self.metadata
self.user_id,
self.action_type,
)
self.validated = True
return self
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
"""Exit context manager - deduct credits based on success."""
if self.validated:
# If no exception occurred, consider it successful
success = exc_type is None and self.success
await self.credit_service.deduct_credits(
self.user_id, self.action_type, success, self.metadata
self.user_id,
self.action_type,
success=success,
metadata=self.metadata,
)
def mark_success(self) -> None:
"""Mark the operation as successful."""
self.success = True
self.success = True

View File

@@ -1,166 +0,0 @@
"""Database utility functions for common operations."""
from typing import Any, Dict, List, Optional, Type, TypeVar
from sqlmodel import select, SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
T = TypeVar("T", bound=SQLModel)
async def create_and_save(
session: AsyncSession,
model_class: Type[T],
**kwargs: Any
) -> T:
"""Create, add, commit, and refresh a model instance.
This consolidates the common database pattern of:
- instance = ModelClass(**kwargs)
- session.add(instance)
- await session.commit()
- await session.refresh(instance)
Args:
session: Database session
model_class: SQLModel class to instantiate
**kwargs: Arguments to pass to model constructor
Returns:
Created and refreshed model instance
"""
instance = model_class(**kwargs)
session.add(instance)
await session.commit()
await session.refresh(instance)
return instance
async def get_or_create(
session: AsyncSession,
model_class: Type[T],
defaults: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> tuple[T, bool]:
"""Get an existing instance or create a new one.
Args:
session: Database session
model_class: SQLModel class
defaults: Default values for creation (if not found)
**kwargs: Filter criteria for lookup
Returns:
Tuple of (instance, created) where created is True if instance was created
"""
# Build filter conditions
filters = []
for key, value in kwargs.items():
filters.append(getattr(model_class, key) == value)
# Try to find existing instance
statement = select(model_class).where(*filters)
result = await session.exec(statement)
instance = result.first()
if instance:
return instance, False
# Create new instance
create_kwargs = {**kwargs}
if defaults:
create_kwargs.update(defaults)
instance = await create_and_save(session, model_class, **create_kwargs)
return instance, True
async def update_and_save(
session: AsyncSession,
instance: T,
**updates: Any
) -> T:
"""Update model instance fields and save to database.
Args:
session: Database session
instance: Model instance to update
**updates: Field updates to apply
Returns:
Updated and refreshed model instance
"""
for field, value in updates.items():
setattr(instance, field, value)
session.add(instance)
await session.commit()
await session.refresh(instance)
return instance
async def bulk_create(
session: AsyncSession,
model_class: Type[T],
items: List[Dict[str, Any]]
) -> List[T]:
"""Create multiple model instances in bulk.
Args:
session: Database session
model_class: SQLModel class to instantiate
items: List of dictionaries with model data
Returns:
List of created model instances
"""
instances = []
for item_data in items:
instance = model_class(**item_data)
session.add(instance)
instances.append(instance)
await session.commit()
# Refresh all instances
for instance in instances:
await session.refresh(instance)
return instances
async def delete_and_commit(
session: AsyncSession,
instance: T
) -> None:
"""Delete an instance and commit the transaction.
Args:
session: Database session
instance: Model instance to delete
"""
await session.delete(instance)
await session.commit()
async def exists(
session: AsyncSession,
model_class: Type[T],
**kwargs: Any
) -> bool:
"""Check if a model instance exists with given criteria.
Args:
session: Database session
model_class: SQLModel class
**kwargs: Filter criteria
Returns:
True if instance exists, False otherwise
"""
filters = []
for key, value in kwargs.items():
filters.append(getattr(model_class, key) == value)
statement = select(model_class).where(*filters)
result = await session.exec(statement)
return result.first() is not None

View File

@@ -1,121 +0,0 @@
"""Utility functions for common HTTP exception patterns."""
from fastapi import HTTPException, status
def raise_not_found(resource: str, identifier: str = None) -> None:
"""Raise a standardized 404 Not Found exception.
Args:
resource: Name of the resource that wasn't found
identifier: Optional identifier for the specific resource
Raises:
HTTPException with 404 status code
"""
if identifier:
detail = f"{resource} with ID {identifier} not found"
else:
detail = f"{resource} not found"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def raise_unauthorized(detail: str = "Could not validate credentials") -> None:
"""Raise a standardized 401 Unauthorized exception.
Args:
detail: Error message detail
Raises:
HTTPException with 401 status code
"""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)
def raise_bad_request(detail: str) -> None:
"""Raise a standardized 400 Bad Request exception.
Args:
detail: Error message detail
Raises:
HTTPException with 400 status code
"""
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=detail,
)
def raise_internal_server_error(detail: str, cause: Exception = None) -> None:
"""Raise a standardized 500 Internal Server Error exception.
Args:
detail: Error message detail
cause: Optional underlying exception
Raises:
HTTPException with 500 status code
"""
if cause:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail,
) from cause
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail,
)
def raise_payment_required(detail: str = "Insufficient credits") -> None:
"""Raise a standardized 402 Payment Required exception.
Args:
detail: Error message detail
Raises:
HTTPException with 402 status code
"""
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=detail,
)
def raise_forbidden(detail: str = "Access forbidden") -> None:
"""Raise a standardized 403 Forbidden exception.
Args:
detail: Error message detail
Raises:
HTTPException with 403 status code
"""
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
def raise_conflict(detail: str) -> None:
"""Raise a standardized 409 Conflict exception.
Args:
detail: Error message detail
Raises:
HTTPException with 409 status code
"""
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
)

View File

@@ -1,179 +0,0 @@
"""Test helper utilities for reducing code duplication."""
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Type, TypeVar
from unittest.mock import AsyncMock
from fastapi import FastAPI
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import SQLModel
from app.models.user import User
from app.utils.auth import JWTUtils
T = TypeVar("T", bound=SQLModel)
def create_jwt_token_data(user: User) -> Dict[str, str]:
"""Create standardized JWT token data dictionary for a user.
Args:
user: User object to create token data for
Returns:
Dictionary with sub, email, and role fields
"""
return {
"sub": str(user.id),
"email": user.email,
"role": user.role,
}
def create_access_token_for_user(user: User) -> str:
"""Create an access token for a user using standardized token data.
Args:
user: User object to create token for
Returns:
JWT access token string
"""
token_data = create_jwt_token_data(user)
return JWTUtils.create_access_token(token_data)
async def create_and_save_model(
session: AsyncSession,
model_class: Type[T],
**kwargs: Any
) -> T:
"""Create, save, and refresh a model instance.
This consolidates the common pattern of:
- model = ModelClass(**kwargs)
- session.add(model)
- await session.commit()
- await session.refresh(model)
Args:
session: Database session
model_class: SQLModel class to instantiate
**kwargs: Arguments to pass to model constructor
Returns:
Created and refreshed model instance
"""
instance = model_class(**kwargs)
session.add(instance)
await session.commit()
await session.refresh(instance)
return instance
@asynccontextmanager
async def override_dependencies(
app: FastAPI,
overrides: Dict[Any, Any]
):
"""Context manager for FastAPI dependency overrides with automatic cleanup.
Args:
app: FastAPI application instance
overrides: Dictionary mapping dependency functions to mock implementations
Usage:
async with override_dependencies(test_app, {
get_service: lambda: mock_service,
get_repo: lambda: mock_repo
}):
# Test code here
pass
# Dependencies automatically cleaned up
"""
# Apply overrides
for dependency, override in overrides.items():
app.dependency_overrides[dependency] = override
try:
yield
finally:
# Clean up overrides
for dependency in overrides:
app.dependency_overrides.pop(dependency, None)
def create_mock_vlc_services() -> Dict[str, AsyncMock]:
"""Create standard set of mocked VLC-related services.
Returns:
Dictionary with mocked vlc_service, sound_repository, and credit_service
"""
return {
"vlc_service": AsyncMock(),
"sound_repository": AsyncMock(),
"credit_service": AsyncMock(),
}
def configure_mock_sound_play_success(
mocks: Dict[str, AsyncMock],
sound_data: Dict[str, Any]
) -> None:
"""Configure mocks for successful sound playback scenario.
Args:
mocks: Dictionary of mock services from create_mock_vlc_services()
sound_data: Dictionary with sound properties (id, name, etc.)
"""
from app.models.sound import Sound
mock_sound = Sound(**sound_data)
# Configure repository mock
mocks["sound_repository"].get_by_id.return_value = mock_sound
# Configure credit service mocks
mocks["credit_service"].validate_and_reserve_credits.return_value = None
mocks["credit_service"].deduct_credits.return_value = None
# Configure VLC service mock
mocks["vlc_service"].play_sound.return_value = True
def create_mock_vlc_stop_result(
success: bool = True,
processes_found: int = 3,
processes_killed: int = 3,
processes_remaining: int = 0,
message: Optional[str] = None,
error: Optional[str] = None
) -> Dict[str, Any]:
"""Create standardized VLC stop operation result.
Args:
success: Whether operation succeeded
processes_found: Number of VLC processes found
processes_killed: Number of processes successfully killed
processes_remaining: Number of processes still running
message: Success/status message
error: Error message (for failed operations)
Returns:
Dictionary with VLC stop operation result
"""
result = {
"success": success,
"processes_found": processes_found,
"processes_killed": processes_killed,
}
if not success:
result["error"] = error or "Command failed"
result["message"] = message or "Failed to stop VLC processes"
else:
# Always include processes_remaining for successful operations
result["processes_remaining"] = processes_remaining
result["message"] = message or f"Killed {processes_killed} VLC processes"
return result

View File

@@ -1,140 +0,0 @@
"""Common validation utility functions."""
import re
from pathlib import Path
from typing import Any, Optional
# Password validation constants
MIN_PASSWORD_LENGTH = 8
def validate_email(email: str) -> bool:
"""Validate email address format.
Args:
email: Email address to validate
Returns:
True if email format is valid, False otherwise
"""
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
return bool(re.match(pattern, email))
def validate_password_strength(password: str) -> tuple[bool, str | None]:
"""Validate password meets security requirements.
Args:
password: Password to validate
Returns:
Tuple of (is_valid, error_message)
"""
if len(password) < MIN_PASSWORD_LENGTH:
msg = f"Password must be at least {MIN_PASSWORD_LENGTH} characters long"
return False, msg
if not re.search(r"[A-Z]", password):
return False, "Password must contain at least one uppercase letter"
if not re.search(r"[a-z]", password):
return False, "Password must contain at least one lowercase letter"
if not re.search(r"\d", password):
return False, "Password must contain at least one number"
return True, None
def validate_filename(
filename: str, allowed_extensions: list[str] | None = None
) -> bool:
"""Validate filename format and extension.
Args:
filename: Filename to validate
allowed_extensions: List of allowed file extensions (with dots)
Returns:
True if filename is valid, False otherwise
"""
if not filename or filename.startswith(".") or "/" in filename or "\\" in filename:
return False
if allowed_extensions:
file_path = Path(filename)
return file_path.suffix.lower() in [ext.lower() for ext in allowed_extensions]
return True
def validate_audio_filename(filename: str) -> bool:
"""Validate audio filename has allowed extension.
Args:
filename: Audio filename to validate
Returns:
True if filename has valid audio extension, False otherwise
"""
audio_extensions = [".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"]
return validate_filename(filename, audio_extensions)
def sanitize_filename(filename: str) -> str:
"""Sanitize filename by removing/replacing invalid characters.
Args:
filename: Filename to sanitize
Returns:
Sanitized filename safe for filesystem
"""
# Remove or replace problematic characters
sanitized = re.sub(r'[<>:"/\\|?*]', "_", filename)
# Remove leading/trailing whitespace and dots
sanitized = sanitized.strip(" .")
# Ensure not empty
if not sanitized:
sanitized = "untitled"
return sanitized
def validate_url(url: str) -> bool:
"""Validate URL format.
Args:
url: URL to validate
Returns:
True if URL format is valid, False otherwise
"""
pattern = r"^https?://[^\s/$.?#].[^\s]*$"
return bool(re.match(pattern, url))
def validate_positive_integer(value: Any, field_name: str = "value") -> int:
"""Validate and convert value to positive integer.
Args:
value: Value to validate and convert
field_name: Name of field for error messages
Returns:
Validated positive integer
Raises:
ValueError: If value is not a positive integer
"""
try:
int_value = int(value)
if int_value <= 0:
msg = f"{field_name} must be a positive integer"
raise ValueError(msg)
return int_value
except (TypeError, ValueError) as e:
msg = f"{field_name} must be a positive integer"
raise ValueError(msg) from e

100
migrate.py Executable file
View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""Database migration CLI tool."""
import argparse
import logging
import sys
from pathlib import Path
from alembic.config import Config
from alembic import command
# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(message)s")
def main() -> None:
"""Run database migration CLI tool."""
parser = argparse.ArgumentParser(description="Database migration tool")
subparsers = parser.add_subparsers(dest="command", help="Migration commands")
# Upgrade command
upgrade_parser = subparsers.add_parser(
"upgrade", help="Upgrade database to latest revision",
)
upgrade_parser.add_argument(
"revision",
nargs="?",
default="head",
help="Target revision (default: head)",
)
# Downgrade command
downgrade_parser = subparsers.add_parser("downgrade", help="Downgrade database")
downgrade_parser.add_argument("revision", help="Target revision")
# Current command
subparsers.add_parser("current", help="Show current revision")
# History command
subparsers.add_parser("history", help="Show revision history")
# Generate migration command
revision_parser = subparsers.add_parser("revision", help="Create new migration")
revision_parser.add_argument(
"-m", "--message", required=True, help="Migration message",
)
revision_parser.add_argument(
"--autogenerate", action="store_true", help="Auto-generate migration",
)
args = parser.parse_args()
if not args.command:
parser.print_help()
sys.exit(1)
# Get the alembic config
config_path = Path("alembic.ini")
if not config_path.exists():
logger.error("Error: alembic.ini not found. Run from the backend directory.")
sys.exit(1)
alembic_cfg = Config(str(config_path))
try:
if args.command == "upgrade":
command.upgrade(alembic_cfg, args.revision)
logger.info(
"Successfully upgraded database to revision: %s", args.revision,
)
elif args.command == "downgrade":
command.downgrade(alembic_cfg, args.revision)
logger.info(
"Successfully downgraded database to revision: %s", args.revision,
)
elif args.command == "current":
command.current(alembic_cfg)
elif args.command == "history":
command.history(alembic_cfg)
elif args.command == "revision":
if args.autogenerate:
command.revision(alembic_cfg, message=args.message, autogenerate=True)
logger.info("Created new auto-generated migration: %s", args.message)
else:
command.revision(alembic_cfg, message=args.message)
logger.info("Created new empty migration: %s", args.message)
except (OSError, RuntimeError):
logger.exception("Error occurred during migration")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,34 +1,42 @@
[project]
name = "backend"
version = "0.1.0"
description = "Add your description here"
name = "sdb"
version = "2.0.0"
description = "Backend for the SDB v2"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"aiosqlite==0.21.0",
"bcrypt==4.3.0",
"email-validator==2.2.0",
"fastapi[standard]==0.116.1",
"alembic==1.16.5",
"apscheduler==3.11.0",
"bcrypt==5.0.0",
"email-validator==2.3.0",
"fastapi[standard]==0.118.0",
"ffmpeg-python==0.2.0",
"gtts==2.5.4",
"httpx==0.28.1",
"pydantic-settings==2.10.1",
"pydantic-settings==2.11.0",
"pyjwt==2.10.1",
"python-socketio==5.13.0",
"python-socketio==5.14.1",
"pytz==2025.2",
"python-vlc==3.0.21203",
"sqlmodel==0.0.24",
"uvicorn[standard]==0.35.0",
"yt-dlp==2025.7.21",
"sqlmodel==0.0.25",
"uvicorn[standard]==0.37.0",
"yt-dlp==2025.9.26",
"asyncpg==0.30.0",
"psycopg[binary]==3.2.10",
"pycaw>=20240210",
"pulsectl>=24.12.0",
]
[tool.uv]
dev-dependencies = [
"coverage==7.10.1",
"faker==37.4.2",
"coverage==7.10.7",
"faker==37.8.0",
"httpx==0.28.1",
"mypy==1.17.0",
"pytest==8.4.1",
"pytest-asyncio==1.1.0",
"ruff==0.12.6",
"mypy==1.18.2",
"pytest==8.4.2",
"pytest-asyncio==1.2.0",
"ruff==0.13.3",
]
[tool.mypy]
@@ -43,10 +51,31 @@ exclude = ["alembic"]
select = ["ALL"]
ignore = ["D100", "D103", "TRY301"]
[tool.ruff.per-file-ignores]
"tests/**/*.py" = ["S101", "S105"]
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
"S101", # Use of assert detected
"S105", # Possible hardcoded password
"S106", # Possible hardcoded password
"ANN001", # Missing type annotation for function argument
"ANN003", # Missing type annotation for **kwargs
"ANN201", # Missing return type annotation for public function
"ANN202", # Missing return type annotation for private function
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ARG001", # Unused function argument
"ARG002", # Unused method argument
"ARG005", # Unused lambda argument
"BLE001", # Do not catch blind exception
"E501", # Line too long
"PLR2004", # Magic value used in comparison
"PLC0415", # `import` should be at top-level
"SLF001", # Private member accessed
"SIM117", # Use a single `if` statement
"PT011", # `pytest.raises()` is too broad
"PT012", # `pytest.raises()` block should contain a single simple statement
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
filterwarnings = [
"ignore:transaction already deassociated from connection:sqlalchemy.exc.SAWarning",
]

55
reset.sh Executable file
View File

@@ -0,0 +1,55 @@
#!/bin/bash
# Reset script for SDB2 soundboard application
# This script removes the database and cleans extracted sounds
set -e # Exit on any error
echo "🔄 Resetting SDB2 application..."
# Change to backend directory
cd "$(dirname "$0")"
# Remove database file if it exists
if [ -f "data/soundboard.db" ]; then
echo "🗑️ Removing database: data/soundboard.db"
rm "data/soundboard.db"
else
echo " Database file not found, skipping"
fi
# List of folders to clean (only files will be deleted, preserving .gitignore)
FOLDERS_TO_CLEAN=(
"sounds/originals/extracted"
"sounds/originals/extracted/thumbnails"
"sounds/originals/text_to_speech"
"sounds/normalized/extracted"
"sounds/normalized/soundboard"
"sounds/normalized/text_to_speech"
"sounds/temp"
)
# Function to clean files in a directory
clean_folder() {
local folder="$1"
if [ -d "$folder" ]; then
echo "🧹 Cleaning folder: $folder"
# Find and delete all files except .gitignore (preserving subdirectories)
find "$folder" -maxdepth 1 -type f -not -name '.gitignore' -delete
echo "✅ Folder cleaned: $folder"
else
echo " Folder not found, skipping: $folder"
fi
}
# Clean all specified folders
echo "🧹 Cleaning specified folders..."
for folder in "${FOLDERS_TO_CLEAN[@]}"; do
clean_folder "$folder"
done
echo "✅ Application reset complete!"
echo "💡 Run 'uv run python run.py' to start fresh"

View File

@@ -0,0 +1 @@
"""Tests for admin API endpoints."""

View File

@@ -0,0 +1,154 @@
"""Tests for admin extraction API endpoints."""
import pytest
from httpx import AsyncClient
from app.models.extraction import Extraction
from app.models.user import User
class TestAdminExtractionEndpoints:
"""Test admin extraction endpoints."""
@pytest.mark.asyncio
async def test_get_extraction_processor_status(self, authenticated_admin_client):
"""Test getting extraction processor status."""
response = await authenticated_admin_client.get(
"/api/v1/admin/extractions/status",
)
assert response.status_code == 200
data = response.json()
# Check expected status fields (match actual processor status format)
assert "currently_processing" in data
assert "max_concurrent" in data
assert "available_slots" in data
assert "processing_ids" in data
assert isinstance(data["currently_processing"], int)
assert isinstance(data["max_concurrent"], int)
assert isinstance(data["available_slots"], int)
assert isinstance(data["processing_ids"], list)
@pytest.mark.asyncio
async def test_admin_delete_extraction_success(
self,
authenticated_admin_client,
test_session,
test_plan,
):
"""Test admin successfully deleting any extraction."""
# Create a test user
user = User(
name="Test User",
email="test@example.com",
is_active=True,
plan_id=test_plan.id,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Create test extraction
extraction = Extraction(
url="https://example.com/video",
user_id=user.id,
status="completed",
)
test_session.add(extraction)
await test_session.commit()
await test_session.refresh(extraction)
# Admin delete the extraction
response = await authenticated_admin_client.delete(
f"/api/v1/admin/extractions/{extraction.id}",
)
assert response.status_code == 200
data = response.json()
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
# Verify extraction was deleted from database
deleted_extraction = await test_session.get(Extraction, extraction.id)
assert deleted_extraction is None
@pytest.mark.asyncio
async def test_admin_delete_extraction_not_found(self, authenticated_admin_client):
"""Test admin deleting non-existent extraction."""
response = await authenticated_admin_client.delete(
"/api/v1/admin/extractions/999",
)
assert response.status_code == 404
data = response.json()
assert "not found" in data["detail"].lower()
@pytest.mark.asyncio
async def test_admin_delete_extraction_any_user(
self,
authenticated_admin_client,
test_session,
test_plan,
):
"""Test admin deleting extraction owned by any user."""
# Create another user and their extraction
other_user = User(
name="Other User",
email="other@example.com",
is_active=True,
plan_id=test_plan.id,
)
test_session.add(other_user)
await test_session.commit()
await test_session.refresh(other_user)
extraction = Extraction(
url="https://example.com/video",
user_id=other_user.id,
status="completed",
)
test_session.add(extraction)
await test_session.commit()
await test_session.refresh(extraction)
# Admin can delete any user's extraction
response = await authenticated_admin_client.delete(
f"/api/v1/admin/extractions/{extraction.id}",
)
assert response.status_code == 200
data = response.json()
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
@pytest.mark.asyncio
async def test_delete_extraction_non_admin(self, authenticated_client, test_user, test_session):
"""Test non-admin user cannot access admin deletion endpoint."""
# Create test extraction
extraction = Extraction(
url="https://example.com/video",
user_id=test_user.id,
status="completed",
)
test_session.add(extraction)
await test_session.commit()
await test_session.refresh(extraction)
# Non-admin user cannot access admin endpoint
response = await authenticated_client.delete(
f"/api/v1/admin/extractions/{extraction.id}",
)
assert response.status_code == 403
data = response.json()
assert "permissions" in data["detail"].lower()
@pytest.mark.asyncio
async def test_admin_endpoints_unauthenticated(self, client: AsyncClient):
"""Test admin endpoints require authentication."""
# Status endpoint
response = await client.get("/api/v1/admin/extractions/status")
assert response.status_code == 401
# Delete endpoint
response = await client.delete("/api/v1/admin/extractions/1")
assert response.status_code == 401

Some files were not shown because too many files have changed in this diff Show More