from pydantic import Field
from sqlalchemy import INTEGER, Float, ForeignKey, Integer, String
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
from app.models.main import Base
from app.utils.schemas_utils import CustomModel

class CatchmentPotentialBase(CustomModel):
    catch_id: int | None = Field(default=None)
    ISEC_Segment: int | None = Field(default=None)
    number_of_households: int | None = Field(default=None)
    percentage_of_segment: float | None = Field(default=None)
    potential_number_of_households: int |  None =  Field(default=None)
    # Assumptions : str | None = Field(default=None)
    # Sources : str | None = Field(default=None)
    group_id : int | None = Field(default=None)

class TblCatchmentPotential(Base):
    __tablename__ = "tbl_catchment_potential"

    catch_id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    ISEC_Segment: Mapped[int] = mapped_column("ISEC_Segment", INTEGER, nullable=False)
    number_of_households: Mapped[int] = mapped_column(Integer, nullable=False)
    percentage_of_segment: Mapped[float] = mapped_column(Float, nullable=False)
    potential_number_of_households: Mapped[int] = mapped_column(Integer, nullable=False)
    # Assumptions : Mapped[str] = mapped_column(String(100), nullable=True)
    # Sources : Mapped[str] = mapped_column(String(100), nullable=True)
    group_id: Mapped[int] = mapped_column("group_id", INTEGER, ForeignKey("tbl_group.group_id"), nullable=True)

    group = relationship("TblGroup", back_populates="catchmentPotential")

    @classmethod
    def create(cls, data: CatchmentPotentialBase, db: Session) -> "TblCatchmentPotential":
        data_dict = data.model_dump()
        new_data = cls(**data_dict)
        db.add(new_data)
        db.flush()
        return new_data
    
    @classmethod
    def get_by_id(cls, group_id: int, db: Session) -> "TblCatchmentPotential | None":
        return db.query(cls).filter(cls.group_id == group_id).all()

    @classmethod
    def update(cls, catch_id: int, data: CatchmentPotentialBase, db: Session) -> "TblCatchmentPotential | None":
        get_data = db.query(cls).filter(cls.catch_id == catch_id).first()
        if not get_data:
            return None
        data_dict = data.model_dump()
        for key, value in data_dict.items():
            if value is not None:
                setattr(get_data, key, value)
        db.commit()  
        db.refresh(get_data)
        return get_data
    
    @classmethod
    def delete(cls, catch_id: int, db: Session) -> "TblCatchmentPotential":
        obj = db.query(cls).filter(cls.catch_id == catch_id).first()
        if not obj:
            return False
        db.delete(obj)
        db.commit()
        return obj

 