"""Billing router — invoices and payments."""
from uuid import UUID
from typing import Optional

from fastapi import APIRouter, Depends, Query
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession

from src.core.dependencies import get_current_user, require_min_role, require_tenant
from src.core.constants import UserRole
from src.core.schemas.response import success, paginated
from src.core.exceptions import NotFoundError, ForbiddenError
from src.database.session import get_db
from src.apps.auth.models.user import User
from src.apps.billing.models.invoice import Invoice
from src.apps.billing.models.invoice_payment import InvoicePayment
from src.apps.billing.schemas.requests import (
    InvoiceCreateRequest,
    InvoiceUpdateRequest,
    RecordPaymentRequest,
)
from src.apps.billing.schemas.responses import InvoiceResponse, InvoicePaymentResponse

router = APIRouter(prefix="/billing", tags=["billing"])


async def _get_invoice(invoice_id: UUID, tenant_id: str, db: AsyncSession) -> Invoice:
    result = await db.execute(
        select(Invoice).where(Invoice.id == invoice_id, Invoice.tenant_id == tenant_id)
    )
    invoice = result.scalar_one_or_none()
    if not invoice:
        raise NotFoundError("Invoice not found")
    return invoice


@router.get("/invoices")
async def list_invoices(
    page: int = Query(1, ge=1),
    page_size: int = Query(20, ge=1, le=100),
    status: Optional[str] = Query(None),
    current_user: User = Depends(require_min_role(UserRole.STAFF)),
    tenant_id: str = Depends(require_tenant),
    db: AsyncSession = Depends(get_db),
):
    filters = [Invoice.tenant_id == tenant_id]
    if status:
        filters.append(Invoice.status == status)

    total_q = select(func.count()).select_from(Invoice).where(*filters)
    total = (await db.execute(total_q)).scalar_one()

    offset = (page - 1) * page_size
    result = await db.execute(
        select(Invoice).where(*filters)
        .order_by(Invoice.created_at.desc())
        .offset(offset).limit(page_size)
    )
    items = [InvoiceResponse.model_validate(i) for i in result.scalars().all()]
    return paginated(items, total, page, page_size)


@router.get("/invoices/{invoice_id}")
async def get_invoice(
    invoice_id: UUID,
    current_user: User = Depends(require_min_role(UserRole.STAFF)),
    tenant_id: str = Depends(require_tenant),
    db: AsyncSession = Depends(get_db),
):
    invoice = await _get_invoice(invoice_id, tenant_id, db)
    return success(InvoiceResponse.model_validate(invoice))


@router.post("/invoices", status_code=201)
async def create_invoice(
    body: InvoiceCreateRequest,
    current_user: User = Depends(require_min_role(UserRole.MANAGER)),
    tenant_id: str = Depends(require_tenant),
    db: AsyncSession = Depends(get_db),
):
    count_q = select(func.count()).select_from(Invoice).where(Invoice.tenant_id == tenant_id)
    count = (await db.execute(count_q)).scalar_one()
    invoice_number = f"INV-{count + 1:05d}"

    invoice = Invoice(
        tenant_id=tenant_id,
        invoice_number=invoice_number,
        record_id=body.record_id,
        contract_id=body.contract_id,
        due_date=body.due_date,
        notes=body.notes,
        status="draft",
        subtotal=0,
        tax=0,
        total=0,
        amount_paid=0,
        balance=0,
    )
    db.add(invoice)
    await db.flush()
    await db.refresh(invoice)
    return success(InvoiceResponse.model_validate(invoice), "Invoice created")


@router.patch("/invoices/{invoice_id}")
async def update_invoice(
    invoice_id: UUID,
    body: InvoiceUpdateRequest,
    current_user: User = Depends(require_min_role(UserRole.MANAGER)),
    tenant_id: str = Depends(require_tenant),
    db: AsyncSession = Depends(get_db),
):
    invoice = await _get_invoice(invoice_id, tenant_id, db)
    if body.due_date is not None:
        invoice.due_date = body.due_date
    if body.notes is not None:
        invoice.notes = body.notes
    await db.flush()
    await db.refresh(invoice)
    return success(InvoiceResponse.model_validate(invoice))


@router.post("/invoices/{invoice_id}/payments", status_code=201)
async def record_payment(
    invoice_id: UUID,
    body: RecordPaymentRequest,
    current_user: User = Depends(require_min_role(UserRole.MANAGER)),
    tenant_id: str = Depends(require_tenant),
    db: AsyncSession = Depends(get_db),
):
    invoice = await _get_invoice(invoice_id, tenant_id, db)

    if invoice.balance <= 0:
        raise ForbiddenError("Invoice is already fully paid")

    payment = InvoicePayment(
        invoice_id=invoice.id,
        amount=body.amount,
        payment_method=body.payment_method,
        reference=body.reference,
        notes=body.notes,
    )
    db.add(payment)

    invoice.amount_paid = (invoice.amount_paid or 0) + body.amount
    invoice.balance = invoice.total - invoice.amount_paid
    if invoice.balance <= 0:
        invoice.status = "paid"
    elif invoice.amount_paid > 0:
        invoice.status = "partial"

    await db.flush()
    await db.refresh(payment)
    return success(InvoicePaymentResponse.model_validate(payment), "Payment recorded")


@router.get("/invoices/{invoice_id}/payments")
async def list_payments(
    invoice_id: UUID,
    current_user: User = Depends(require_min_role(UserRole.STAFF)),
    tenant_id: str = Depends(require_tenant),
    db: AsyncSession = Depends(get_db),
):
    await _get_invoice(invoice_id, tenant_id, db)
    result = await db.execute(
        select(InvoicePayment)
        .where(InvoicePayment.invoice_id == invoice_id)
        .order_by(InvoicePayment.created_at.desc())
    )
    items = [InvoicePaymentResponse.model_validate(p) for p in result.scalars().all()]
    return success(items)
