database.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
  2. from sqlalchemy.orm import sessionmaker
  3. from sqlalchemy import text, select
  4. from app.models import Base, AppSettings
  5. from app.config import get_settings
  6. settings = get_settings()
  7. # Convert sqlite:/// to sqlite+aiosqlite:/// for async support
  8. db_url = settings.database_url.replace("sqlite://", "sqlite+aiosqlite://")
  9. # Create async engine
  10. engine = create_async_engine(db_url, echo=True)
  11. # Create async session factory
  12. async_session = async_sessionmaker(
  13. engine, class_=AsyncSession, expire_on_commit=False
  14. )
  15. async def init_db():
  16. """Initialize database tables."""
  17. async with engine.begin() as conn:
  18. # Check if migration is needed (users table doesn't exist)
  19. result = await conn.execute(text(
  20. "SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
  21. ))
  22. needs_migration = result.fetchone() is None
  23. # Check if we have old schema (listening_sessions exists but no users table)
  24. result = await conn.execute(text(
  25. "SELECT name FROM sqlite_master WHERE type='table' AND name='listening_sessions'"
  26. ))
  27. has_old_schema = result.fetchone() is not None
  28. if needs_migration and has_old_schema:
  29. # Need to run migration
  30. print("Existing database detected without users table - migration required")
  31. print("Please run: python -m app.migrations.add_multi_user")
  32. print("Or delete absrecommend.db to start fresh")
  33. raise RuntimeError("Database migration required")
  34. # Create all tables (will skip existing ones)
  35. await conn.run_sync(Base.metadata.create_all)
  36. # Initialize default settings
  37. async with async_session() as session:
  38. # Check if settings already exist
  39. result = await session.execute(
  40. select(AppSettings).where(AppSettings.key == "allow_registration")
  41. )
  42. if not result.scalar_one_or_none():
  43. # Create default settings
  44. default_settings = [
  45. AppSettings(key="allow_registration", value="true"),
  46. ]
  47. session.add_all(default_settings)
  48. await session.commit()
  49. async def get_db():
  50. """Dependency for getting database session."""
  51. async with async_session() as session:
  52. try:
  53. yield session
  54. finally:
  55. await session.close()