"""SQLAlchemy engine setup, migration runner, and session helpers.""" from __future__ import annotations import logging import threading from pathlib import Path from sqlalchemy import create_engine, event, text from sqlalchemy.engine import Connection, Engine logger = logging.getLogger(__name__) _engine: Engine | None = None _thread_local = threading.local() def get_engine() -> Engine: global _engine if _engine is not None: return _engine from config import settings db_path = Path(settings.db_path).resolve() db_path.parent.mkdir(parents=True, exist_ok=True) _engine = create_engine( f"sqlite:///{db_path}", connect_args={"check_same_thread": False}, echo=False, ) # Apply pragmas on every new connection @event.listens_for(_engine, "connect") def set_sqlite_pragma(dbapi_conn, connection_record): cursor = dbapi_conn.cursor() cursor.execute("PRAGMA journal_mode=WAL") cursor.execute("PRAGMA foreign_keys=ON") cursor.execute("PRAGMA busy_timeout=5000") cursor.close() return _engine def get_db(): """FastAPI dependency. Yields a SQLAlchemy Connection, closes after request.""" engine = get_engine() with engine.connect() as conn: try: yield conn conn.commit() except Exception: conn.rollback() raise def get_thread_db() -> Connection: """ Return a thread-local DB connection for background threads. Each thread gets its own connection (SQLite requires this). Call conn.close() in thread teardown. """ if not hasattr(_thread_local, "conn") or _thread_local.conn is None: _thread_local.conn = get_engine().connect() return _thread_local.conn def run_migrations(engine: Engine) -> None: """Apply all pending SQL migration files in order.""" migrations_dir = Path(__file__).parent / "core" / "migrations" migration_files = sorted(migrations_dir.glob("*.sql")) with engine.connect() as conn: # Ensure tracking table exists conn.execute(text(""" CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at TEXT NOT NULL DEFAULT (datetime('now')) ) """)) conn.commit() applied = { row[0] for row in conn.execute( text("SELECT version FROM schema_migrations") ) } for mfile in migration_files: # Extract version number from filename: 001_initial.sql -> 1 version_str = mfile.name.split("_")[0] try: version = int(version_str) except ValueError: logger.warning("Skipping migration with non-numeric prefix: %s", mfile.name) continue if version in applied: continue logger.info("Applying migration: %s", mfile.name) sql = mfile.read_text(encoding="utf-8") # Execute each statement separately (SQLite doesn't support executescript in transactions) for statement in sql.split(";"): statement = statement.strip() if statement: conn.execute(text(statement)) conn.execute( text("INSERT INTO schema_migrations (version) VALUES (:v)"), {"v": version}, ) conn.commit() logger.info("Migration %d applied.", version)