import logging
import time
from typing import Any

from fastapi import APIRouter, File, Form, UploadFile, HTTPException

from app.models.schemas import AnalyzeResponse
from app.services.ocr_service import extract_ocr_from_file, light_clean_ocr_text
from app.services.extraction_service import extract_tests_from_ocr
from app.services.validation_service import validate_and_enrich_tests
from app.services.interpretation_service import interpret_results

logger = logging.getLogger("lab_analyzer")
router = APIRouter()

OUTPUT_TEST_KEYS = (
    "test_name",
    "result",
    "unit",
    "normal_range",
    "status",
    "explanation",
)


def _interp_lookup(interpretations: list[dict]) -> dict[str, str]:
    lookup: dict[str, str] = {}
    for item in interpretations:
        if not isinstance(item, dict):
            continue
        name = str(item.get("test_name", "")).strip().lower()
        expl = str(item.get("explanation", "")).strip()
        if name and expl:
            lookup[name] = expl
    return lookup


def _match_explanation(test: dict, lookup: dict[str, str]) -> str:
    keys = [
        str(test.get("test_name", "")).strip().lower(),
        str(test.get("canonical_name", "")).strip().lower(),
    ]
    for key in keys:
        if key and key in lookup:
            return lookup[key]
    base = keys[0]
    for stored, expl in lookup.items():
        if base and (base in stored or stored in base):
            return expl
    return ""


def _normalize_test_fields(test: dict) -> dict[str, Any]:
    row: dict[str, Any] = {
        "test_name": test.get("test_name", ""),
        "result": test.get("result", ""),
        "unit": test.get("unit", ""),
        "normal_range": test.get("normal_range", ""),
        "status": test.get("status", ""),
    }
    if not str(row["result"]).strip() and test.get("percentage_value"):
        row["result"] = test.get("percentage_value", "")
        row["unit"] = test.get("percentage_unit", "") or row["unit"]
        row["normal_range"] = test.get("percentage_range", "") or row["normal_range"]
        if not str(row["status"]).strip():
            row["status"] = test.get("status_percentage", "")
    return row


def build_analyze_response(
    language: str,
    overall_summary: str,
    validated_tests: list[dict],
    interpretations: list[dict],
) -> dict[str, Any]:
    lookup = _interp_lookup(interpretations)
    tests_out: list[dict[str, Any]] = []

    for test in validated_tests:
        if not isinstance(test, dict):
            continue
        normalized = _normalize_test_fields(test)
        explanation = _match_explanation(test, lookup)
        row: dict[str, Any] = {}
        for key in OUTPUT_TEST_KEYS:
            if key == "explanation":
                row[key] = explanation
            else:
                val = normalized.get(key, "")
                if val is not None and str(val).strip():
                    row[key] = str(val).strip()
        if row.get("test_name"):
            tests_out.append(row)

    return {
        "language": language,
        "overall_summary": overall_summary,
        "tests": tests_out,
    }


@router.post(
    "/analyze",
    response_model=AnalyzeResponse,
    response_model_exclude_none=True,
)
async def analyze_lab_report(
    file: UploadFile = File(..., description="Lab report image or PDF"),
    language: str = Form("en", description='Explanation language: "en" or "ar"'),
):
    lang = language.strip().lower()
    if lang not in ("en", "ar"):
        raise HTTPException(status_code=400, detail='language must be "en" or "ar"')

    filename = file.filename or "upload.png"

    try:
        file_bytes = await file.read()
        if not file_bytes:
            raise HTTPException(status_code=400, detail="Empty file uploaded")

        t0 = time.perf_counter()

        logger.info("[OCR] Processing %s", filename)
        raw_ocr = await extract_ocr_from_file(file_bytes, filename)
        t_ocr = time.perf_counter()

        if not raw_ocr.strip():
            raise HTTPException(
                status_code=400,
                detail="OCR produced no text. Try a clearer image or PDF.",
            )

        cleaned_ocr = light_clean_ocr_text(raw_ocr)

        try:
            _, raw_tests = await extract_tests_from_ocr(cleaned_ocr)
            t_extract = time.perf_counter()
        except Exception as exc:
            logger.exception("[AI_EXTRACTION] Failed")
            raise HTTPException(
                status_code=502,
                detail=f"Extraction failed: {exc}",
            ) from exc

        validated_dicts = validate_and_enrich_tests(raw_tests)

        overall_summary = ""
        interpretations: list[dict] = []
        t_interp = time.perf_counter()

        if validated_dicts:
            try:
                interp_tests, overall_summary = await interpret_results(
                    validated_dicts, lang
                )
                interpretations = [t for t in interp_tests if isinstance(t, dict)]
            except Exception as exc:
                logger.exception("[INTERPRETATION] Failed")
                overall_summary = (
                    "تعذر إنشاء التفسير. ناقش النتائج مع طبيبك."
                    if lang == "ar"
                    else "Interpretation unavailable. Discuss results with your doctor."
                )
            t_interp = time.perf_counter()

        payload = build_analyze_response(
            language=lang,
            overall_summary=overall_summary,
            validated_tests=validated_dicts,
            interpretations=interpretations,
        )

        total = time.perf_counter() - t0
        if validated_dicts:
            logger.info(
                "[TIMING] OCR %.1fs | extract %.1fs | interpret %.1fs | total %.1fs | tests=%d",
                t_ocr - t0,
                t_extract - t_ocr,
                t_interp - t_extract,
                total,
                len(validated_dicts),
            )
        else:
            logger.info("[TIMING] OCR %.1fs | total %.1fs | no tests", t_ocr - t0, total)

        return AnalyzeResponse(**payload)

    except HTTPException:
        raise
    except Exception as exc:
        logger.exception("Pipeline error")
        raise HTTPException(status_code=500, detail=str(exc)) from exc
