Add Alembic for database migrations and initial migration scripts

- Created alembic.ini configuration file for Alembic migrations.
- Added README file for Alembic with a brief description.
- Implemented env.py for Alembic to manage database migrations.
- Created script.py.mako template for migration scripts.
- Added initial migration script to create database tables.
- Created a migration script to add initial plan and playlist data.
- Updated database initialization to run Alembic migrations.
- Enhanced credit service to automatically recharge user credits based on their plan.
- Implemented delete_task method in scheduler service to remove scheduled tasks.
- Updated scheduler API to reflect task deletion instead of cancellation.
- Added CLI tool for managing database migrations.
- Updated tests to cover new functionality for task deletion and credit recharge.
- Updated pyproject.toml and lock files to include Alembic as a dependency.
This commit is contained in:
JSC
2025-09-16 13:45:14 +02:00
parent e8f979c137
commit 83239cb4fa
16 changed files with 828 additions and 29 deletions

148
alembic.ini Normal file
View File

@@ -0,0 +1,148 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
# sqlalchemy.url = driver://user:pass@localhost/dbname
# URL will be set dynamically in env.py from config
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

86
alembic/env.py Normal file
View File

@@ -0,0 +1,86 @@
import asyncio
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from sqlalchemy.ext.asyncio import create_async_engine
from alembic import context
import app.models # noqa: F401
from app.core.config import settings
from sqlmodel import SQLModel
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Set the database URL from settings - convert async URL to sync for alembic
sync_db_url = settings.DATABASE_URL.replace("sqlite+aiosqlite", "sqlite")
sync_db_url = sync_db_url.replace("postgresql+asyncpg", "postgresql")
config.set_main_option("sqlalchemy.url", sync_db_url)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = SQLModel.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,222 @@
"""Initial migration
Revision ID: 7aa9892ceff3
Revises:
Create Date: 2025-09-16 13:16:58.233360
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '7aa9892ceff3'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('plan',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('code', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('credits', sa.Integer(), nullable=False),
sa.Column('max_credits', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_plan_code'), 'plan', ['code'], unique=True)
op.create_table('sound',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('filename', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('duration', sa.Integer(), nullable=False),
sa.Column('size', sa.Integer(), nullable=False),
sa.Column('hash', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('normalized_filename', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('normalized_duration', sa.Integer(), nullable=True),
sa.Column('normalized_size', sa.Integer(), nullable=True),
sa.Column('normalized_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('thumbnail', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('play_count', sa.Integer(), nullable=False),
sa.Column('is_normalized', sa.Boolean(), nullable=False),
sa.Column('is_music', sa.Boolean(), nullable=False),
sa.Column('is_deletable', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('hash', name='uq_sound_hash')
)
op.create_table('user',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('plan_id', sa.Integer(), nullable=False),
sa.Column('role', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('picture', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('password_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('credits', sa.Integer(), nullable=False),
sa.Column('api_token', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('api_token_expires_at', sa.DateTime(), nullable=True),
sa.Column('refresh_token_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('refresh_token_expires_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['plan_id'], ['plan.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('api_token'),
sa.UniqueConstraint('email')
)
op.create_table('credit_transaction',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('action_type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('amount', sa.Integer(), nullable=False),
sa.Column('balance_before', sa.Integer(), nullable=False),
sa.Column('balance_after', sa.Integer(), nullable=False),
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('success', sa.Boolean(), nullable=False),
sa.Column('metadata_json', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('extraction',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('service', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('service_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('sound_id', sa.Integer(), nullable=True),
sa.Column('url', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('title', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('track', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('artist', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('album', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('genre', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('status', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('error', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('playlist',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('genre', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('is_main', sa.Boolean(), nullable=False),
sa.Column('is_current', sa.Boolean(), nullable=False),
sa.Column('is_deletable', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name')
)
op.create_table('scheduled_task',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('task_type', sa.Enum('CREDIT_RECHARGE', 'PLAY_SOUND', 'PLAY_PLAYLIST', name='tasktype'), nullable=False),
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='taskstatus'), nullable=False),
sa.Column('scheduled_at', sa.DateTime(), nullable=False),
sa.Column('timezone', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('recurrence_type', sa.Enum('NONE', 'MINUTELY', 'HOURLY', 'DAILY', 'WEEKLY', 'MONTHLY', 'YEARLY', 'CRON', name='recurrencetype'), nullable=False),
sa.Column('cron_expression', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('recurrence_count', sa.Integer(), nullable=True),
sa.Column('executions_count', sa.Integer(), nullable=False),
sa.Column('parameters', sa.JSON(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('last_executed_at', sa.DateTime(), nullable=True),
sa.Column('next_execution_at', sa.DateTime(), nullable=True),
sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('expires_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('sound_played',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('sound_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user_oauth',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('provider_user_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('picture', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_user_oauth_provider_user_id')
)
op.create_table('favorite',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('sound_id', sa.Integer(), nullable=True),
sa.Column('playlist_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['playlist_id'], ['playlist.id'], ),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('user_id', 'playlist_id', name='uq_favorite_user_playlist'),
sa.UniqueConstraint('user_id', 'sound_id', name='uq_favorite_user_sound')
)
op.create_table('playlist_sound',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('playlist_id', sa.Integer(), nullable=False),
sa.Column('sound_id', sa.Integer(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['playlist_id'], ['playlist.id'], ),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('playlist_id', 'position', name='uq_playlist_sound_playlist_position'),
sa.UniqueConstraint('playlist_id', 'sound_id', name='uq_playlist_sound_playlist_sound')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('playlist_sound')
op.drop_table('favorite')
op.drop_table('user_oauth')
op.drop_table('sound_played')
op.drop_table('scheduled_task')
op.drop_table('playlist')
op.drop_table('extraction')
op.drop_table('credit_transaction')
op.drop_table('user')
op.drop_table('sound')
op.drop_index(op.f('ix_plan_code'), table_name='plan')
op.drop_table('plan')
# ### end Alembic commands ###

View File

@@ -0,0 +1,106 @@
"""Add initial plan and playlist data
Revision ID: a0d322857b2c
Revises: 7aa9892ceff3
Create Date: 2025-09-16 13:23:31.682276
"""
from typing import Sequence, Union
from datetime import datetime
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'a0d322857b2c'
down_revision: Union[str, Sequence[str], None] = '7aa9892ceff3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema and add initial data."""
# Get the current timestamp
now = datetime.utcnow()
# Insert initial plans
plans_table = sa.table(
'plan',
sa.column('code', sa.String),
sa.column('name', sa.String),
sa.column('description', sa.String),
sa.column('credits', sa.Integer),
sa.column('max_credits', sa.Integer),
sa.column('created_at', sa.DateTime),
sa.column('updated_at', sa.DateTime),
)
op.bulk_insert(
plans_table,
[
{
'code': 'free',
'name': 'Free Plan',
'description': 'Basic free plan with limited features',
'credits': 25,
'max_credits': 75,
'created_at': now,
'updated_at': now,
},
{
'code': 'premium',
'name': 'Premium Plan',
'description': 'Premium plan with more features',
'credits': 50,
'max_credits': 150,
'created_at': now,
'updated_at': now,
},
{
'code': 'pro',
'name': 'Pro Plan',
'description': 'Pro plan with unlimited features',
'credits': 100,
'max_credits': 300,
'created_at': now,
'updated_at': now,
},
]
)
# Insert main playlist
playlist_table = sa.table(
'playlist',
sa.column('name', sa.String),
sa.column('description', sa.String),
sa.column('is_main', sa.Boolean),
sa.column('is_deletable', sa.Boolean),
sa.column('is_current', sa.Boolean),
sa.column('created_at', sa.DateTime),
sa.column('updated_at', sa.DateTime),
)
op.bulk_insert(
playlist_table,
[
{
'name': 'All',
'description': 'The default main playlist with all the tracks',
'is_main': True,
'is_deletable': False,
'is_current': True,
'created_at': now,
'updated_at': now,
}
]
)
def downgrade() -> None:
"""Downgrade schema and remove initial data."""
# Remove initial plans
op.execute("DELETE FROM plan WHERE code IN ('free', 'premium', 'pro')")
# Remove main playlist
op.execute("DELETE FROM playlist WHERE is_main = 1")

