from datetime import datetime, timezone
from typing import Optional, Tuple
from uuid import UUID

from sqlalchemy import and_, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from src.apps.plots.models.plot import Plot
from src.apps.records.models.burial_info import BurialInfo
from src.apps.records.models.family_contact import FamilyContact
from src.apps.records.models.record import Record
from src.apps.sections.models.section import Section
from src.apps.site_admin.services.audit_service import AuditService
from src.core.exceptions import NotFoundError


class RecordService:
    def __init__(self, db: AsyncSession):
        self.db = db

    async def list_records(
        self,
        tenant_id: UUID,
        search: Optional[str] = None,
        section: Optional[str] = None,
        record_type: Optional[str] = None,
        page: int = 1,
        page_size: int = 20,
    ) -> Tuple[list, int]:
        conditions = [
            Record.tenant_id == tenant_id,
            Record.deleted_at.is_(None),
        ]

        if search:
            fts_cond = func.to_tsvector(
                'english',
                func.coalesce(Record.first_name, '') + ' ' + func.coalesce(Record.last_name, '') + ' ' + func.coalesce(Record.maiden_name, '')
            ).op('@@')(func.plainto_tsquery('english', search))
            conditions.append(fts_cond)

        if record_type:
            conditions.append(Record.status == record_type)

        where_clause = and_(*conditions)

        # Build count query
        count_query = select(func.count(Record.id)).where(where_clause)
        # Build data query
        data_query = (
            select(Record)
            .where(where_clause)
            .options(
                selectinload(Record.plot).selectinload(Plot.section),
                selectinload(Record.burial_info),
            )
            .order_by(Record.last_name.asc(), Record.first_name.asc())
        )

        # Apply section join only when filtering by section
        if section:
            count_query = (
                select(func.count(Record.id))
                .where(where_clause)
                .join(Record.plot)
                .join(Plot.section)
                .where(Section.code == section)
            )
            data_query = (
                select(Record)
                .where(where_clause)
                .join(Record.plot)
                .join(Plot.section)
                .where(Section.code == section)
                .options(
                    selectinload(Record.plot).selectinload(Plot.section),
                    selectinload(Record.burial_info),
                )
                .order_by(Record.last_name.asc(), Record.first_name.asc())
            )

        count_result = await self.db.execute(count_query)
        total = count_result.scalar_one()

        offset = (page - 1) * page_size
        result = await self.db.execute(
            data_query.offset(offset).limit(page_size)
        )
        records = result.scalars().all()
        return records, total

    async def get_by_id(self, record_id: UUID, tenant_id: UUID) -> Record:
        result = await self.db.execute(
            select(Record)
            .options(
                selectinload(Record.burial_info),
                selectinload(Record.family_contacts),
                selectinload(Record.plot).selectinload(Plot.section),
            )
            .where(
                and_(
                    Record.id == record_id,
                    Record.tenant_id == tenant_id,
                    Record.deleted_at.is_(None),
                )
            )
        )
        record = result.scalar_one_or_none()
        if not record:
            raise NotFoundError("Record not found")
        return record

    async def create(
        self,
        tenant_id: UUID,
        data: dict,
        current_user=None,
        request=None,
    ) -> Record:
        record = Record(tenant_id=tenant_id, **data)
        self.db.add(record)
        await self.db.flush()
        if current_user and request:
            await AuditService.log(self.db, 'Record', record.id, 'create', current_user, request)
        return await self.get_by_id(record.id, tenant_id)

    async def update(
        self,
        record_id: UUID,
        tenant_id: UUID,
        data: dict,
        current_user=None,
        request=None,
    ) -> Record:
        record = await self.get_by_id(record_id, tenant_id)
        old_snapshot = {k: str(getattr(record, k, None)) for k in data}
        for field, value in data.items():
            if hasattr(record, field):
                setattr(record, field, value)
        await self.db.flush()
        if current_user and request:
            await AuditService.log(
                self.db, 'Record', record.id, 'update',
                current_user, request,
                old_value=old_snapshot,
                new_value={k: str(v) for k, v in data.items()},
            )
        return await self.get_by_id(record_id, tenant_id)

    async def soft_delete(
        self,
        record_id: UUID,
        tenant_id: UUID,
        current_user=None,
        request=None,
    ) -> None:
        record = await self.get_by_id(record_id, tenant_id)
        record.deleted_at = datetime.now(timezone.utc)
        await self.db.flush()
        if current_user and request:
            await AuditService.log(self.db, 'Record', record.id, 'delete', current_user, request)

    async def export_csv(
        self,
        tenant_id: UUID,
        record_ids: Optional[list] = None,
        search: Optional[str] = None,
        section: Optional[str] = None,
        record_type: Optional[str] = None,
    ):
        import csv
        import io

        if record_ids:
            conditions = [
                Record.tenant_id == tenant_id,
                Record.deleted_at.is_(None),
                Record.id.in_(record_ids),
            ]
            result = await self.db.execute(
                select(Record)
                .options(
                    selectinload(Record.plot).selectinload(Plot.section),
                    selectinload(Record.burial_info),
                )
                .where(and_(*conditions))
            )
            records = result.scalars().all()
        else:
            records, _ = await self.list_records(
                tenant_id=tenant_id,
                search=search,
                section=section,
                record_type=record_type,
                page=1,
                page_size=10000,
            )

        header = "id,full_name,maiden_name,dob,dod,plot_ref,section,date_interred,status,is_veteran\r\n"
        yield header
        for r in records:
            row_io = io.StringIO()
            writer = csv.writer(row_io)
            writer.writerow([
                str(r.id),
                f"{r.first_name} {r.last_name}".strip(),
                r.maiden_name or '',
                r.date_of_birth.isoformat() if r.date_of_birth else '',
                r.date_of_death.isoformat() if r.date_of_death else '',
                r.plot.plot_ref if r.plot else '',
                r.plot.section.code if r.plot and r.plot.section else '',
                r.burial_info.interment_date.isoformat() if r.burial_info and r.burial_info.interment_date else '',
                r.status,
                'true' if r.is_veteran else 'false',
            ])
            yield row_io.getvalue()

    # --- Burial Info ---

    async def upsert_burial_info(
        self, record_id: UUID, tenant_id: UUID, data: dict
    ) -> BurialInfo:
        # Ensure record exists and belongs to tenant
        await self.get_by_id(record_id, tenant_id)

        result = await self.db.execute(
            select(BurialInfo).where(BurialInfo.record_id == record_id)
        )
        burial = result.scalar_one_or_none()
        if burial:
            for field, value in data.items():
                if hasattr(burial, field):
                    setattr(burial, field, value)
        else:
            burial = BurialInfo(record_id=record_id, tenant_id=tenant_id, **data)
            self.db.add(burial)

        await self.db.flush()
        await self.db.refresh(burial)
        return burial

    # --- Family Contacts ---

    async def add_family_contact(
        self, record_id: UUID, tenant_id: UUID, data: dict
    ) -> FamilyContact:
        await self.get_by_id(record_id, tenant_id)
        contact_data = {**data}
        if "relationship" in contact_data:
            contact_data["relationship_type"] = contact_data.pop("relationship")
        contact = FamilyContact(record_id=record_id, tenant_id=tenant_id, **contact_data)
        self.db.add(contact)
        await self.db.flush()
        await self.db.refresh(contact)
        return contact

    async def delete_family_contact(
        self, contact_id: UUID, tenant_id: UUID
    ) -> None:
        result = await self.db.execute(
            select(FamilyContact).where(
                and_(
                    FamilyContact.id == contact_id,
                    FamilyContact.tenant_id == tenant_id,
                )
            )
        )
        contact = result.scalar_one_or_none()
        if not contact:
            raise NotFoundError("Family contact not found")
        await self.db.delete(contact)
        await self.db.flush()

    async def update_family_contact(
        self, contact_id: UUID, tenant_id: UUID, record_id: UUID, data: dict
    ) -> FamilyContact:
        result = await self.db.execute(
            select(FamilyContact).where(
                and_(
                    FamilyContact.id == contact_id,
                    FamilyContact.tenant_id == tenant_id,
                    FamilyContact.record_id == record_id,
                )
            )
        )
        contact = result.scalar_one_or_none()
        if not contact:
            raise NotFoundError("Family contact not found")
        update_data = {**data}
        if "relationship" in update_data:
            update_data["relationship_type"] = update_data.pop("relationship")
        for field, value in update_data.items():
            if hasattr(contact, field):
                setattr(contact, field, value)
        await self.db.flush()
        await self.db.refresh(contact)
        return contact
