import logging
from typing import Optional

from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse

from src.core.config import settings

logger = logging.getLogger(__name__)

# Paths that don't require tenant resolution
TENANT_EXEMPT_PATHS = {
    "/health",
    "/health/live",
    "/health/ready",
    "/api/docs",
    "/api/redoc",
    "/openapi.json",
    "/api/v1/auth/login",
    "/api/site-admin/auth/login",
    "/api/public/signup",
}


def extract_subdomain(host: str) -> Optional[str]:
    """Extract tenant subdomain from Host header.

    riverside.indelis.com → "riverside"
    indelis.com → None
    admin.indelis.com → None (system subdomain)
    localhost → None (dev, use header fallback)
    """
    host = host.split(":")[0].lower()  # strip port
    parts = host.split(".")

    domain = settings.APP_DOMAIN  # "indelis.com"
    domain_parts = domain.split(".")

    if len(parts) > len(domain_parts):
        subdomain = parts[0]
        if subdomain in settings.SYSTEM_SUBDOMAINS:
            return None
        return subdomain
    return None


class TenantMiddleware(BaseHTTPMiddleware):
    """Resolve tenant from request and inject tenant_id into request.state."""

    async def dispatch(self, request: Request, call_next):
        # Skip tenant resolution for exempt paths
        path = request.url.path
        if any(path.startswith(p) for p in TENANT_EXEMPT_PATHS):
            return await call_next(request)

        # Skip for site-admin routes (no tenant scoping)
        if path.startswith("/api/site-admin"):
            return await call_next(request)

        # Skip for public signup
        if path == "/api/public/signup":
            return await call_next(request)

        tenant_id = None
        tenant_slug = None

        # Priority 1: X-Tenant-ID header
        x_tenant_id = request.headers.get("X-Tenant-ID")
        if x_tenant_id:
            tenant_id = x_tenant_id

        # Priority 2: X-Tenant-Slug header
        if not tenant_id:
            x_tenant_slug = request.headers.get("X-Tenant-Slug")
            if x_tenant_slug:
                tenant_slug = x_tenant_slug

        # Priority 3: Subdomain
        if not tenant_id and not tenant_slug:
            host = request.headers.get("host", "")
            subdomain = extract_subdomain(host)
            if subdomain:
                tenant_slug = subdomain

        # Resolve tenant from slug if needed
        if not tenant_id and tenant_slug:
            try:
                tenant_id = await self._resolve_tenant_id(request, tenant_slug)
            except Exception as e:
                logger.warning(f"Tenant resolution failed for slug '{tenant_slug}': {e}")

        # Set tenant context on request state
        request.state.tenant_id = tenant_id
        request.state.tenant_slug = tenant_slug

        # For API routes that need tenant, enforce presence
        if path.startswith("/api/v1/") and not tenant_id:
            # Allow auth endpoints without tenant (site-wide auth)
            if not path.startswith("/api/v1/auth"):
                logger.debug(f"No tenant resolved for path: {path}")

        response = await call_next(request)
        return response

    async def _resolve_tenant_id(self, request: Request, slug: str) -> Optional[str]:
        """Resolve tenant_id from slug via Redis cache → DB."""
        # Lazy import to avoid circular imports at startup
        from src.database.session import get_redis
        from src.apps.tenants.services.tenant_service import TenantService

        # Try Redis cache first
        redis = await get_redis()
        cache_key = f"tenant:subdomain:{slug}"

        if redis:
            cached = await redis.get(cache_key)
            if cached:
                return cached.decode() if isinstance(cached, bytes) else cached

        # Fallback to DB
        async with request.app.state.db_session() as db:
            service = TenantService(db)
            account = await service.get_by_slug(slug)
            if account:
                tenant_id = str(account.id)
                if redis:
                    await redis.setex(cache_key, settings.REDIS_TENANT_CACHE_TTL, tenant_id)
                return tenant_id

        return None