View File

@@ -129,7 +129,7 @@ async def update_task(
@router.delete("/tasks/{task_id}")
async def cancel_task(
async def delete_task(
task_id: int,
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
scheduler_service: Annotated[
@@ -137,7 +137,7 @@ async def cancel_task(
] = ...,
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
) -> dict:
"""Cancel a scheduled task."""
"""Delete a scheduled task completely."""
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
@@ -148,11 +148,11 @@ async def cancel_task(
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
success = await scheduler_service.cancel_task(task_id)
success = await scheduler_service.delete_task(task_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to cancel task")
raise HTTPException(status_code=400, detail="Failed to delete task")
return {"message": "Task cancelled successfully"}
return {"message": "Task deleted successfully"}
# Admin-only endpoints

View File

@@ -8,7 +8,6 @@ from sqlmodel.ext.asyncio.session import AsyncSession
import app.models # noqa: F401
from app.core.config import settings
from app.core.logging import get_logger
from app.core.seeds import seed_all_data
engine: AsyncEngine = create_async_engine(
settings.DATABASE_URL,
@@ -40,26 +39,23 @@ def get_session_factory() -> Callable[[], AsyncSession]:
async def init_db() -> None:
"""Initialize the database and create tables if they do not exist."""
"""Initialize the database using Alembic migrations."""
logger = get_logger(__name__)
try:
logger.info("Initializing database tables")
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
logger.info("Database tables created successfully")
logger.info("Running database migrations")
# Run Alembic migrations programmatically
from alembic import command
from alembic.config import Config
# Seed initial data
await seed_initial_data()
# Get the alembic config
alembic_cfg = Config("alembic.ini")
# Run migrations to the latest revision
command.upgrade(alembic_cfg, "head")
logger.info("Database migrations completed successfully")
except Exception:
logger.exception("Failed to initialize database")
raise
async def seed_initial_data() -> None:
"""Seed initial data into the database."""
logger = get_logger(__name__)
logger.info("Starting initial data seeding")
async with AsyncSession(engine) as session:
await seed_all_data(session)

View File

@@ -403,6 +403,44 @@ class CreditService:
finally:
await session.close()
async def recharge_user_credits_auto(
self,
user_id: int,
) -> CreditTransaction | None:
"""Recharge credits for a user automatically based on their plan.
Args:
user_id: The user ID
Returns:
The created credit transaction if credits were added, None if no recharge
needed
Raises:
ValueError: If user not found or has no plan
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id_with_plan(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if not user.plan:
msg = f"User {user_id} has no plan assigned"
raise ValueError(msg)
# Call the main method with plan details
return await self.recharge_user_credits(
user_id,
user.plan.credits,
user.plan.max_credits,
)
finally:
await session.close()
async def recharge_user_credits(
self,
user_id: int,
@@ -556,7 +594,13 @@ class CreditService:
if transaction:
stats["recharged_users"] += 1
stats["total_credits_added"] += transaction.amount
# Calculate the amount from plan data to avoid session issues
current_credits = user.credits
plan_credits = user.plan.credits
max_credits = user.plan.max_credits
target_credits = min(current_credits + plan_credits, max_credits)
credits_added = target_credits - current_credits
stats["total_credits_added"] += credits_added
else:
stats["skipped_users"] += 1

View File

@@ -144,6 +144,25 @@ class SchedulerService:
logger.info("Cancelled task: %s (%s)", task.name, task_id)
return True
async def delete_task(self, task_id: int) -> bool:
"""Delete a scheduled task completely."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
task = await repo.get_by_id(task_id)
if not task:
return False
# Remove from APScheduler first (job might not exist in scheduler)
with suppress(Exception):
self.scheduler.remove_job(str(task_id))
# Delete from database
await repo.delete(task)
logger.info("Deleted task: %s (%s)", task.name, task_id)
return True
async def get_user_tasks(
self,
user_id: int,

View File

@@ -80,8 +80,11 @@ class TaskHandlerRegistry:
msg = f"Invalid user_id format: {user_id}"
raise TaskExecutionError(msg) from e
stats = await self.credit_service.recharge_user_credits(user_id_int)
logger.info("Recharged credits for user %s: %s", user_id, stats)
transaction = await self.credit_service.recharge_user_credits_auto(user_id_int)
if transaction:
logger.info("Recharged credits for user %s: %s credits added", user_id, transaction.amount)
else:
logger.info("No credits added for user %s (already at maximum)", user_id)
else:
# Recharge all users (system task)
stats = await self.credit_service.recharge_all_users_credits()

84
migrate.py Executable file
View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python3
"""Database migration CLI tool."""
import argparse
import sys
from pathlib import Path
from alembic import command
from alembic.config import Config
def main() -> None:
"""Main CLI function for database migrations."""
parser = argparse.ArgumentParser(description="Database migration tool")
subparsers = parser.add_subparsers(dest="command", help="Migration commands")
# Upgrade command
upgrade_parser = subparsers.add_parser("upgrade", help="Upgrade database to latest revision")
upgrade_parser.add_argument(
"revision",
nargs="?",
default="head",
help="Target revision (default: head)"
)
# Downgrade command
downgrade_parser = subparsers.add_parser("downgrade", help="Downgrade database")
downgrade_parser.add_argument("revision", help="Target revision")
# Current command
subparsers.add_parser("current", help="Show current revision")
# History command
subparsers.add_parser("history", help="Show revision history")
# Generate migration command
revision_parser = subparsers.add_parser("revision", help="Create new migration")
revision_parser.add_argument("-m", "--message", required=True, help="Migration message")
revision_parser.add_argument("--autogenerate", action="store_true", help="Auto-generate migration")
args = parser.parse_args()
if not args.command:
parser.print_help()
sys.exit(1)
# Get the alembic config
config_path = Path("alembic.ini")
if not config_path.exists():
print("Error: alembic.ini not found. Run from the backend directory.")
sys.exit(1)
alembic_cfg = Config(str(config_path))
try:
if args.command == "upgrade":
command.upgrade(alembic_cfg, args.revision)
print(f"Successfully upgraded database to revision: {args.revision}")
elif args.command == "downgrade":
command.downgrade(alembic_cfg, args.revision)
print(f"Successfully downgraded database to revision: {args.revision}")
elif args.command == "current":
command.current(alembic_cfg)
elif args.command == "history":
command.history(alembic_cfg)
elif args.command == "revision":
if args.autogenerate:
command.revision(alembic_cfg, message=args.message, autogenerate=True)
print(f"Created new auto-generated migration: {args.message}")
else:
command.revision(alembic_cfg, message=args.message)
print(f"Created new empty migration: {args.message}")
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -6,6 +6,7 @@ readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"aiosqlite==0.21.0",
"alembic==1.16.5",
"apscheduler==3.11.0",
"bcrypt==4.3.0",
"email-validator==2.3.0",

View File

@@ -174,6 +174,39 @@ class TestSchedulerService:
result = await scheduler_service.cancel_task(uuid.uuid4())
assert result is False
async def test_delete_task(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test deleting a task."""
# Create a task first
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(task_data=schema)
# Mock the scheduler remove_job method
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
result = await scheduler_service.delete_task(task.id)
assert result is True
mock_remove.assert_called_once_with(str(task.id))
# Check task is completely deleted from database
from app.repositories.scheduled_task import ScheduledTaskRepository
async with scheduler_service.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
deleted_task = await repo.get_by_id(task.id)
assert deleted_task is None
async def test_delete_nonexistent_task(
self,
scheduler_service: SchedulerService,
):
"""Test deleting a non-existent task."""
result = await scheduler_service.delete_task(uuid.uuid4())
assert result is False
async def test_get_user_tasks(
self,
scheduler_service: SchedulerService,

View File

@@ -92,14 +92,14 @@ class TestTaskHandlerRegistry:
parameters={"user_id": str(test_user_id)},
)
mock_credit_service.recharge_user_credits.return_value = {
"user_id": str(test_user_id),
"credits_added": 100,
}
# Mock transaction object
mock_transaction = MagicMock()
mock_transaction.amount = 100
mock_credit_service.recharge_user_credits_auto.return_value = mock_transaction
await task_registry.execute_task(task)
mock_credit_service.recharge_user_credits.assert_called_once_with(test_user_id)
mock_credit_service.recharge_user_credits_auto.assert_called_once_with(test_user_id)
async def test_handle_credit_recharge_uuid_user_id(
self,
@@ -117,7 +117,7 @@ class TestTaskHandlerRegistry:
await task_registry.execute_task(task)
mock_credit_service.recharge_user_credits.assert_called_once_with(test_user_id)
mock_credit_service.recharge_user_credits_auto.assert_called_once_with(test_user_id)
async def test_handle_play_sound_success(
self,

28
uv.lock generated
View File

@@ -14,6 +14,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792 },
]
[[package]]
name = "alembic"
version = "1.16.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mako" },
{ name = "sqlalchemy" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9a/ca/4dc52902cf3491892d464f5265a81e9dff094692c8a049a3ed6a05fe7ee8/alembic-1.16.5.tar.gz", hash = "sha256:a88bb7f6e513bd4301ecf4c7f2206fe93f9913f9b48dac3b78babde2d6fe765e", size = 1969868 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/39/4a/4c61d4c84cfd9befb6fa08a702535b27b21fff08c946bc2f6139decbf7f7/alembic-1.16.5-py3-none-any.whl", hash = "sha256:e845dfe090c5ffa7b92593ae6687c5cb1a101e91fa53868497dbd79847f9dbe3", size = 247355 },
]
[[package]]
name = "annotated-types"
version = "0.7.0"
@@ -465,6 +479,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899 },
]
[[package]]
name = "mako"
version = "1.3.10"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markupsafe" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509 },
]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
@@ -959,6 +985,7 @@ version = "2.0.0"
source = { virtual = "." }
dependencies = [
{ name = "aiosqlite" },
{ name = "alembic" },
{ name = "apscheduler" },
{ name = "asyncpg" },
{ name = "bcrypt" },
@@ -991,6 +1018,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "aiosqlite", specifier = "==0.21.0" },
{ name = "alembic", specifier = "==1.16.5" },
{ name = "apscheduler", specifier = "==3.11.0" },
{ name = "asyncpg", specifier = "==0.30.0" },
{ name = "bcrypt", specifier = "==4.3.0" },