Compare commits
4 Commits
e63c7a0767
...
7128ca727b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7128ca727b | ||
|
|
1b597f4047 | ||
|
|
703212656f | ||
|
|
171dbb9b63 |
@@ -7,6 +7,7 @@ from flask_jwt_extended import JWTManager
|
||||
|
||||
from app.database import init_db
|
||||
from app.services.auth_service import AuthService
|
||||
from app.services.scheduler_service import scheduler_service
|
||||
|
||||
# Global auth service instance
|
||||
auth_service = AuthService()
|
||||
@@ -26,7 +27,7 @@ def create_app():
|
||||
|
||||
# Configure Flask-JWT-Extended
|
||||
app.config["JWT_SECRET_KEY"] = os.environ.get(
|
||||
"JWT_SECRET_KEY", "jwt-secret-key"
|
||||
"JWT_SECRET_KEY", "jwt-secret-key",
|
||||
)
|
||||
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
|
||||
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7)
|
||||
@@ -60,10 +61,23 @@ def create_app():
|
||||
# Initialize authentication service with app
|
||||
auth_service.init_app(app)
|
||||
|
||||
# Initialize scheduler service with app
|
||||
scheduler_service.app = app
|
||||
|
||||
# Start scheduler for background tasks
|
||||
scheduler_service.start()
|
||||
|
||||
# Register blueprints
|
||||
from app.routes import auth, main
|
||||
|
||||
app.register_blueprint(main.bp, url_prefix="/api")
|
||||
app.register_blueprint(auth.bp, url_prefix="/api/auth")
|
||||
|
||||
# Shutdown scheduler when app is torn down
|
||||
@app.teardown_appcontext
|
||||
def shutdown_scheduler(exception):
|
||||
"""Stop scheduler when app context is torn down."""
|
||||
if exception:
|
||||
scheduler_service.stop()
|
||||
|
||||
return app
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Database configuration and initialization."""
|
||||
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
@@ -11,8 +11,8 @@ def init_db(app):
|
||||
"""Initialize database with Flask app."""
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
|
||||
|
||||
# Import models here to ensure they are registered with SQLAlchemy
|
||||
from app.models import user, user_oauth # noqa: F401
|
||||
|
||||
return db
|
||||
|
||||
return db
|
||||
|
||||
@@ -8,10 +8,10 @@ def init_database():
|
||||
"""Initialize database tables and seed with default data."""
|
||||
# Create all tables
|
||||
db.create_all()
|
||||
|
||||
|
||||
# Seed plans if they don't exist
|
||||
seed_plans()
|
||||
|
||||
|
||||
# Migrate existing users to have plans
|
||||
migrate_users_to_plans()
|
||||
|
||||
@@ -21,7 +21,7 @@ def seed_plans():
|
||||
# Check if plans already exist
|
||||
if Plan.query.count() > 0:
|
||||
return
|
||||
|
||||
|
||||
# Create default plans
|
||||
plans_data = [
|
||||
{
|
||||
@@ -46,11 +46,11 @@ def seed_plans():
|
||||
"max_credits": 300,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
for plan_data in plans_data:
|
||||
plan = Plan(**plan_data)
|
||||
db.session.add(plan)
|
||||
|
||||
|
||||
db.session.commit()
|
||||
print(f"Seeded {len(plans_data)} plans into database")
|
||||
|
||||
@@ -58,34 +58,34 @@ def seed_plans():
|
||||
def migrate_users_to_plans():
|
||||
"""Assign plans to existing users who don't have one."""
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
try:
|
||||
# Find users without plans
|
||||
users_without_plans = User.query.filter(User.plan_id.is_(None)).all()
|
||||
|
||||
|
||||
# Find users with plans but NULL credits (only if credits column exists)
|
||||
# Note: We only migrate users with NULL credits, not 0 credits
|
||||
# 0 credits means they spent them, NULL means they never got assigned
|
||||
try:
|
||||
users_without_credits = User.query.filter(
|
||||
User.plan_id.isnot(None), User.credits.is_(None)
|
||||
User.plan_id.isnot(None), User.credits.is_(None),
|
||||
).all()
|
||||
except Exception:
|
||||
# Credits column doesn't exist yet, will be handled by create_all
|
||||
users_without_credits = []
|
||||
|
||||
|
||||
if not users_without_plans and not users_without_credits:
|
||||
return
|
||||
|
||||
|
||||
# Get default and pro plans
|
||||
default_plan = Plan.get_default_plan()
|
||||
pro_plan = Plan.get_pro_plan()
|
||||
|
||||
|
||||
# Get the first user (admin) from all users ordered by ID
|
||||
first_user = User.query.order_by(User.id).first()
|
||||
|
||||
|
||||
updated_count = 0
|
||||
|
||||
|
||||
# Assign plans to users without plans
|
||||
for user in users_without_plans:
|
||||
# First user gets pro plan, others get free plan
|
||||
@@ -104,17 +104,19 @@ def migrate_users_to_plans():
|
||||
except Exception:
|
||||
pass
|
||||
updated_count += 1
|
||||
|
||||
|
||||
# Assign credits to users with plans but no credits
|
||||
for user in users_without_credits:
|
||||
user.credits = user.plan.credits
|
||||
updated_count += 1
|
||||
|
||||
|
||||
if updated_count > 0:
|
||||
db.session.commit()
|
||||
print(f"Updated {updated_count} existing users with plans and credits")
|
||||
|
||||
print(
|
||||
f"Updated {updated_count} existing users with plans and credits",
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# If there's any error (like missing columns), just skip migration
|
||||
# The database will be properly created by create_all()
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Database models."""
|
||||
|
||||
from .plan import Plan
|
||||
from .sound import Sound
|
||||
from .user import User
|
||||
from .user_oauth import UserOAuth
|
||||
|
||||
__all__ = ["Plan", "User", "UserOAuth"]
|
||||
__all__ = ["Plan", "Sound", "User", "UserOAuth"]
|
||||
|
||||
@@ -55,4 +55,4 @@ class Plan(db.Model):
|
||||
"description": self.description,
|
||||
"credits": self.credits,
|
||||
"max_credits": self.max_credits,
|
||||
}
|
||||
}
|
||||
|
||||
225
app/models/sound.py
Normal file
225
app/models/sound.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Sound model for storing sound file information."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from app.database import db
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class SoundType(Enum):
|
||||
"""Sound type enumeration."""
|
||||
|
||||
SDB = "SDB" # Soundboard sound
|
||||
SAY = "SAY" # Text-to-speech
|
||||
STR = "STR" # Stream sound
|
||||
|
||||
|
||||
class Sound(db.Model):
|
||||
"""Sound model for storing sound file information."""
|
||||
|
||||
__tablename__ = "sounds"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
# Sound type (SDB, SAY, or STR)
|
||||
type: Mapped[str] = mapped_column(String(3), nullable=False)
|
||||
|
||||
# Basic sound information
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
filename: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
duration: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
size: Mapped[int] = mapped_column(Integer, nullable=False) # Size in bytes
|
||||
hash: Mapped[str] = mapped_column(String(64), nullable=False) # SHA256 hash
|
||||
|
||||
# Normalized sound information
|
||||
normalized_filename: Mapped[str | None] = mapped_column(
|
||||
String(500),
|
||||
nullable=True,
|
||||
)
|
||||
normalized_duration: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
nullable=True,
|
||||
)
|
||||
normalized_size: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
nullable=True,
|
||||
)
|
||||
normalized_hash: Mapped[str | None] = mapped_column(
|
||||
String(64),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Sound properties
|
||||
is_normalized: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
)
|
||||
is_music: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
)
|
||||
is_deletable: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
)
|
||||
|
||||
# Usage tracking
|
||||
play_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of Sound."""
|
||||
return f"<Sound {self.name} ({self.type}) - {self.play_count} plays>"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert sound to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.type,
|
||||
"name": self.name,
|
||||
"filename": self.filename,
|
||||
"duration": self.duration,
|
||||
"size": self.size,
|
||||
"hash": self.hash,
|
||||
"normalized_filename": self.normalized_filename,
|
||||
"normalized_duration": self.normalized_duration,
|
||||
"normalized_size": self.normalized_size,
|
||||
"normalized_hash": self.normalized_hash,
|
||||
"is_normalized": self.is_normalized,
|
||||
"is_music": self.is_music,
|
||||
"is_deletable": self.is_deletable,
|
||||
"play_count": self.play_count,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
def increment_play_count(self) -> None:
|
||||
"""Increment the play count for this sound."""
|
||||
self.play_count += 1
|
||||
self.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def set_normalized_info(
|
||||
self,
|
||||
normalized_filename: str,
|
||||
normalized_duration: float,
|
||||
normalized_size: int,
|
||||
normalized_hash: str,
|
||||
) -> None:
|
||||
"""Set normalized sound information."""
|
||||
self.normalized_filename = normalized_filename
|
||||
self.normalized_duration = normalized_duration
|
||||
self.normalized_size = normalized_size
|
||||
self.normalized_hash = normalized_hash
|
||||
self.is_normalized = True
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def clear_normalized_info(self) -> None:
|
||||
"""Clear normalized sound information."""
|
||||
self.normalized_filename = None
|
||||
self.normalized_duration = None
|
||||
self.normalized_hash = None
|
||||
self.normalized_size = None
|
||||
self.is_normalized = False
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def update_file_info(
|
||||
self,
|
||||
filename: str,
|
||||
duration: float,
|
||||
size: int,
|
||||
hash_value: str,
|
||||
) -> None:
|
||||
"""Update file information for existing sound."""
|
||||
self.filename = filename
|
||||
self.duration = duration
|
||||
self.size = size
|
||||
self.hash = hash_value
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
@classmethod
|
||||
def find_by_hash(cls, hash_value: str) -> Optional["Sound"]:
|
||||
"""Find sound by hash."""
|
||||
return cls.query.filter_by(hash=hash_value).first()
|
||||
|
||||
@classmethod
|
||||
def find_by_name(cls, name: str) -> Optional["Sound"]:
|
||||
"""Find sound by name."""
|
||||
return cls.query.filter_by(name=name).first()
|
||||
|
||||
@classmethod
|
||||
def find_by_filename(cls, filename: str) -> Optional["Sound"]:
|
||||
"""Find sound by filename."""
|
||||
return cls.query.filter_by(filename=filename).first()
|
||||
|
||||
@classmethod
|
||||
def find_by_type(cls, sound_type: str) -> list["Sound"]:
|
||||
"""Find all sounds by type."""
|
||||
return cls.query.filter_by(type=sound_type).all()
|
||||
|
||||
@classmethod
|
||||
def get_most_played(cls, limit: int = 10) -> list["Sound"]:
|
||||
"""Get the most played sounds."""
|
||||
return cls.query.order_by(cls.play_count.desc()).limit(limit).all()
|
||||
|
||||
@classmethod
|
||||
def get_music_sounds(cls) -> list["Sound"]:
|
||||
"""Get all music sounds."""
|
||||
return cls.query.filter_by(is_music=True).all()
|
||||
|
||||
@classmethod
|
||||
def get_deletable_sounds(cls) -> list["Sound"]:
|
||||
"""Get all deletable sounds."""
|
||||
return cls.query.filter_by(is_deletable=True).all()
|
||||
|
||||
@classmethod
|
||||
def create_sound(
|
||||
cls,
|
||||
sound_type: str,
|
||||
name: str,
|
||||
filename: str,
|
||||
duration: float,
|
||||
size: int,
|
||||
hash_value: str,
|
||||
is_music: bool = False,
|
||||
is_deletable: bool = True,
|
||||
commit: bool = True,
|
||||
) -> "Sound":
|
||||
"""Create a new sound."""
|
||||
# Validate sound type
|
||||
if sound_type not in [t.value for t in SoundType]:
|
||||
raise ValueError(f"Invalid sound type: {sound_type}")
|
||||
|
||||
sound = cls(
|
||||
type=sound_type,
|
||||
name=name,
|
||||
filename=filename,
|
||||
duration=duration,
|
||||
size=size,
|
||||
hash=hash_value,
|
||||
is_music=is_music,
|
||||
is_deletable=is_deletable,
|
||||
)
|
||||
|
||||
db.session.add(sound)
|
||||
if commit:
|
||||
db.session.commit()
|
||||
return sound
|
||||
@@ -2,69 +2,80 @@
|
||||
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
from sqlalchemy import String, DateTime, Integer, ForeignKey
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
from app.database import db
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user_oauth import UserOAuth
|
||||
from app.models.plan import Plan
|
||||
from app.models.user_oauth import UserOAuth
|
||||
|
||||
|
||||
class User(db.Model):
|
||||
"""User model for storing user information."""
|
||||
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
|
||||
# Primary user information (can be updated from any connected provider)
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
||||
|
||||
picture: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
|
||||
# Password authentication (optional - users can use OAuth instead)
|
||||
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
|
||||
password_hash: Mapped[str | None] = mapped_column(
|
||||
String(255), nullable=True,
|
||||
)
|
||||
|
||||
# Role-based access control
|
||||
role: Mapped[str] = mapped_column(String(50), nullable=False, default="user")
|
||||
|
||||
role: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False, default="user",
|
||||
)
|
||||
|
||||
# User status
|
||||
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
||||
|
||||
|
||||
# Plan relationship
|
||||
plan_id: Mapped[int] = mapped_column(Integer, ForeignKey("plans.id"), nullable=False)
|
||||
|
||||
plan_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("plans.id"), nullable=False,
|
||||
)
|
||||
|
||||
# User credits (populated from plan credits on creation)
|
||||
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
|
||||
# API token for programmatic access
|
||||
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
api_token: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
api_token_expires_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime, nullable=True,
|
||||
)
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
DateTime, default=datetime.utcnow, nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
# Relationships
|
||||
oauth_providers: Mapped[list["UserOAuth"]] = relationship(
|
||||
"UserOAuth", back_populates="user", cascade="all, delete-orphan"
|
||||
"UserOAuth", back_populates="user", cascade="all, delete-orphan",
|
||||
)
|
||||
plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of User."""
|
||||
provider_count = len(self.oauth_providers)
|
||||
return f"<User {self.email} ({provider_count} providers)>"
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert user to dictionary."""
|
||||
# Build comprehensive providers list
|
||||
@@ -73,7 +84,7 @@ class User(db.Model):
|
||||
providers.append("password")
|
||||
if self.api_token:
|
||||
providers.append("api_token")
|
||||
|
||||
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"email": self.email,
|
||||
@@ -82,25 +93,27 @@ class User(db.Model):
|
||||
"role": self.role,
|
||||
"is_active": self.is_active,
|
||||
"api_token": self.api_token,
|
||||
"api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None,
|
||||
"api_token_expires_at": self.api_token_expires_at.isoformat()
|
||||
if self.api_token_expires_at
|
||||
else None,
|
||||
"providers": providers,
|
||||
"plan": self.plan.to_dict() if self.plan else None,
|
||||
"credits": self.credits,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
def get_provider(self, provider_name: str) -> Optional["UserOAuth"]:
|
||||
"""Get specific OAuth provider for this user."""
|
||||
for provider in self.oauth_providers:
|
||||
if provider.provider == provider_name:
|
||||
return provider
|
||||
return None
|
||||
|
||||
|
||||
def has_provider(self, provider_name: str) -> bool:
|
||||
"""Check if user has specific OAuth provider connected."""
|
||||
return self.get_provider(provider_name) is not None
|
||||
|
||||
|
||||
def update_from_provider(self, provider_data: dict) -> None:
|
||||
"""Update user info from provider data (email, name, picture)."""
|
||||
self.email = provider_data.get("email", self.email)
|
||||
@@ -108,60 +121,60 @@ class User(db.Model):
|
||||
self.picture = provider_data.get("picture", self.picture)
|
||||
self.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def set_password(self, password: str) -> None:
|
||||
"""Hash and set user password."""
|
||||
self.password_hash = generate_password_hash(password)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def check_password(self, password: str) -> bool:
|
||||
"""Check if provided password matches user's password."""
|
||||
if not self.password_hash:
|
||||
return False
|
||||
return check_password_hash(self.password_hash, password)
|
||||
|
||||
|
||||
def has_password(self) -> bool:
|
||||
"""Check if user has a password set."""
|
||||
return self.password_hash is not None
|
||||
|
||||
|
||||
def generate_api_token(self) -> str:
|
||||
"""Generate a new API token for the user."""
|
||||
self.api_token = secrets.token_urlsafe(32)
|
||||
self.api_token_expires_at = None # No expiration by default
|
||||
self.updated_at = datetime.utcnow()
|
||||
return self.api_token
|
||||
|
||||
|
||||
def is_api_token_valid(self) -> bool:
|
||||
"""Check if the user's API token is valid (exists and not expired)."""
|
||||
if not self.api_token:
|
||||
return False
|
||||
|
||||
|
||||
if self.api_token_expires_at is None:
|
||||
return True # No expiration
|
||||
|
||||
|
||||
return datetime.utcnow() < self.api_token_expires_at
|
||||
|
||||
|
||||
def revoke_api_token(self) -> None:
|
||||
"""Revoke the user's API token."""
|
||||
self.api_token = None
|
||||
self.api_token_expires_at = None
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def activate(self) -> None:
|
||||
"""Activate the user account."""
|
||||
self.is_active = True
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def deactivate(self) -> None:
|
||||
"""Deactivate the user account."""
|
||||
self.is_active = False
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
@classmethod
|
||||
def find_by_email(cls, email: str) -> Optional["User"]:
|
||||
"""Find user by email address."""
|
||||
return cls.query.filter_by(email=email).first()
|
||||
|
||||
|
||||
@classmethod
|
||||
def find_by_api_token(cls, api_token: str) -> Optional["User"]:
|
||||
"""Find user by API token if token is valid."""
|
||||
@@ -169,18 +182,25 @@ class User(db.Model):
|
||||
if user and user.is_api_token_valid():
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
def find_or_create_from_oauth(
|
||||
cls, provider: str, provider_id: str, email: str, name: str, picture: Optional[str] = None
|
||||
cls,
|
||||
provider: str,
|
||||
provider_id: str,
|
||||
email: str,
|
||||
name: str,
|
||||
picture: str | None = None,
|
||||
) -> tuple["User", "UserOAuth"]:
|
||||
"""Find existing user or create new one from OAuth data."""
|
||||
from app.models.user_oauth import UserOAuth
|
||||
from app.models.plan import Plan
|
||||
|
||||
from app.models.user_oauth import UserOAuth
|
||||
|
||||
# First, try to find existing OAuth provider
|
||||
oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id)
|
||||
|
||||
oauth_provider = UserOAuth.find_by_provider_and_id(
|
||||
provider, provider_id,
|
||||
)
|
||||
|
||||
if oauth_provider:
|
||||
# Update existing provider and user info
|
||||
user = oauth_provider.user
|
||||
@@ -188,24 +208,26 @@ class User(db.Model):
|
||||
oauth_provider.name = name
|
||||
oauth_provider.picture = picture
|
||||
oauth_provider.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
# Update user info with latest data
|
||||
user.update_from_provider({"email": email, "name": name, "picture": picture})
|
||||
user.update_from_provider(
|
||||
{"email": email, "name": name, "picture": picture},
|
||||
)
|
||||
else:
|
||||
# Try to find user by email to link the new provider
|
||||
user = cls.find_by_email(email)
|
||||
|
||||
|
||||
if not user:
|
||||
# Check if this is the first user (admin with pro plan)
|
||||
user_count = cls.query.count()
|
||||
role = "admin" if user_count == 0 else "user"
|
||||
|
||||
|
||||
# Assign plan: first user gets pro, others get free
|
||||
if user_count == 0:
|
||||
plan = Plan.get_pro_plan()
|
||||
else:
|
||||
plan = Plan.get_default_plan()
|
||||
|
||||
|
||||
# Create new user
|
||||
user = cls(
|
||||
email=email,
|
||||
@@ -218,7 +240,7 @@ class User(db.Model):
|
||||
user.generate_api_token() # Generate API token on creation
|
||||
db.session.add(user)
|
||||
db.session.flush() # Flush to get user.id
|
||||
|
||||
|
||||
# Create new OAuth provider
|
||||
oauth_provider = UserOAuth.create_or_update(
|
||||
user_id=user.id,
|
||||
@@ -228,30 +250,32 @@ class User(db.Model):
|
||||
name=name,
|
||||
picture=picture,
|
||||
)
|
||||
|
||||
|
||||
db.session.commit()
|
||||
return user, oauth_provider
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_with_password(cls, email: str, password: str, name: str) -> "User":
|
||||
def create_with_password(
|
||||
cls, email: str, password: str, name: str,
|
||||
) -> "User":
|
||||
"""Create new user with email and password."""
|
||||
from app.models.plan import Plan
|
||||
|
||||
|
||||
# Check if user already exists
|
||||
existing_user = cls.find_by_email(email)
|
||||
if existing_user:
|
||||
raise ValueError("User with this email already exists")
|
||||
|
||||
|
||||
# Check if this is the first user (admin with pro plan)
|
||||
user_count = cls.query.count()
|
||||
role = "admin" if user_count == 0 else "user"
|
||||
|
||||
|
||||
# Assign plan: first user gets pro, others get free
|
||||
if user_count == 0:
|
||||
plan = Plan.get_pro_plan()
|
||||
else:
|
||||
plan = Plan.get_default_plan()
|
||||
|
||||
|
||||
# Create new user
|
||||
user = cls(
|
||||
email=email,
|
||||
@@ -262,15 +286,17 @@ class User(db.Model):
|
||||
)
|
||||
user.set_password(password)
|
||||
user.generate_api_token() # Generate API token on creation
|
||||
|
||||
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@classmethod
|
||||
def authenticate_with_password(cls, email: str, password: str) -> Optional["User"]:
|
||||
def authenticate_with_password(
|
||||
cls, email: str, password: str,
|
||||
) -> Optional["User"]:
|
||||
"""Authenticate user with email and password."""
|
||||
user = cls.find_by_email(email)
|
||||
if user and user.check_password(password) and user.is_active:
|
||||
return user
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""User OAuth model for storing user's connected providers."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import String, DateTime, Text, ForeignKey
|
||||
from sqlalchemy import DateTime, ForeignKey, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import db
|
||||
@@ -14,43 +14,50 @@ if TYPE_CHECKING:
|
||||
|
||||
class UserOAuth(db.Model):
|
||||
"""Model for storing user's connected OAuth providers."""
|
||||
|
||||
|
||||
__tablename__ = "user_oauth"
|
||||
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
|
||||
# User relationship
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||
|
||||
|
||||
# OAuth provider information
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
provider_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
|
||||
# Provider-specific user information
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
picture: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
DateTime, default=datetime.utcnow, nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
# Unique constraint on provider + provider_id combination
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("provider", "provider_id", name="unique_provider_user"),
|
||||
db.UniqueConstraint(
|
||||
"provider", "provider_id", name="unique_provider_user",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship("User", back_populates="oauth_providers")
|
||||
|
||||
user: Mapped["User"] = relationship(
|
||||
"User", back_populates="oauth_providers",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of UserOAuth."""
|
||||
return f"<UserOAuth {self.email} ({self.provider})>"
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert oauth provider to dictionary."""
|
||||
return {
|
||||
@@ -63,25 +70,29 @@ class UserOAuth(db.Model):
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def find_by_provider_and_id(cls, provider: str, provider_id: str) -> Optional["UserOAuth"]:
|
||||
def find_by_provider_and_id(
|
||||
cls, provider: str, provider_id: str,
|
||||
) -> Optional["UserOAuth"]:
|
||||
"""Find OAuth provider by provider name and provider ID."""
|
||||
return cls.query.filter_by(provider=provider, provider_id=provider_id).first()
|
||||
|
||||
return cls.query.filter_by(
|
||||
provider=provider, provider_id=provider_id,
|
||||
).first()
|
||||
|
||||
@classmethod
|
||||
def create_or_update(
|
||||
cls,
|
||||
user_id: int,
|
||||
provider: str,
|
||||
provider_id: str,
|
||||
email: str,
|
||||
name: str,
|
||||
picture: Optional[str] = None
|
||||
cls,
|
||||
user_id: int,
|
||||
provider: str,
|
||||
provider_id: str,
|
||||
email: str,
|
||||
name: str,
|
||||
picture: str | None = None,
|
||||
) -> "UserOAuth":
|
||||
"""Create new OAuth provider or update existing one."""
|
||||
oauth_provider = cls.find_by_provider_and_id(provider, provider_id)
|
||||
|
||||
|
||||
if oauth_provider:
|
||||
# Update existing provider
|
||||
oauth_provider.user_id = user_id
|
||||
@@ -100,6 +111,6 @@ class UserOAuth(db.Model):
|
||||
picture=picture,
|
||||
)
|
||||
db.session.add(oauth_provider)
|
||||
|
||||
|
||||
db.session.commit()
|
||||
return oauth_provider
|
||||
return oauth_provider
|
||||
|
||||
@@ -31,7 +31,7 @@ def callback(provider):
|
||||
# If successful, redirect to frontend dashboard with cookies
|
||||
if auth_response.status_code == 200:
|
||||
redirect_response = make_response(
|
||||
redirect("http://localhost:3000/dashboard")
|
||||
redirect("http://localhost:3000/dashboard"),
|
||||
)
|
||||
|
||||
# Copy all cookies from the auth response
|
||||
@@ -39,9 +39,8 @@ def callback(provider):
|
||||
redirect_response.headers.add("Set-Cookie", cookie)
|
||||
|
||||
return redirect_response
|
||||
else:
|
||||
# If there was an error, redirect to login with error
|
||||
return redirect("http://localhost:3000/login?error=oauth_failed")
|
||||
# If there was an error, redirect to login with error
|
||||
return redirect("http://localhost:3000/login?error=oauth_failed")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e).replace(" ", "_").replace('"', "")
|
||||
@@ -129,7 +128,7 @@ def refresh():
|
||||
def link_provider(provider):
|
||||
"""Link a new OAuth provider to current user account."""
|
||||
redirect_uri = url_for(
|
||||
"auth.link_callback", provider=provider, _external=True
|
||||
"auth.link_callback", provider=provider, _external=True,
|
||||
)
|
||||
return auth_service.redirect_to_login(provider, redirect_uri)
|
||||
|
||||
@@ -168,19 +167,19 @@ def link_callback(provider):
|
||||
|
||||
if not provider_data.get("id"):
|
||||
return {
|
||||
"error": "Failed to get user information from provider"
|
||||
"error": "Failed to get user information from provider",
|
||||
}, 400
|
||||
|
||||
# Check if this provider is already linked to another user
|
||||
from app.models.user_oauth import UserOAuth
|
||||
|
||||
existing_provider = UserOAuth.find_by_provider_and_id(
|
||||
provider, provider_data["id"]
|
||||
provider, provider_data["id"],
|
||||
)
|
||||
|
||||
if existing_provider and existing_provider.user_id != user.id:
|
||||
return {
|
||||
"error": "This provider account is already linked to another user"
|
||||
"error": "This provider account is already linked to another user",
|
||||
}, 409
|
||||
|
||||
# Link the provider to current user
|
||||
@@ -210,7 +209,6 @@ def unlink_provider(provider):
|
||||
|
||||
from app.database import db
|
||||
from app.models.user import User
|
||||
from app.models.user_oauth import UserOAuth
|
||||
|
||||
user = User.query.get(current_user_id)
|
||||
if not user:
|
||||
@@ -224,7 +222,7 @@ def unlink_provider(provider):
|
||||
oauth_provider = user.get_provider(provider)
|
||||
if not oauth_provider:
|
||||
return {
|
||||
"error": f"Provider '{provider}' not linked to this account"
|
||||
"error": f"Provider '{provider}' not linked to this account",
|
||||
}, 404
|
||||
|
||||
db.session.delete(oauth_provider)
|
||||
@@ -279,21 +277,22 @@ def me():
|
||||
def update_profile():
|
||||
"""Update current user profile information."""
|
||||
from flask import request
|
||||
|
||||
from app.database import db
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return {"error": "No data provided"}, 400
|
||||
|
||||
|
||||
user_data = get_current_user()
|
||||
if not user_data:
|
||||
return {"error": "User not authenticated"}, 401
|
||||
|
||||
|
||||
user = User.query.get(int(user_data["id"]))
|
||||
if not user:
|
||||
return {"error": "User not found"}, 404
|
||||
|
||||
|
||||
# Update allowed fields
|
||||
if "name" in data:
|
||||
name = data["name"].strip()
|
||||
@@ -302,10 +301,10 @@ def update_profile():
|
||||
if len(name) > 100:
|
||||
return {"error": "Name too long (max 100 characters)"}, 400
|
||||
user.name = name
|
||||
|
||||
|
||||
try:
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# Return fresh user data from database
|
||||
updated_user = {
|
||||
"id": str(user.id),
|
||||
@@ -319,14 +318,11 @@ def update_profile():
|
||||
"plan": user.plan.to_dict() if user.plan else None,
|
||||
"credits": user.credits,
|
||||
}
|
||||
|
||||
return {
|
||||
"message": "Profile updated successfully",
|
||||
"user": updated_user
|
||||
}
|
||||
|
||||
return {"message": "Profile updated successfully", "user": updated_user}
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return {"error": f"Failed to update profile: {str(e)}"}, 500
|
||||
return {"error": f"Failed to update profile: {e!s}"}, 500
|
||||
|
||||
|
||||
@bp.route("/password", methods=["PUT"])
|
||||
@@ -334,53 +330,54 @@ def update_profile():
|
||||
def change_password():
|
||||
"""Change or set user password."""
|
||||
from flask import request
|
||||
from werkzeug.security import check_password_hash
|
||||
|
||||
from app.database import db
|
||||
from app.models.user import User
|
||||
from werkzeug.security import check_password_hash
|
||||
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return {"error": "No data provided"}, 400
|
||||
|
||||
|
||||
user_data = get_current_user()
|
||||
if not user_data:
|
||||
return {"error": "User not authenticated"}, 401
|
||||
|
||||
|
||||
user = User.query.get(int(user_data["id"]))
|
||||
if not user:
|
||||
return {"error": "User not found"}, 404
|
||||
|
||||
|
||||
new_password = data.get("new_password")
|
||||
current_password = data.get("current_password")
|
||||
|
||||
|
||||
if not new_password:
|
||||
return {"error": "New password is required"}, 400
|
||||
|
||||
|
||||
# Password validation
|
||||
if len(new_password) < 6:
|
||||
return {"error": "Password must be at least 6 characters long"}, 400
|
||||
|
||||
|
||||
# Check authentication method: if user logged in via password, require current password
|
||||
# If user logged in via OAuth, they can change password without current password
|
||||
current_auth_method = user_data.get("provider", "unknown")
|
||||
|
||||
|
||||
if user.password_hash and current_auth_method == "password":
|
||||
# User has a password AND logged in via password, require current password for verification
|
||||
if not current_password:
|
||||
return {"error": "Current password is required to change password"}, 400
|
||||
|
||||
return {
|
||||
"error": "Current password is required to change password",
|
||||
}, 400
|
||||
|
||||
if not check_password_hash(user.password_hash, current_password):
|
||||
return {"error": "Current password is incorrect"}, 400
|
||||
# If user logged in via OAuth (google, github, etc.), they can change password without current password
|
||||
|
||||
|
||||
# Set the new password
|
||||
try:
|
||||
user.set_password(new_password)
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"message": "Password updated successfully"
|
||||
}
|
||||
|
||||
return {"message": "Password updated successfully"}
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return {"error": f"Failed to update password: {str(e)}"}, 500
|
||||
return {"error": f"Failed to update password: {e!s}"}, 500
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
"""Main routes for the application."""
|
||||
|
||||
from flask import Blueprint
|
||||
from flask import Blueprint, request
|
||||
|
||||
from app.services.decorators import get_current_user, require_auth, require_role, require_credits
|
||||
from app.services.decorators import (
|
||||
get_current_user,
|
||||
require_auth,
|
||||
require_credits,
|
||||
require_role,
|
||||
)
|
||||
from app.services.scheduler_service import scheduler_service
|
||||
from app.services.sound_normalizer_service import SoundNormalizerService
|
||||
from app.services.sound_scanner_service import SoundScannerService
|
||||
|
||||
bp = Blueprint("main", __name__)
|
||||
|
||||
@@ -63,7 +71,8 @@ def use_credits(amount: int) -> dict[str, str]:
|
||||
return {
|
||||
"message": f"Successfully used endpoint! You requested amount: {amount}",
|
||||
"user": user["email"],
|
||||
"remaining_credits": user["credits"] - 5, # Note: credits already deducted by decorator
|
||||
"remaining_credits": user["credits"]
|
||||
- 5, # Note: credits already deducted by decorator
|
||||
}
|
||||
|
||||
|
||||
@@ -78,3 +87,71 @@ def expensive_operation() -> dict[str, str]:
|
||||
"user": user["email"],
|
||||
"operation_cost": 10,
|
||||
}
|
||||
|
||||
|
||||
@bp.route("/admin/scheduler/status")
|
||||
@require_auth
|
||||
@require_role("admin")
|
||||
def scheduler_status() -> dict:
|
||||
"""Get scheduler status (admin only)."""
|
||||
return scheduler_service.get_scheduler_status()
|
||||
|
||||
|
||||
@bp.route("/admin/credits/refill", methods=["POST"])
|
||||
@require_auth
|
||||
@require_role("admin")
|
||||
def manual_credit_refill() -> dict:
|
||||
"""Manually trigger credit refill for all users (admin only)."""
|
||||
return scheduler_service.trigger_credit_refill_now()
|
||||
|
||||
|
||||
@bp.route("/admin/sounds/scan", methods=["POST"])
|
||||
@require_auth
|
||||
@require_role("admin")
|
||||
def manual_sound_scan() -> dict:
|
||||
"""Manually trigger sound directory scan (admin only)."""
|
||||
return scheduler_service.trigger_sound_scan_now()
|
||||
|
||||
|
||||
@bp.route("/admin/sounds/stats")
|
||||
@require_auth
|
||||
@require_role("admin")
|
||||
def sound_statistics() -> dict:
|
||||
"""Get sound database statistics (admin only)."""
|
||||
return SoundScannerService.get_scan_statistics()
|
||||
|
||||
|
||||
@bp.route("/admin/sounds/normalize/<int:sound_id>", methods=["POST"])
|
||||
@require_auth
|
||||
@require_role("admin")
|
||||
def normalize_sound(sound_id: int) -> dict:
|
||||
"""Normalize a specific sound file (admin only)."""
|
||||
overwrite = request.args.get("overwrite", "false").lower() == "true"
|
||||
return SoundNormalizerService.normalize_sound(sound_id, overwrite)
|
||||
|
||||
|
||||
@bp.route("/admin/sounds/normalize-all", methods=["POST"])
|
||||
@require_auth
|
||||
@require_role("admin")
|
||||
def normalize_all_sounds() -> dict:
|
||||
"""Normalize all soundboard files (admin only)."""
|
||||
overwrite = request.args.get("overwrite", "false").lower() == "true"
|
||||
limit_str = request.args.get("limit")
|
||||
limit = int(limit_str) if limit_str else None
|
||||
return SoundNormalizerService.normalize_all_sounds(overwrite, limit)
|
||||
|
||||
|
||||
@bp.route("/admin/sounds/normalization-status")
|
||||
@require_auth
|
||||
@require_role("admin")
|
||||
def normalization_status() -> dict:
|
||||
"""Get normalization status statistics (admin only)."""
|
||||
return SoundNormalizerService.get_normalization_status()
|
||||
|
||||
|
||||
@bp.route("/admin/sounds/ffmpeg-check")
|
||||
@require_auth
|
||||
@require_role("admin")
|
||||
def ffmpeg_check() -> dict:
|
||||
"""Check ffmpeg availability and capabilities (admin only)."""
|
||||
return SoundNormalizerService.check_ffmpeg_availability()
|
||||
|
||||
@@ -83,14 +83,16 @@ class AuthService:
|
||||
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login
|
||||
jwt_user_data["provider"] = (
|
||||
oauth_provider.provider
|
||||
) # Override provider for OAuth login
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
jwt_user_data
|
||||
jwt_user_data,
|
||||
)
|
||||
refresh_token = self.token_service.generate_refresh_token(
|
||||
jwt_user_data
|
||||
jwt_user_data,
|
||||
)
|
||||
|
||||
# Create response and set HTTP-only cookies
|
||||
@@ -98,7 +100,7 @@ class AuthService:
|
||||
{
|
||||
"message": "Login successful",
|
||||
"user": jwt_user_data,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Set JWT cookies
|
||||
@@ -147,7 +149,7 @@ class AuthService:
|
||||
return None
|
||||
|
||||
def register_with_password(
|
||||
self, email: str, password: str, name: str
|
||||
self, email: str, password: str, name: str,
|
||||
) -> Any:
|
||||
"""Register new user with email and password."""
|
||||
try:
|
||||
@@ -156,14 +158,16 @@ class AuthService:
|
||||
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password registration
|
||||
jwt_user_data["provider"] = (
|
||||
"password" # Override provider for password registration
|
||||
)
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
jwt_user_data
|
||||
jwt_user_data,
|
||||
)
|
||||
refresh_token = self.token_service.generate_refresh_token(
|
||||
jwt_user_data
|
||||
jwt_user_data,
|
||||
)
|
||||
|
||||
# Create response and set HTTP-only cookies
|
||||
@@ -171,7 +175,7 @@ class AuthService:
|
||||
{
|
||||
"message": "Registration successful",
|
||||
"user": jwt_user_data,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Set JWT cookies
|
||||
@@ -192,14 +196,16 @@ class AuthService:
|
||||
|
||||
if not user:
|
||||
response = jsonify(
|
||||
{"error": "Invalid email, password or disabled account"}
|
||||
{"error": "Invalid email, password or disabled account"},
|
||||
)
|
||||
response.status_code = 401
|
||||
return response
|
||||
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password login
|
||||
jwt_user_data["provider"] = (
|
||||
"password" # Override provider for password login
|
||||
)
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(jwt_user_data)
|
||||
@@ -210,7 +216,7 @@ class AuthService:
|
||||
{
|
||||
"message": "Login successful",
|
||||
"user": jwt_user_data,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Set JWT cookies
|
||||
|
||||
133
app/services/credit_service.py
Normal file
133
app/services/credit_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Credit management service for handling daily credit refills."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.database import db
|
||||
from app.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreditService:
|
||||
"""Service for managing user credits and daily refills."""
|
||||
|
||||
@staticmethod
|
||||
def refill_all_users_credits() -> dict:
|
||||
"""Refill credits for all active users based on their plan.
|
||||
|
||||
This function:
|
||||
1. Gets all active users
|
||||
2. For each user, adds their plan's daily credit amount
|
||||
3. Ensures credits never exceed the plan's max_credits limit
|
||||
4. Updates all users in a single database transaction
|
||||
|
||||
Returns:
|
||||
dict: Summary of the refill operation
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get all active users with their plans
|
||||
users = User.query.filter_by(is_active=True).all()
|
||||
|
||||
if not users:
|
||||
logger.info("No active users found for credit refill")
|
||||
return {
|
||||
"success": True,
|
||||
"users_processed": 0,
|
||||
"credits_added": 0,
|
||||
"message": "No active users found",
|
||||
}
|
||||
|
||||
users_processed = 0
|
||||
total_credits_added = 0
|
||||
|
||||
for user in users:
|
||||
if not user.plan:
|
||||
logger.warning(f"User {user.email} has no plan assigned, skipping")
|
||||
continue
|
||||
|
||||
# Calculate new credit amount, capped at plan max
|
||||
current_credits = user.credits or 0
|
||||
plan_daily_credits = user.plan.credits
|
||||
max_credits = user.plan.max_credits
|
||||
|
||||
# Add daily credits but don't exceed maximum
|
||||
new_credits = min(current_credits + plan_daily_credits, max_credits)
|
||||
credits_added = new_credits - current_credits
|
||||
|
||||
if credits_added > 0:
|
||||
user.credits = new_credits
|
||||
user.updated_at = datetime.utcnow()
|
||||
total_credits_added += credits_added
|
||||
|
||||
logger.debug(
|
||||
f"User {user.email}: {current_credits} -> {new_credits} "
|
||||
f"(+{credits_added} credits, plan: {user.plan.code})",
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"User {user.email}: Already at max credits "
|
||||
f"({current_credits}/{max_credits})",
|
||||
)
|
||||
|
||||
users_processed += 1
|
||||
|
||||
# Commit all changes in a single transaction
|
||||
db.session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Daily credit refill completed: {users_processed} users processed, "
|
||||
f"{total_credits_added} total credits added",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"users_processed": users_processed,
|
||||
"credits_added": total_credits_added,
|
||||
"message": f"Successfully refilled credits for {users_processed} users",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Rollback transaction on error
|
||||
db.session.rollback()
|
||||
logger.error(f"Error during daily credit refill: {e!s}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"users_processed": 0,
|
||||
"credits_added": 0,
|
||||
"error": str(e),
|
||||
"message": "Credit refill failed",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_user_credit_info(user_id: int) -> dict:
|
||||
"""Get detailed credit information for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
dict: User's credit information
|
||||
|
||||
"""
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return {"error": "User not found"}
|
||||
|
||||
if not user.plan:
|
||||
return {"error": "User has no plan assigned"}
|
||||
|
||||
return {
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"current_credits": user.credits,
|
||||
"plan": {
|
||||
"code": user.plan.code,
|
||||
"name": user.plan.name,
|
||||
"daily_credits": user.plan.credits,
|
||||
"max_credits": user.plan.max_credits,
|
||||
},
|
||||
"is_active": user.is_active,
|
||||
}
|
||||
@@ -4,7 +4,7 @@ from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask_jwt_extended import get_jwt, get_jwt_identity, verify_jwt_in_request
|
||||
from flask_jwt_extended import get_jwt_identity, verify_jwt_in_request
|
||||
|
||||
|
||||
def get_user_from_jwt() -> dict[str, Any] | None:
|
||||
@@ -12,14 +12,14 @@ def get_user_from_jwt() -> dict[str, Any] | None:
|
||||
try:
|
||||
# Try to verify JWT token in request - this sets up the context
|
||||
verify_jwt_in_request()
|
||||
|
||||
|
||||
current_user_id = get_jwt_identity()
|
||||
if not current_user_id:
|
||||
return None
|
||||
|
||||
# Query database for user data instead of using JWT claims
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
user = User.query.get(int(current_user_id))
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
@@ -70,7 +70,7 @@ def get_user_from_api_token() -> dict[str, Any] | None:
|
||||
providers.append("password")
|
||||
if user.api_token:
|
||||
providers.append("api_token")
|
||||
|
||||
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
@@ -109,7 +109,7 @@ def require_auth(f):
|
||||
if not user:
|
||||
return (
|
||||
jsonify(
|
||||
{"error": "Authentication required (JWT or API token)"}
|
||||
{"error": "Authentication required (JWT or API token)"},
|
||||
),
|
||||
401,
|
||||
)
|
||||
@@ -133,8 +133,8 @@ def require_role(required_role: str):
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": f"Access denied. {required_role.title()} role required"
|
||||
}
|
||||
"error": f"Access denied. {required_role.title()} role required",
|
||||
},
|
||||
),
|
||||
403,
|
||||
)
|
||||
@@ -148,41 +148,43 @@ def require_role(required_role: str):
|
||||
|
||||
def require_credits(credits_needed: int):
|
||||
"""Decorator to require and deduct credits for routes."""
|
||||
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
from app.models.user import User
|
||||
from app.database import db
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
# First check authentication
|
||||
user_data = get_current_user()
|
||||
if not user_data:
|
||||
return jsonify({"error": "Authentication required"}), 401
|
||||
|
||||
|
||||
# Get the actual user from database to check/update credits
|
||||
user = User.query.get(int(user_data["id"]))
|
||||
if not user or not user.is_active:
|
||||
return jsonify({"error": "User not found or inactive"}), 401
|
||||
|
||||
|
||||
# Check if user has enough credits
|
||||
if user.credits < credits_needed:
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}"
|
||||
}
|
||||
"error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}",
|
||||
},
|
||||
),
|
||||
402, # Payment Required status code
|
||||
)
|
||||
|
||||
|
||||
# Deduct credits
|
||||
user.credits -= credits_needed
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# Execute the function
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,39 +1,36 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from authlib.integrations.flask_client import OAuth
|
||||
|
||||
|
||||
class OAuthProvider(ABC):
|
||||
"""Abstract base class for OAuth providers."""
|
||||
|
||||
|
||||
def __init__(self, oauth: OAuth, client_id: str, client_secret: str):
|
||||
self.oauth = oauth
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self._client = None
|
||||
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name (e.g., 'google', 'github')."""
|
||||
pass
|
||||
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable provider name (e.g., 'Google', 'GitHub')."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
def get_client_config(self) -> dict[str, Any]:
|
||||
"""Return OAuth client configuration."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract user information from OAuth token response."""
|
||||
pass
|
||||
|
||||
|
||||
def get_client(self):
|
||||
"""Get or create OAuth client."""
|
||||
if self._client is None:
|
||||
@@ -42,27 +39,29 @@ class OAuthProvider(ABC):
|
||||
name=self.name,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
**config
|
||||
**config,
|
||||
)
|
||||
return self._client
|
||||
|
||||
|
||||
def get_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Generate authorization URL for OAuth flow."""
|
||||
client = self.get_client()
|
||||
return client.authorize_redirect(redirect_uri).location
|
||||
|
||||
def exchange_code_for_token(self, code: str = None, redirect_uri: str = None) -> Dict[str, Any]:
|
||||
|
||||
def exchange_code_for_token(
|
||||
self, code: str = None, redirect_uri: str = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token."""
|
||||
client = self.get_client()
|
||||
token = client.authorize_access_token()
|
||||
return token
|
||||
|
||||
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
def normalize_user_data(self, user_info: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalize user data to common format."""
|
||||
return {
|
||||
'id': user_info.get('id'),
|
||||
'email': user_info.get('email'),
|
||||
'name': user_info.get('name'),
|
||||
'picture': user_info.get('picture'),
|
||||
'provider': self.name
|
||||
}
|
||||
"id": user_info.get("id"),
|
||||
"email": user_info.get("email"),
|
||||
"name": user_info.get("name"),
|
||||
"picture": user_info.get("picture"),
|
||||
"provider": self.name,
|
||||
}
|
||||
|
||||
@@ -1,52 +1,51 @@
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from .base import OAuthProvider
|
||||
|
||||
|
||||
class GitHubOAuthProvider(OAuthProvider):
|
||||
"""GitHub OAuth provider implementation."""
|
||||
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return 'github'
|
||||
|
||||
return "github"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return 'GitHub'
|
||||
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
return "GitHub"
|
||||
|
||||
def get_client_config(self) -> dict[str, Any]:
|
||||
"""Return GitHub OAuth client configuration."""
|
||||
return {
|
||||
'access_token_url': 'https://github.com/login/oauth/access_token',
|
||||
'authorize_url': 'https://github.com/login/oauth/authorize',
|
||||
'api_base_url': 'https://api.github.com/',
|
||||
'client_kwargs': {
|
||||
'scope': 'user:email'
|
||||
}
|
||||
"access_token_url": "https://github.com/login/oauth/access_token",
|
||||
"authorize_url": "https://github.com/login/oauth/authorize",
|
||||
"api_base_url": "https://api.github.com/",
|
||||
"client_kwargs": {"scope": "user:email"},
|
||||
}
|
||||
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract user information from GitHub OAuth token response."""
|
||||
client = self.get_client()
|
||||
|
||||
|
||||
# Get user profile
|
||||
user_resp = client.get('user', token=token)
|
||||
user_resp = client.get("user", token=token)
|
||||
user_data = user_resp.json()
|
||||
|
||||
|
||||
# Get user email (may be private)
|
||||
email = user_data.get('email')
|
||||
email = user_data.get("email")
|
||||
if not email:
|
||||
# If email is private, get from emails endpoint
|
||||
emails_resp = client.get('user/emails', token=token)
|
||||
emails_resp = client.get("user/emails", token=token)
|
||||
emails = emails_resp.json()
|
||||
# Find primary email
|
||||
for email_obj in emails:
|
||||
if email_obj.get('primary', False):
|
||||
email = email_obj.get('email')
|
||||
if email_obj.get("primary", False):
|
||||
email = email_obj.get("email")
|
||||
break
|
||||
|
||||
|
||||
return {
|
||||
'id': str(user_data.get('id')),
|
||||
'email': email,
|
||||
'name': user_data.get('name') or user_data.get('login'),
|
||||
'picture': user_data.get('avatar_url')
|
||||
}
|
||||
"id": str(user_data.get("id")),
|
||||
"email": email,
|
||||
"name": user_data.get("name") or user_data.get("login"),
|
||||
"picture": user_data.get("avatar_url"),
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .base import OAuthProvider
|
||||
|
||||
@@ -14,14 +14,14 @@ class GoogleOAuthProvider(OAuthProvider):
|
||||
def display_name(self) -> str:
|
||||
return "Google"
|
||||
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
def get_client_config(self) -> dict[str, Any]:
|
||||
"""Return Google OAuth client configuration."""
|
||||
return {
|
||||
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
|
||||
"client_kwargs": {"scope": "openid email profile"},
|
||||
}
|
||||
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract user information from Google OAuth token response."""
|
||||
client = self.get_client()
|
||||
user_info = client.userinfo(token=token)
|
||||
|
||||
@@ -1,45 +1,46 @@
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
from authlib.integrations.flask_client import OAuth
|
||||
|
||||
from .base import OAuthProvider
|
||||
from .google import GoogleOAuthProvider
|
||||
from .github import GitHubOAuthProvider
|
||||
from .google import GoogleOAuthProvider
|
||||
|
||||
|
||||
class OAuthProviderRegistry:
|
||||
"""Registry for OAuth providers."""
|
||||
|
||||
|
||||
def __init__(self, oauth: OAuth):
|
||||
self.oauth = oauth
|
||||
self._providers: Dict[str, OAuthProvider] = {}
|
||||
self._providers: dict[str, OAuthProvider] = {}
|
||||
self._initialize_providers()
|
||||
|
||||
|
||||
def _initialize_providers(self):
|
||||
"""Initialize available providers based on environment variables."""
|
||||
# Google OAuth
|
||||
google_client_id = os.getenv('GOOGLE_CLIENT_ID')
|
||||
google_client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID")
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
|
||||
if google_client_id and google_client_secret:
|
||||
self._providers['google'] = GoogleOAuthProvider(
|
||||
self.oauth, google_client_id, google_client_secret
|
||||
self._providers["google"] = GoogleOAuthProvider(
|
||||
self.oauth, google_client_id, google_client_secret,
|
||||
)
|
||||
|
||||
|
||||
# GitHub OAuth
|
||||
github_client_id = os.getenv('GITHUB_CLIENT_ID')
|
||||
github_client_secret = os.getenv('GITHUB_CLIENT_SECRET')
|
||||
github_client_id = os.getenv("GITHUB_CLIENT_ID")
|
||||
github_client_secret = os.getenv("GITHUB_CLIENT_SECRET")
|
||||
if github_client_id and github_client_secret:
|
||||
self._providers['github'] = GitHubOAuthProvider(
|
||||
self.oauth, github_client_id, github_client_secret
|
||||
self._providers["github"] = GitHubOAuthProvider(
|
||||
self.oauth, github_client_id, github_client_secret,
|
||||
)
|
||||
|
||||
def get_provider(self, name: str) -> Optional[OAuthProvider]:
|
||||
|
||||
def get_provider(self, name: str) -> OAuthProvider | None:
|
||||
"""Get OAuth provider by name."""
|
||||
return self._providers.get(name)
|
||||
|
||||
def get_available_providers(self) -> Dict[str, OAuthProvider]:
|
||||
|
||||
def get_available_providers(self) -> dict[str, OAuthProvider]:
|
||||
"""Get all available providers."""
|
||||
return self._providers.copy()
|
||||
|
||||
|
||||
def is_provider_available(self, name: str) -> bool:
|
||||
"""Check if provider is available."""
|
||||
return name in self._providers
|
||||
return name in self._providers
|
||||
|
||||
164
app/services/scheduler_service.py
Normal file
164
app/services/scheduler_service.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Scheduler service for managing background tasks with APScheduler."""
|
||||
|
||||
import logging
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from flask import current_app
|
||||
|
||||
from app.services.credit_service import CreditService
|
||||
from app.services.sound_scanner_service import SoundScannerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""Service for managing scheduled background tasks."""
|
||||
|
||||
def __init__(self, app=None) -> None:
|
||||
"""Initialize the scheduler service."""
|
||||
self.scheduler: BackgroundScheduler | None = None
|
||||
self.app = app
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler and add all scheduled jobs."""
|
||||
if self.scheduler is not None:
|
||||
logger.warning("Scheduler is already running")
|
||||
return
|
||||
|
||||
self.scheduler = BackgroundScheduler()
|
||||
|
||||
# Add daily credit refill job
|
||||
self._add_daily_credit_refill_job()
|
||||
|
||||
# Add sound scanning job
|
||||
self._add_sound_scanning_job()
|
||||
|
||||
# Start the scheduler
|
||||
self.scheduler.start()
|
||||
logger.info("Scheduler started successfully")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the scheduler."""
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.shutdown()
|
||||
self.scheduler = None
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
def _add_daily_credit_refill_job(self) -> None:
|
||||
"""Add the daily credit refill job."""
|
||||
if self.scheduler is None:
|
||||
raise RuntimeError("Scheduler not initialized")
|
||||
|
||||
# Schedule daily at 00:00 UTC
|
||||
trigger = CronTrigger(hour=0, minute=0)
|
||||
|
||||
self.scheduler.add_job(
|
||||
func=self._run_daily_credit_refill,
|
||||
trigger=trigger,
|
||||
id="daily_credit_refill",
|
||||
name="Daily Credit Refill",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.info("Daily credit refill job scheduled for 00:00 UTC")
|
||||
|
||||
def _add_sound_scanning_job(self) -> None:
|
||||
"""Add the sound scanning job."""
|
||||
if self.scheduler is None:
|
||||
raise RuntimeError("Scheduler not initialized")
|
||||
|
||||
# Schedule every 5 minutes for sound scanning
|
||||
trigger = CronTrigger(minute="*/5")
|
||||
|
||||
self.scheduler.add_job(
|
||||
func=self._run_sound_scan,
|
||||
trigger=trigger,
|
||||
id="sound_scan",
|
||||
name="Sound Directory Scan",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.info("Sound scanning job scheduled every 5 minutes")
|
||||
|
||||
def _run_daily_credit_refill(self) -> None:
|
||||
"""Execute the daily credit refill task."""
|
||||
logger.info("Starting daily credit refill task")
|
||||
|
||||
app = self.app or current_app
|
||||
with app.app_context():
|
||||
try:
|
||||
result = CreditService.refill_all_users_credits()
|
||||
|
||||
if result["success"]:
|
||||
logger.info(
|
||||
f"Daily credit refill completed successfully: "
|
||||
f"{result['users_processed']} users processed, "
|
||||
f"{result['credits_added']} credits added",
|
||||
)
|
||||
else:
|
||||
logger.error(f"Daily credit refill failed: {result['message']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during daily credit refill: {e}")
|
||||
|
||||
def _run_sound_scan(self) -> None:
|
||||
"""Execute the sound scanning task."""
|
||||
logger.info("Starting sound directory scan")
|
||||
|
||||
app = self.app or current_app
|
||||
with app.app_context():
|
||||
try:
|
||||
result = SoundScannerService.scan_soundboard_directory()
|
||||
|
||||
if result["success"]:
|
||||
if result["files_added"] > 0:
|
||||
logger.info(
|
||||
f"Sound scan completed: {result['files_added']} new sounds added",
|
||||
)
|
||||
else:
|
||||
logger.debug("Sound scan completed: no new files found")
|
||||
else:
|
||||
logger.error(f"Sound scan failed: {result.get('error', 'Unknown error')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during sound scan: {e}")
|
||||
|
||||
def trigger_credit_refill_now(self) -> dict:
|
||||
"""Manually trigger credit refill for testing purposes."""
|
||||
logger.info("Manually triggering credit refill")
|
||||
app = self.app or current_app
|
||||
with app.app_context():
|
||||
return CreditService.refill_all_users_credits()
|
||||
|
||||
def trigger_sound_scan_now(self) -> dict:
|
||||
"""Manually trigger sound scan for testing purposes."""
|
||||
logger.info("Manually triggering sound scan")
|
||||
app = self.app or current_app
|
||||
with app.app_context():
|
||||
return SoundScannerService.scan_soundboard_directory()
|
||||
|
||||
def get_scheduler_status(self) -> dict:
|
||||
"""Get the current status of the scheduler."""
|
||||
if self.scheduler is None:
|
||||
return {"running": False, "jobs": []}
|
||||
|
||||
jobs = [
|
||||
{
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run": job.next_run_time.isoformat()
|
||||
if job.next_run_time else None,
|
||||
"trigger": str(job.trigger),
|
||||
}
|
||||
for job in self.scheduler.get_jobs()
|
||||
]
|
||||
|
||||
return {
|
||||
"running": self.scheduler.running,
|
||||
"jobs": jobs,
|
||||
}
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
scheduler_service = SchedulerService()
|
||||
491
app/services/sound_normalizer_service.py
Normal file
491
app/services/sound_normalizer_service.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""Sound normalization service using ffmpeg loudnorm filter."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import ffmpeg
|
||||
from pydub import AudioSegment
|
||||
|
||||
from app.database import db
|
||||
from app.models.sound import Sound
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SoundNormalizerService:
|
||||
"""Service for normalizing sound files using ffmpeg loudnorm."""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
".mp3",
|
||||
".wav",
|
||||
".ogg",
|
||||
".flac",
|
||||
".m4a",
|
||||
".aac",
|
||||
".opus",
|
||||
}
|
||||
SOUNDS_DIR = "sounds/soundboard"
|
||||
NORMALIZED_DIR = "sounds/normalized/soundboard"
|
||||
|
||||
LOUDNORM_PARAMS = {
|
||||
"integrated": -16,
|
||||
"true_peak": -1.5,
|
||||
"lra": 11.0,
|
||||
"print_format": "summary",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def normalize_sound(sound_id: int, overwrite: bool = False) -> dict:
|
||||
"""Normalize a specific sound file using ffmpeg loudnorm.
|
||||
|
||||
Args:
|
||||
sound_id: ID of the sound to normalize
|
||||
overwrite: Whether to overwrite existing normalized file
|
||||
|
||||
Returns:
|
||||
dict: Result of the normalization operation
|
||||
|
||||
"""
|
||||
try:
|
||||
sound = Sound.query.get(sound_id)
|
||||
if not sound:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Sound with ID {sound_id} not found",
|
||||
}
|
||||
|
||||
source_path = Path(SoundNormalizerService.SOUNDS_DIR) / sound.filename
|
||||
if not source_path.exists():
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Source file not found: {source_path}",
|
||||
}
|
||||
|
||||
# Always output as WAV regardless of input format
|
||||
filename_without_ext = Path(sound.filename).stem
|
||||
normalized_filename = f"{filename_without_ext}.wav"
|
||||
normalized_path = Path(SoundNormalizerService.NORMALIZED_DIR) / normalized_filename
|
||||
|
||||
normalized_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalized_path.exists() and not overwrite:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Normalized file already exists: {normalized_path}. Use overwrite=True to replace it.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Starting normalization of {sound.name} ({sound.filename})",
|
||||
)
|
||||
|
||||
result = SoundNormalizerService._normalize_with_ffmpeg(
|
||||
str(source_path), str(normalized_path),
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
# Calculate normalized file metadata
|
||||
normalized_metadata = (
|
||||
SoundNormalizerService._get_normalized_metadata(
|
||||
str(normalized_path),
|
||||
)
|
||||
)
|
||||
|
||||
# Update sound record with normalized information
|
||||
sound.set_normalized_info(
|
||||
normalized_filename=normalized_filename,
|
||||
normalized_duration=normalized_metadata["duration"],
|
||||
normalized_size=normalized_metadata["size"],
|
||||
normalized_hash=normalized_metadata["hash"],
|
||||
)
|
||||
|
||||
# Commit the database changes
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"Successfully normalized {sound.name}")
|
||||
return {
|
||||
"success": True,
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound.name,
|
||||
"source_path": str(source_path),
|
||||
"normalized_path": str(normalized_path),
|
||||
"normalized_filename": normalized_filename,
|
||||
"normalized_duration": normalized_metadata["duration"],
|
||||
"normalized_size": normalized_metadata["size"],
|
||||
"normalized_hash": normalized_metadata["hash"],
|
||||
"loudnorm_stats": result.get("stats", {}),
|
||||
}
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing sound {sound_id}: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
def normalize_all_sounds(
|
||||
overwrite: bool = False, limit: int = None,
|
||||
) -> dict:
|
||||
"""Normalize all soundboard files.
|
||||
|
||||
Args:
|
||||
overwrite: Whether to overwrite existing normalized files
|
||||
limit: Maximum number of files to process (None for all)
|
||||
|
||||
Returns:
|
||||
dict: Summary of the normalization operation
|
||||
|
||||
"""
|
||||
try:
|
||||
query = Sound.query.filter_by(type="SDB")
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
sounds = query.all()
|
||||
|
||||
if not sounds:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "No soundboard files found to normalize",
|
||||
"processed": 0,
|
||||
"successful": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
}
|
||||
|
||||
logger.info(f"Starting bulk normalization of {len(sounds)} sounds")
|
||||
|
||||
processed = 0
|
||||
successful = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
for sound in sounds:
|
||||
result = SoundNormalizerService.normalize_sound(
|
||||
sound.id, overwrite,
|
||||
)
|
||||
processed += 1
|
||||
|
||||
if result["success"]:
|
||||
successful += 1
|
||||
elif "already exists" in result.get("error", ""):
|
||||
skipped += 1
|
||||
else:
|
||||
failed += 1
|
||||
errors.append(f"{sound.name}: {result['error']}")
|
||||
|
||||
logger.info(
|
||||
f"Bulk normalization completed: {successful} successful, {failed} failed, {skipped} skipped",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Processed {processed} sounds: {successful} successful, {failed} failed, {skipped} skipped",
|
||||
"processed": processed,
|
||||
"successful": successful,
|
||||
"failed": failed,
|
||||
"skipped": skipped,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during bulk normalization: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"processed": 0,
|
||||
"successful": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_with_ffmpeg(source_path: str, output_path: str) -> dict:
|
||||
"""Run ffmpeg loudnorm on a single file using python-ffmpeg.
|
||||
|
||||
Args:
|
||||
source_path: Path to source audio file
|
||||
output_path: Path for normalized output file (will be WAV format)
|
||||
|
||||
Returns:
|
||||
dict: Result with success status and loudnorm statistics
|
||||
|
||||
"""
|
||||
try:
|
||||
params = SoundNormalizerService.LOUDNORM_PARAMS
|
||||
|
||||
logger.debug(
|
||||
f"Running ffmpeg normalization: {source_path} -> {output_path}",
|
||||
)
|
||||
|
||||
# Create ffmpeg input stream
|
||||
input_stream = ffmpeg.input(source_path)
|
||||
|
||||
# Apply loudnorm filter
|
||||
loudnorm_filter = f"loudnorm=I={params['integrated']}:TP={params['true_peak']}:LRA={params['lra']}:print_format={params['print_format']}"
|
||||
|
||||
# Create output stream with WAV format
|
||||
output_stream = ffmpeg.output(
|
||||
input_stream,
|
||||
output_path,
|
||||
acodec="pcm_s16le", # 16-bit PCM for WAV
|
||||
ar=44100, # 44.1kHz sample rate
|
||||
af=loudnorm_filter,
|
||||
y=None, # Overwrite output file
|
||||
)
|
||||
|
||||
# Run the ffmpeg process
|
||||
out, err = ffmpeg.run(
|
||||
output_stream, capture_stdout=True, capture_stderr=True,
|
||||
)
|
||||
|
||||
# Parse loudnorm statistics from stderr
|
||||
stats = SoundNormalizerService._parse_loudnorm_stats(
|
||||
err.decode() if err else "",
|
||||
)
|
||||
|
||||
if not Path(output_path).exists():
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Output file was not created",
|
||||
}
|
||||
|
||||
return {"success": True, "stats": stats}
|
||||
|
||||
except ffmpeg.Error as e:
|
||||
error_msg = (
|
||||
f"FFmpeg error: {e.stderr.decode() if e.stderr else str(e)}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return {"success": False, "error": error_msg}
|
||||
except Exception as e:
|
||||
logger.error(f"Error running ffmpeg: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
def _parse_loudnorm_stats(stderr_output: str) -> dict:
|
||||
"""Parse loudnorm statistics from ffmpeg stderr output.
|
||||
|
||||
Args:
|
||||
stderr_output: ffmpeg stderr output containing loudnorm stats
|
||||
|
||||
Returns:
|
||||
dict: Parsed loudnorm statistics
|
||||
|
||||
"""
|
||||
stats = {}
|
||||
|
||||
if not stderr_output:
|
||||
return stats
|
||||
|
||||
lines = stderr_output.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if "Input Integrated:" in line:
|
||||
try:
|
||||
stats["input_integrated"] = float(line.split()[-2])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
elif "Input True Peak:" in line:
|
||||
try:
|
||||
stats["input_true_peak"] = float(line.split()[-2])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
elif "Input LRA:" in line:
|
||||
try:
|
||||
stats["input_lra"] = float(line.split()[-1])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
elif "Output Integrated:" in line:
|
||||
try:
|
||||
stats["output_integrated"] = float(line.split()[-2])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
elif "Output True Peak:" in line:
|
||||
try:
|
||||
stats["output_true_peak"] = float(line.split()[-2])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
elif "Output LRA:" in line:
|
||||
try:
|
||||
stats["output_lra"] = float(line.split()[-1])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
def _get_normalized_metadata(file_path: str) -> dict:
|
||||
"""Calculate metadata for normalized file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the normalized audio file
|
||||
|
||||
Returns:
|
||||
dict: Metadata including duration and hash
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get file size
|
||||
file_size = Path(file_path).stat().st_size
|
||||
|
||||
# Calculate file hash
|
||||
file_hash = SoundNormalizerService._calculate_file_hash(file_path)
|
||||
|
||||
# Get duration using pydub
|
||||
audio = AudioSegment.from_wav(file_path)
|
||||
duration = len(audio) # Duration in milliseconds
|
||||
|
||||
return {
|
||||
"duration": duration,
|
||||
"size": file_size,
|
||||
"hash": file_hash,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating metadata for {file_path}: {e}")
|
||||
return {
|
||||
"duration": 0,
|
||||
"size": Path(file_path).stat().st_size,
|
||||
"hash": "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _calculate_file_hash(file_path: str) -> str:
|
||||
"""Calculate SHA256 hash of file contents."""
|
||||
sha256_hash = hashlib.sha256()
|
||||
|
||||
with Path(file_path).open("rb") as f:
|
||||
# Read file in chunks to handle large files
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(chunk)
|
||||
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def get_normalization_status() -> dict:
|
||||
"""Get statistics about normalized vs original files.
|
||||
|
||||
Returns:
|
||||
dict: Statistics about normalization status
|
||||
|
||||
"""
|
||||
try:
|
||||
total_sounds = Sound.query.filter_by(type="SDB").count()
|
||||
|
||||
normalized_count = 0
|
||||
total_original_size = 0
|
||||
total_normalized_size = 0
|
||||
|
||||
sounds = Sound.query.filter_by(type="SDB").all()
|
||||
|
||||
for sound in sounds:
|
||||
original_path = Path(SoundNormalizerService.SOUNDS_DIR) / sound.filename
|
||||
|
||||
if original_path.exists():
|
||||
total_original_size += original_path.stat().st_size
|
||||
|
||||
# Use database field to check if normalized, not file existence
|
||||
if sound.is_normalized and sound.normalized_filename:
|
||||
normalized_count += 1
|
||||
normalized_path = Path(SoundNormalizerService.NORMALIZED_DIR) / sound.normalized_filename
|
||||
if normalized_path.exists():
|
||||
total_normalized_size += normalized_path.stat().st_size
|
||||
|
||||
return {
|
||||
"total_sounds": total_sounds,
|
||||
"normalized_count": normalized_count,
|
||||
"normalization_percentage": (
|
||||
(normalized_count / total_sounds * 100)
|
||||
if total_sounds > 0
|
||||
else 0
|
||||
),
|
||||
"total_original_size": total_original_size,
|
||||
"total_normalized_size": total_normalized_size,
|
||||
"size_difference": (
|
||||
total_normalized_size - total_original_size
|
||||
if normalized_count > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting normalization status: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"total_sounds": 0,
|
||||
"normalized_count": 0,
|
||||
"normalization_percentage": 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def check_ffmpeg_availability() -> dict:
|
||||
"""Check if ffmpeg is available and supports loudnorm filter.
|
||||
|
||||
Returns:
|
||||
dict: Information about ffmpeg availability and capabilities
|
||||
|
||||
"""
|
||||
try:
|
||||
# Create a minimal test audio file to check ffmpeg
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".wav", delete=False,
|
||||
) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Try a simple ffmpeg operation to check availability
|
||||
test_input = ffmpeg.input(
|
||||
"anullsrc=channel_layout=stereo:sample_rate=44100",
|
||||
f="lavfi",
|
||||
t=0.1,
|
||||
)
|
||||
test_output = ffmpeg.output(test_input, temp_path)
|
||||
ffmpeg.run(
|
||||
test_output,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
quiet=True,
|
||||
)
|
||||
|
||||
# If we get here, basic ffmpeg is working
|
||||
# Now test loudnorm filter
|
||||
try:
|
||||
norm_input = ffmpeg.input(temp_path)
|
||||
norm_output = ffmpeg.output(
|
||||
norm_input,
|
||||
"/dev/null",
|
||||
af="loudnorm=I=-16:TP=-1.5:LRA=11.0",
|
||||
f="null",
|
||||
)
|
||||
ffmpeg.run(
|
||||
norm_output,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
quiet=True,
|
||||
)
|
||||
has_loudnorm = True
|
||||
except ffmpeg.Error:
|
||||
has_loudnorm = False
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"version": "ffmpeg-python wrapper available",
|
||||
"has_loudnorm": has_loudnorm,
|
||||
"ready": has_loudnorm,
|
||||
}
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
temp_file_path = Path(temp_path)
|
||||
if temp_file_path.exists():
|
||||
temp_file_path.unlink()
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"available": False,
|
||||
"error": f"ffmpeg not available via python-ffmpeg: {e!s}",
|
||||
}
|
||||
316
app/services/sound_scanner_service.py
Normal file
316
app/services/sound_scanner_service.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Sound file scanning service for discovering and importing audio files."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from pydub import AudioSegment
|
||||
from pydub.utils import mediainfo
|
||||
|
||||
from app.database import db
|
||||
from app.models.sound import Sound
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SoundScannerService:
|
||||
"""Service for scanning and importing sound files."""
|
||||
|
||||
# Supported audio file extensions
|
||||
SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".ogg", ".flac", ".m4a", ".aac"}
|
||||
|
||||
# Default soundboard directory
|
||||
DEFAULT_SOUNDBOARD_DIR = "sounds/soundboard"
|
||||
|
||||
@staticmethod
|
||||
def scan_soundboard_directory(
|
||||
directory: str | None = None,
|
||||
) -> dict:
|
||||
"""Scan the soundboard directory and add new files to the database.
|
||||
|
||||
Args:
|
||||
directory: Directory to scan (defaults to sounds/soundboard)
|
||||
|
||||
Returns:
|
||||
dict: Summary of the scan operation
|
||||
|
||||
"""
|
||||
scan_dir = directory or SoundScannerService.DEFAULT_SOUNDBOARD_DIR
|
||||
|
||||
try:
|
||||
# Ensure directory exists
|
||||
scan_path = Path(scan_dir)
|
||||
if not scan_path.exists():
|
||||
logger.warning(
|
||||
f"Soundboard directory does not exist: {scan_dir}",
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Directory not found: {scan_dir}",
|
||||
"files_found": 0,
|
||||
"files_added": 0,
|
||||
"files_skipped": 0,
|
||||
}
|
||||
|
||||
logger.info(f"Starting soundboard scan in: {scan_dir}")
|
||||
|
||||
files_found = 0
|
||||
files_added = 0
|
||||
files_skipped = 0
|
||||
errors = []
|
||||
|
||||
# Walk through directory and subdirectories
|
||||
for file_path in scan_path.rglob("*"):
|
||||
if file_path.is_file():
|
||||
filename = file_path.name
|
||||
|
||||
# Check if file has supported extension
|
||||
if not SoundScannerService._is_supported_audio_file(
|
||||
filename,
|
||||
):
|
||||
continue
|
||||
|
||||
files_found += 1
|
||||
|
||||
try:
|
||||
# Process the audio file
|
||||
result = SoundScannerService._process_audio_file(
|
||||
str(file_path),
|
||||
scan_dir,
|
||||
)
|
||||
|
||||
if result["added"]:
|
||||
files_added += 1
|
||||
logger.debug(f"Added sound: {filename}")
|
||||
elif result.get("updated"):
|
||||
files_added += 1 # Count updates as additions for reporting
|
||||
logger.debug(f"Updated sound: {filename}")
|
||||
else:
|
||||
files_skipped += 1
|
||||
logger.debug(
|
||||
f"Skipped sound: {filename} ({result['reason']})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing {filename}: {e!s}"
|
||||
logger.error(error_msg)
|
||||
errors.append(error_msg)
|
||||
files_skipped += 1
|
||||
|
||||
# Commit all changes
|
||||
db.session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Soundboard scan completed: {files_found} files found, "
|
||||
f"{files_added} added, {files_skipped} skipped",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"directory": scan_dir,
|
||||
"files_found": files_found,
|
||||
"files_added": files_added,
|
||||
"files_skipped": files_skipped,
|
||||
"errors": errors,
|
||||
"message": f"Scan completed: {files_added} new sounds added",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"Error during soundboard scan: {e!s}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"files_found": 0,
|
||||
"files_added": 0,
|
||||
"files_skipped": 0,
|
||||
"message": "Soundboard scan failed",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _is_supported_audio_file(filename: str) -> bool:
|
||||
"""Check if file has a supported audio extension."""
|
||||
return (
|
||||
Path(filename).suffix.lower()
|
||||
in SoundScannerService.SUPPORTED_EXTENSIONS
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_audio_file(file_path: str, base_dir: str) -> dict:
|
||||
"""Process a single audio file and add it to database if new.
|
||||
|
||||
Args:
|
||||
file_path: Full path to the audio file
|
||||
base_dir: Base directory for relative path calculation
|
||||
|
||||
Returns:
|
||||
dict: Processing result with added flag and reason
|
||||
|
||||
"""
|
||||
# Calculate file hash for deduplication
|
||||
file_hash = SoundScannerService._calculate_file_hash(file_path)
|
||||
|
||||
# Get file metadata
|
||||
metadata = SoundScannerService._extract_audio_metadata(file_path)
|
||||
|
||||
# Calculate relative filename from base directory
|
||||
relative_path = Path(file_path).relative_to(Path(base_dir))
|
||||
|
||||
# Check if file already exists in database by hash
|
||||
existing_sound = Sound.find_by_hash(file_hash)
|
||||
if existing_sound:
|
||||
return {
|
||||
"added": False,
|
||||
"reason": f"File already exists as '{existing_sound.name}'",
|
||||
}
|
||||
|
||||
# Check if filename already exists in database
|
||||
existing_filename_sound = Sound.find_by_filename(str(relative_path))
|
||||
if existing_filename_sound:
|
||||
# Remove normalized files and clear normalized info
|
||||
SoundScannerService._clear_normalized_files(existing_filename_sound)
|
||||
existing_filename_sound.clear_normalized_info()
|
||||
|
||||
# Update existing sound with new file information
|
||||
existing_filename_sound.update_file_info(
|
||||
filename=str(relative_path),
|
||||
duration=metadata["duration"],
|
||||
size=metadata["size"],
|
||||
hash_value=file_hash,
|
||||
)
|
||||
|
||||
return {
|
||||
"added": False,
|
||||
"updated": True,
|
||||
"sound_id": existing_filename_sound.id,
|
||||
"reason": f"Updated existing sound '{existing_filename_sound.name}' with new file data",
|
||||
}
|
||||
|
||||
# Generate sound name from filename (without extension)
|
||||
sound_name = Path(file_path).stem
|
||||
|
||||
# Check if name already exists and make it unique if needed
|
||||
counter = 1
|
||||
original_name = sound_name
|
||||
while Sound.find_by_name(sound_name):
|
||||
sound_name = f"{original_name}_{counter}"
|
||||
counter += 1
|
||||
|
||||
# Create new sound record
|
||||
sound = Sound.create_sound(
|
||||
sound_type="SDB", # Soundboard type
|
||||
name=sound_name,
|
||||
filename=str(relative_path),
|
||||
duration=metadata["duration"],
|
||||
size=metadata["size"],
|
||||
hash_value=file_hash,
|
||||
is_music=False,
|
||||
is_deletable=False,
|
||||
commit=False, # Don't commit individually, let scanner handle transaction
|
||||
)
|
||||
|
||||
return {
|
||||
"added": True,
|
||||
"sound_id": sound.id,
|
||||
"reason": "New file added successfully",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _calculate_file_hash(file_path: str) -> str:
|
||||
"""Calculate SHA256 hash of file contents."""
|
||||
sha256_hash = hashlib.sha256()
|
||||
|
||||
with Path(file_path).open("rb") as f:
|
||||
# Read file in chunks to handle large files
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(chunk)
|
||||
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _clear_normalized_files(sound: Sound) -> None:
|
||||
"""Remove normalized files for a sound if they exist."""
|
||||
if sound.is_normalized and sound.normalized_filename:
|
||||
# Import here to avoid circular imports
|
||||
from app.services.sound_normalizer_service import SoundNormalizerService
|
||||
|
||||
normalized_path = Path(SoundNormalizerService.NORMALIZED_DIR) / sound.normalized_filename
|
||||
if normalized_path.exists():
|
||||
try:
|
||||
normalized_path.unlink()
|
||||
logger.info(f"Removed normalized file: {normalized_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove normalized file {normalized_path}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _extract_audio_metadata(file_path: str) -> dict:
|
||||
"""Extract metadata from audio file using pydub and mediainfo."""
|
||||
try:
|
||||
# Get file size
|
||||
file_size = Path(file_path).stat().st_size
|
||||
|
||||
# Load audio file with pydub for basic info
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
|
||||
# Extract basic metadata from AudioSegment
|
||||
duration = len(audio)
|
||||
channels = audio.channels
|
||||
sample_rate = audio.frame_rate
|
||||
|
||||
# Use mediainfo for more accurate bitrate information
|
||||
bitrate = None
|
||||
try:
|
||||
info = mediainfo(file_path)
|
||||
if info and "bit_rate" in info:
|
||||
bitrate = int(info["bit_rate"])
|
||||
elif info and "bitrate" in info:
|
||||
bitrate = int(info["bitrate"])
|
||||
except (ValueError, KeyError, TypeError):
|
||||
# Fallback to calculated bitrate if mediainfo fails
|
||||
if duration > 0:
|
||||
file_size_bits = file_size * 8
|
||||
bitrate = int(file_size_bits / duration / 1000)
|
||||
|
||||
return {
|
||||
"duration": duration,
|
||||
"size": file_size,
|
||||
"bitrate": bitrate,
|
||||
"channels": channels,
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract metadata from {file_path}: {e}")
|
||||
return {
|
||||
"duration": 0,
|
||||
"size": Path(file_path).stat().st_size,
|
||||
"bitrate": None,
|
||||
"channels": None,
|
||||
"sample_rate": None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_scan_statistics() -> dict:
|
||||
"""Get statistics about sounds in the database."""
|
||||
total_sounds = Sound.query.count()
|
||||
sdb_sounds = Sound.query.filter_by(type="SDB").count()
|
||||
music_sounds = Sound.query.filter_by(is_music=True).count()
|
||||
|
||||
# Calculate total size and duration
|
||||
sounds = Sound.query.all()
|
||||
total_size = sum(sound.size for sound in sounds)
|
||||
total_duration = sum(sound.duration for sound in sounds)
|
||||
total_plays = sum(sound.play_count for sound in sounds)
|
||||
|
||||
return {
|
||||
"total_sounds": total_sounds,
|
||||
"soundboard_sounds": sdb_sounds,
|
||||
"music_sounds": music_sounds,
|
||||
"total_size_bytes": total_size,
|
||||
"total_duration": total_duration,
|
||||
"total_plays": total_plays,
|
||||
"most_played": [
|
||||
sound.to_dict() for sound in Sound.get_most_played(5)
|
||||
],
|
||||
}
|
||||
@@ -1,30 +1,35 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Database migration script for Flask-Migrate."""
|
||||
|
||||
import os
|
||||
from flask.cli import FlaskGroup
|
||||
|
||||
from app import create_app
|
||||
from app.database import db
|
||||
|
||||
app = create_app()
|
||||
cli = FlaskGroup(app)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def init_db():
|
||||
"""Initialize the database."""
|
||||
print("Initializing database...")
|
||||
from app.database_init import init_database
|
||||
|
||||
init_database()
|
||||
print("Database initialized successfully!")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def reset_db():
|
||||
"""Reset the database (drop all tables and recreate)."""
|
||||
print("Resetting database...")
|
||||
db.drop_all()
|
||||
from app.database_init import init_database
|
||||
|
||||
init_database()
|
||||
print("Database reset successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
cli()
|
||||
|
||||
@@ -6,12 +6,15 @@ authors = [{ name = "quaik8", email = "quaik8@gmail.com" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"apscheduler==3.11.0",
|
||||
"authlib==1.6.0",
|
||||
"ffmpeg-python>=0.2.0",
|
||||
"flask==3.1.1",
|
||||
"flask-cors==6.0.1",
|
||||
"flask-jwt-extended==4.7.1",
|
||||
"flask-migrate==4.1.0",
|
||||
"flask-sqlalchemy==3.1.1",
|
||||
"pydub==0.25.1",
|
||||
"python-dotenv==1.1.1",
|
||||
"requests==2.32.4",
|
||||
"werkzeug==3.1.3",
|
||||
|
||||
14
reset.sh
14
reset.sh
@@ -1,5 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
shopt -s extglob
|
||||
|
||||
rm instance/soundboard.db
|
||||
uv run migrate_db.py init-db
|
||||
|
||||
rm -rf alembic/versions/!(.gitignore)
|
||||
rm -rf sounds/say/!(.gitignore)
|
||||
rm -rf sounds/stream/!(.gitignore|thumbnails)
|
||||
rm -rf sounds/stream/thumbnails/!(.gitignore)
|
||||
rm -rf sounds/temp/!(.gitignore)
|
||||
rm -rf sounds/normalized/say/!(.gitignore)
|
||||
rm -rf sounds/normalized/soundboard/!(.gitignore)
|
||||
rm -rf sounds/normalized/stream/!(.gitignore)
|
||||
|
||||
# uv run migrate_db.py init-db
|
||||
uv run main.py
|
||||
5
sounds/normalized/.gitignore
vendored
Normal file
5
sounds/normalized/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
*
|
||||
!.gitignore
|
||||
!say
|
||||
!soundboard
|
||||
!stream
|
||||
2
sounds/normalized/say/.gitignore
vendored
Normal file
2
sounds/normalized/say/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
||||
2
sounds/normalized/soundboard/.gitignore
vendored
Normal file
2
sounds/normalized/soundboard/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
||||
3
sounds/normalized/stream/.gitignore
vendored
Normal file
3
sounds/normalized/stream/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
*
|
||||
!.gitignore
|
||||
!thumbnails
|
||||
2
sounds/say/.gitignore
vendored
Normal file
2
sounds/say/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
||||
BIN
sounds/soundboard/20th_century_fox.mp3
Normal file
BIN
sounds/soundboard/20th_century_fox.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/3corde.wav
Normal file
BIN
sounds/soundboard/3corde.wav
Normal file
Binary file not shown.
BIN
sounds/soundboard/a_few_moments_later.mp3
Normal file
BIN
sounds/soundboard/a_few_moments_later.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/aallez.wav
Normal file
BIN
sounds/soundboard/aallez.wav
Normal file
Binary file not shown.
BIN
sounds/soundboard/ah_denis_brogniart.mp3
Normal file
BIN
sounds/soundboard/ah_denis_brogniart.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/alerte_gogole.mp3
Normal file
BIN
sounds/soundboard/alerte_gogole.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/allez.wav
Normal file
BIN
sounds/soundboard/allez.wav
Normal file
Binary file not shown.
BIN
sounds/soundboard/among_us.mp3
Normal file
BIN
sounds/soundboard/among_us.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/and_his_name_is_john_cena.mp3
Normal file
BIN
sounds/soundboard/and_his_name_is_john_cena.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/animal_crossing_bla.mp3
Normal file
BIN
sounds/soundboard/animal_crossing_bla.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/another_one.mp3
Normal file
BIN
sounds/soundboard/another_one.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/as_tu_vu_les_quenouilles.mp3
Normal file
BIN
sounds/soundboard/as_tu_vu_les_quenouilles.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/as_tu_vu_les_quenouilles_long.mp3
Normal file
BIN
sounds/soundboard/as_tu_vu_les_quenouilles_long.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/aughhhhh_aughhhhh.mp3
Normal file
BIN
sounds/soundboard/aughhhhh_aughhhhh.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/awwww.mp3
Normal file
BIN
sounds/soundboard/awwww.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/bebou.mp3
Normal file
BIN
sounds/soundboard/bebou.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/bebou_long.mp3
Normal file
BIN
sounds/soundboard/bebou_long.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/bizarre.opus
Normal file
BIN
sounds/soundboard/bizarre.opus
Normal file
Binary file not shown.
BIN
sounds/soundboard/bonk.mp3
Normal file
BIN
sounds/soundboard/bonk.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/brother_ewwwwwww.mp3
Normal file
BIN
sounds/soundboard/brother_ewwwwwww.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/c_est_honteux.mp3
Normal file
BIN
sounds/soundboard/c_est_honteux.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/c_est_l_heure_de_manger.mp3
Normal file
BIN
sounds/soundboard/c_est_l_heure_de_manger.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/c_est_la_mer_noir.mp3
Normal file
BIN
sounds/soundboard/c_est_la_mer_noir.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/c_t_sur_sard.mp3
Normal file
BIN
sounds/soundboard/c_t_sur_sard.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/ca_va_peter.mp3
Normal file
BIN
sounds/soundboard/ca_va_peter.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/careless_whisper_short.mp3
Normal file
BIN
sounds/soundboard/careless_whisper_short.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/carrefour.mp3
Normal file
BIN
sounds/soundboard/carrefour.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/cest_moi.wav
Normal file
BIN
sounds/soundboard/cest_moi.wav
Normal file
Binary file not shown.
BIN
sounds/soundboard/cloche_de_boxe.mp3
Normal file
BIN
sounds/soundboard/cloche_de_boxe.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/combien.mp3
Normal file
BIN
sounds/soundboard/combien.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/comment_ca_mon_reuf_sans_le_quoi.mp3
Normal file
BIN
sounds/soundboard/comment_ca_mon_reuf_sans_le_quoi.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/community_chang_gay.mp3
Normal file
BIN
sounds/soundboard/community_chang_gay.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/cou.wav
Normal file
BIN
sounds/soundboard/cou.wav
Normal file
Binary file not shown.
BIN
sounds/soundboard/coucou.mp3
Normal file
BIN
sounds/soundboard/coucou.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/dancehall_horn.mp3
Normal file
BIN
sounds/soundboard/dancehall_horn.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/decathlon.mp3
Normal file
BIN
sounds/soundboard/decathlon.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/dikkenek_ou_tu_sors_ou_j_te_sors.mp3
Normal file
BIN
sounds/soundboard/dikkenek_ou_tu_sors_ou_j_te_sors.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/directed_by_robert_b_weide.mp3
Normal file
BIN
sounds/soundboard/directed_by_robert_b_weide.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/downer_noise.mp3
Normal file
BIN
sounds/soundboard/downer_noise.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/dry_fart.mp3
Normal file
BIN
sounds/soundboard/dry_fart.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/emotional_damage.mp3
Normal file
BIN
sounds/soundboard/emotional_damage.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/epic_sax_guy.mp3
Normal file
BIN
sounds/soundboard/epic_sax_guy.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/etchebest_c_est_con_ça.mp3
Normal file
BIN
sounds/soundboard/etchebest_c_est_con_ça.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/excuse_moiiii.mp3
Normal file
BIN
sounds/soundboard/excuse_moiiii.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/expecto_patronum.mp3
Normal file
BIN
sounds/soundboard/expecto_patronum.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/fart_with_extra_reverb.mp3
Normal file
BIN
sounds/soundboard/fart_with_extra_reverb.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/fbi_open_up.mp3
Normal file
BIN
sounds/soundboard/fbi_open_up.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/fdp.mp3
Normal file
BIN
sounds/soundboard/fdp.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/flute.wav
Normal file
BIN
sounds/soundboard/flute.wav
Normal file
Binary file not shown.
BIN
sounds/soundboard/flute_anniv_LIP.mp3
Normal file
BIN
sounds/soundboard/flute_anniv_LIP.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/fonctionnaire.mp3
Normal file
BIN
sounds/soundboard/fonctionnaire.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/gay_echo.mp3
Normal file
BIN
sounds/soundboard/gay_echo.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/goku_drip.mp3
Normal file
BIN
sounds/soundboard/goku_drip.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/gta_mission_complete.mp3
Normal file
BIN
sounds/soundboard/gta_mission_complete.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/gtav_wasted.mp3
Normal file
BIN
sounds/soundboard/gtav_wasted.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/happy_happy_happy.mp3
Normal file
BIN
sounds/soundboard/happy_happy_happy.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/hugooo.mp3
Normal file
BIN
sounds/soundboard/hugooo.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/i_will_be_back.mp3
Normal file
BIN
sounds/soundboard/i_will_be_back.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/initial_d_deja_vu.mp3
Normal file
BIN
sounds/soundboard/initial_d_deja_vu.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/initial_d_gas_gas_gas.mp3
Normal file
BIN
sounds/soundboard/initial_d_gas_gas_gas.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/insult.wav
Normal file
BIN
sounds/soundboard/insult.wav
Normal file
Binary file not shown.
BIN
sounds/soundboard/je_suis_pas_venue_ici_pour_souffrir_ok.mp3
Normal file
BIN
sounds/soundboard/je_suis_pas_venue_ici_pour_souffrir_ok.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/je_te_demande_pardon.mp3
Normal file
BIN
sounds/soundboard/je_te_demande_pardon.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/je_vous_demande_de_vous_arreter.mp3
Normal file
BIN
sounds/soundboard/je_vous_demande_de_vous_arreter.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/julien_lepers_Ah_ouai_ouai_ouai_question.mp3
Normal file
BIN
sounds/soundboard/julien_lepers_Ah_ouai_ouai_ouai_question.mp3
Normal file
Binary file not shown.
Binary file not shown.
BIN
sounds/soundboard/kabuki.mp3
Normal file
BIN
sounds/soundboard/kabuki.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/karime_cuisiniere.mp3
Normal file
BIN
sounds/soundboard/karime_cuisiniere.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/karime_enfant_gache_court.mp3
Normal file
BIN
sounds/soundboard/karime_enfant_gache_court.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/karime_enfant_gache_long.mp3
Normal file
BIN
sounds/soundboard/karime_enfant_gache_long.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/karime_enfant_gache_medium.mp3
Normal file
BIN
sounds/soundboard/karime_enfant_gache_medium.mp3
Normal file
Binary file not shown.
BIN
sounds/soundboard/kendrick_mustard.mp3
Normal file
BIN
sounds/soundboard/kendrick_mustard.mp3
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user