"""Generic async repository base class."""
from typing import Generic, TypeVar, Type, Optional, List, Any
from uuid import UUID

from sqlalchemy import select, func, update, delete
from sqlalchemy.ext.asyncio import AsyncSession

from src.database.base import BaseModel

ModelT = TypeVar("ModelT", bound=BaseModel)


class BaseRepository(Generic[ModelT]):
    def __init__(self, model: Type[ModelT], db: AsyncSession):
        self.model = model
        self.db = db

    async def get(self, id: UUID) -> Optional[ModelT]:
        result = await self.db.execute(select(self.model).where(self.model.id == id))
        return result.scalar_one_or_none()

    async def list(
        self,
        *filters: Any,
        offset: int = 0,
        limit: int = 20,
        order_by: Any = None,
    ) -> tuple[List[ModelT], int]:
        base_q = select(self.model).where(*filters)
        count_q = select(func.count()).select_from(base_q.subquery())
        total = (await self.db.execute(count_q)).scalar_one()

        if order_by is not None:
            base_q = base_q.order_by(order_by)
        else:
            base_q = base_q.order_by(self.model.created_at.desc())

        result = await self.db.execute(base_q.offset(offset).limit(limit))
        items = list(result.scalars().all())
        return items, total

    async def create(self, **kwargs: Any) -> ModelT:
        instance = self.model(**kwargs)
        self.db.add(instance)
        await self.db.flush()
        await self.db.refresh(instance)
        return instance

    async def update(self, id: UUID, **kwargs: Any) -> Optional[ModelT]:
        await self.db.execute(
            update(self.model)
            .where(self.model.id == id)
            .values(**kwargs)
            .execution_options(synchronize_session="fetch")
        )
        return await self.get(id)

    async def delete(self, id: UUID) -> bool:
        result = await self.db.execute(
            delete(self.model).where(self.model.id == id)
        )
        return result.rowcount > 0

    async def exists(self, *filters: Any) -> bool:
        q = select(func.count()).select_from(self.model).where(*filters)
        count = (await self.db.execute(q)).scalar_one()
        return count > 0
