"""Catchment."""

import pandas as pd
from typing import Annotated, List
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from sqlalchemy.orm import Session
from app.api.catchment_potential import service
from app.database.main.mysql import get_db
from app.dependency.authantication import JWTPayloadSchema, get_current_student
from app.models.main.main_data import TblMainData
from .schemas import GetPotentialResponse, MainDataCreate, MainDataResponse, TotalPotentialResponse, UpdatePotential, catchmentPotentialCreat ,catchmentPotentialResponse

catchment_potential_router = APIRouter()

@catchment_potential_router.post("/catchment_creat", response_model_exclude_none=True)
async def create_catchment(request: catchmentPotentialCreat,db: Annotated[Session, Depends(get_db)],token: Annotated[JWTPayloadSchema, Depends( get_current_student)]):
    return await service.Catchment_Potential(db, token).create_catchment_potential(request)

@catchment_potential_router.get("/catchment_potential/{group_id}/total_potential", response_model=TotalPotentialResponse)
async def get_total_consumption_by_group(group_id: int,db: Session = Depends(get_db),token: JWTPayloadSchema = Depends(get_current_student),):
    return await service.Catchment_Potential(db, token).get_total_consumption_by_group(group_id)

@catchment_potential_router.get("/catchment/{group_id}", response_model_exclude_none=True)
async def get_simulation(group_id: int, db: Session = Depends(get_db), token: JWTPayloadSchema = Depends(get_current_student)):
    return await service.Catchment_Potential(db, token).get_catchment(group_id)

@catchment_potential_router.put("/catchment/update", response_model_exclude_none=True)
async def update_segment(request: List[UpdatePotential], db: Session = Depends(get_db),token: JWTPayloadSchema = Depends(get_current_student)):
    return await service.Catchment_Potential(db,token).update_catchment(request)

@catchment_potential_router.get("/catchment_potential/group/{group_id}",response_model=list[catchmentPotentialResponse],response_model_exclude_none=True)
async def get_customer_locations_by_group(group_id: int,db: Session = Depends(get_db),token: JWTPayloadSchema = Depends(get_current_student)):
    return await service.Catchment_Potential(db, token).get_catchment_poteneial_by_group(group_id)

@catchment_potential_router.delete("/catchment/delete/{catch_id}",response_model_exclude_none=True)
async def delete_catchment(catch_id: int,db: Session = Depends(get_db),token: JWTPayloadSchema = Depends(get_current_student)):
    return await service.Catchment_Potential(db, token).delete_catchment(catch_id)

@catchment_potential_router.post("/upload-main-data/")
async def upload_excel(file: UploadFile = File(...), db: Session = Depends(get_db)):
    df = pd.read_excel(file.file)

    # Normalize columns
    df.columns = [col.strip().lower().replace(" ", "_").replace("/", "_").replace("-", "_") for col in df.columns]

    # Safe getter
    def safe_get(row, key, default=None):
        return row.get(key, default) if key in row and not pd.isnull(row[key]) else default

    # Build list of schema objects
    main_data_list = []
    for _, row in df.iterrows():
        item = MainDataCreate(
            category_identifier=safe_get(row, "category_identifier", ""),
            format_type=safe_get(row, "format_type", ""),
            reference_retailer=safe_get(row, "reference_retailer", ""),
            no_of_stores=safe_get(row, "no_of_stores"),
            description=safe_get(row, "indicative_categories_description", ""),
            no_of_cats_l=safe_get(row, "no_of_cats_l"),
            no_of_cats_m=safe_get(row, "no_of_cats_m"),
            no_of_cats_h=safe_get(row, "no_of_cats_h"),
            indicative_size=safe_get(row, "indicative_size_square_feet"),
            capital_available=safe_get(row, "indicative_capital_available_rs_in_crores"),
            approx_investment=safe_get(row, "approx_investment_per_store_in_rupees_lakhs"),
            comp_intensity_l=safe_get(row, "competition_intensity_l"),
            comp_intensity_m=safe_get(row, "competition_intensity_m"),
            comp_intensity_h=safe_get(row, "competition_intensity_h"),
            tech_inv_h=safe_get(row, "technology_investment_in_rs_lakhs_h"),
            tech_inv_m=safe_get(row, "technology_investment_in_rs_lakhs_m"),
            tech_inv_l=safe_get(row, "technology_investment_in_rs_lakhs_l"),
            spillage_h=safe_get(row, "spillage_factor_h"),
            spillage_l=safe_get(row, "spillage_factor_l"),
            services_h=safe_get(row, "services_h"),
            services_m=safe_get(row, "services_m"),
            services_l=safe_get(row, "services_l")
        )
        main_data_list.append(item)

    # Delete old records
    db.query(TblMainData).delete()
    db.commit()

    # Insert new records
    for data in main_data_list:
        record = TblMainData(**data.dict())
        db.add(record)

    db.commit()

    return {"message": f"{len(main_data_list)} records replaced successfully"}

@catchment_potential_router.get("/main-data/", response_model=List[MainDataResponse])
def get_main_data_by_format_type(
    format_type: str,
    db: Session = Depends(get_db)
):
    results = db.query(TblMainData).filter(TblMainData.format_type == format_type).all()
    if not results:
        raise HTTPException(status_code=404, detail="No data found for this format type")
    return results
