"""SQLAlchemy-powered Session backend.

Usage::

    from agents.extensions.memory import SQLAlchemySession

    # Create from SQLAlchemy URL (uses asyncpg driver under the hood for Postgres)
    session = SQLAlchemySession.from_url(
        session_id="user-123",
        url="postgresql+asyncpg://app:secret@db.example.com/agents",
        create_tables=True, # If you want to auto-create tables, set to True.
    )

    # Or pass an existing AsyncEngine that your application already manages
    session = SQLAlchemySession(
        session_id="user-123",
        engine=my_async_engine,
        create_tables=True, # If you want to auto-create tables, set to True.
    )

    await Runner.run(agent, "Hello", session=session)
"""

from __future__ import annotations

import asyncio
import json
from typing import Any

from sqlalchemy import (
    TIMESTAMP,
    Column,
    ForeignKey,
    Index,
    Integer,
    MetaData,
    String,
    Table,
    Text,
    delete,
    insert,
    select,
    text as sql_text,
    update,
)
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine

from ...items import TResponseInputItem
from ...memory.session import SessionABC


class SQLAlchemySession(SessionABC):
    """SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`."""

    _metadata: MetaData
    _sessions: Table
    _messages: Table

    def __init__(
        self,
        session_id: str,
        *,
        engine: AsyncEngine,
        create_tables: bool = False,
        sessions_table: str = "agent_sessions",
        messages_table: str = "agent_messages",
    ):
        """Initializes a new SQLAlchemySession.

        Args:
            session_id (str): Unique identifier for the conversation.
            engine (AsyncEngine): A pre-configured SQLAlchemy async engine. The engine
                must be created with an async driver (e.g., 'postgresql+asyncpg://',
                'mysql+aiomysql://', or 'sqlite+aiosqlite://').
            create_tables (bool, optional): Whether to automatically create the required
                tables and indexes. Defaults to False for production use. Set to True for
                development and testing when migrations aren't used.
            sessions_table (str, optional): Override the default table name for sessions if needed.
            messages_table (str, optional): Override the default table name for messages if needed.
        """
        self.session_id = session_id
        self._engine = engine
        self._lock = asyncio.Lock()

        self._metadata = MetaData()
        self._sessions = Table(
            sessions_table,
            self._metadata,
            Column("session_id", String, primary_key=True),
            Column(
                "created_at",
                TIMESTAMP(timezone=False),
                server_default=sql_text("CURRENT_TIMESTAMP"),
                nullable=False,
            ),
            Column(
                "updated_at",
                TIMESTAMP(timezone=False),
                server_default=sql_text("CURRENT_TIMESTAMP"),
                onupdate=sql_text("CURRENT_TIMESTAMP"),
                nullable=False,
            ),
        )

        self._messages = Table(
            messages_table,
            self._metadata,
            Column("id", Integer, primary_key=True, autoincrement=True),
            Column(
                "session_id",
                String,
                ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"),
                nullable=False,
            ),
            Column("message_data", Text, nullable=False),
            Column(
                "created_at",
                TIMESTAMP(timezone=False),
                server_default=sql_text("CURRENT_TIMESTAMP"),
                nullable=False,
            ),
            Index(
                f"idx_{messages_table}_session_time",
                "session_id",
                "created_at",
            ),
            sqlite_autoincrement=True,
        )

        # Async session factory
        self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)

        self._create_tables = create_tables

    # ---------------------------------------------------------------------
    # Convenience constructors
    # ---------------------------------------------------------------------
    @classmethod
    def from_url(
        cls,
        session_id: str,
        *,
        url: str,
        engine_kwargs: dict[str, Any] | None = None,
        **kwargs: Any,
    ) -> SQLAlchemySession:
        """Create a session from a database URL string.

        Args:
            session_id (str): Conversation ID.
            url (str): Any SQLAlchemy async URL, e.g. "postgresql+asyncpg://user:pass@host/db".
            engine_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to
                sqlalchemy.ext.asyncio.create_async_engine.
            **kwargs: Additional keyword arguments forwarded to the main constructor
                (e.g., create_tables, custom table names, etc.).

        Returns:
            SQLAlchemySession: An instance of SQLAlchemySession connected to the specified database.
        """
        engine_kwargs = engine_kwargs or {}
        engine = create_async_engine(url, **engine_kwargs)
        return cls(session_id, engine=engine, **kwargs)

    async def _serialize_item(self, item: TResponseInputItem) -> str:
        """Serialize an item to JSON string. Can be overridden by subclasses."""
        return json.dumps(item, separators=(",", ":"))

    async def _deserialize_item(self, item: str) -> TResponseInputItem:
        """Deserialize a JSON string to an item. Can be overridden by subclasses."""
        return json.loads(item)  # type: ignore[no-any-return]

    # ------------------------------------------------------------------
    # Session protocol implementation
    # ------------------------------------------------------------------
    async def _ensure_tables(self) -> None:
        """Ensure tables are created before any database operations."""
        if self._create_tables:
            async with self._engine.begin() as conn:
                await conn.run_sync(self._metadata.create_all)
            self._create_tables = False  # Only create once

    async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
        """Retrieve the conversation history for this session.

        Args:
            limit: Maximum number of items to retrieve. If None, retrieves all items.
                   When specified, returns the latest N items in chronological order.

        Returns:
            List of input items representing the conversation history
        """
        await self._ensure_tables()
        async with self._session_factory() as sess:
            if limit is None:
                stmt = (
                    select(self._messages.c.message_data)
                    .where(self._messages.c.session_id == self.session_id)
                    .order_by(self._messages.c.created_at.asc())
                )
            else:
                stmt = (
                    select(self._messages.c.message_data)
                    .where(self._messages.c.session_id == self.session_id)
                    # Use DESC + LIMIT to get the latest N
                    # then reverse later for chronological order.
                    .order_by(self._messages.c.created_at.desc())
                    .limit(limit)
                )

            result = await sess.execute(stmt)
            rows: list[str] = [row[0] for row in result.all()]

            if limit is not None:
                rows.reverse()

            items: list[TResponseInputItem] = []
            for raw in rows:
                try:
                    items.append(await self._deserialize_item(raw))
                except json.JSONDecodeError:
                    # Skip corrupted rows
                    continue
            return items

    async def add_items(self, items: list[TResponseInputItem]) -> None:
        """Add new items to the conversation history.

        Args:
            items: List of input items to add to the history
        """
        if not items:
            return

        await self._ensure_tables()
        payload = [
            {
                "session_id": self.session_id,
                "message_data": await self._serialize_item(item),
            }
            for item in items
        ]

        async with self._session_factory() as sess:
            async with sess.begin():
                # Ensure the parent session row exists - use merge for cross-DB compatibility
                # Check if session exists
                existing = await sess.execute(
                    select(self._sessions.c.session_id).where(
                        self._sessions.c.session_id == self.session_id
                    )
                )
                if not existing.scalar_one_or_none():
                    # Session doesn't exist, create it
                    await sess.execute(
                        insert(self._sessions).values({"session_id": self.session_id})
                    )

                # Insert messages in bulk
                await sess.execute(insert(self._messages), payload)

                # Touch updated_at column
                await sess.execute(
                    update(self._sessions)
                    .where(self._sessions.c.session_id == self.session_id)
                    .values(updated_at=sql_text("CURRENT_TIMESTAMP"))
                )

    async def pop_item(self) -> TResponseInputItem | None:
        """Remove and return the most recent item from the session.

        Returns:
            The most recent item if it exists, None if the session is empty
        """
        await self._ensure_tables()
        async with self._session_factory() as sess:
            async with sess.begin():
                # Fallback for all dialects - get ID first, then delete
                subq = (
                    select(self._messages.c.id)
                    .where(self._messages.c.session_id == self.session_id)
                    .order_by(self._messages.c.created_at.desc())
                    .limit(1)
                )
                res = await sess.execute(subq)
                row_id = res.scalar_one_or_none()
                if row_id is None:
                    return None
                # Fetch data before deleting
                res_data = await sess.execute(
                    select(self._messages.c.message_data).where(self._messages.c.id == row_id)
                )
                row = res_data.scalar_one_or_none()
                await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))

                if row is None:
                    return None
                try:
                    return await self._deserialize_item(row)
                except json.JSONDecodeError:
                    return None

    async def clear_session(self) -> None:
        """Clear all items for this session."""
        await self._ensure_tables()
        async with self._session_factory() as sess:
            async with sess.begin():
                await sess.execute(
                    delete(self._messages).where(self._messages.c.session_id == self.session_id)
                )
                await sess.execute(
                    delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
                )
