from pydantic import Field
from app.utils.schemas_utils import CustomModel
from sqlalchemy import INTEGER, VARCHAR, Float, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship, Session
from app.models.main import Base

class NetworkPlanningBase(CustomModel):
    network_id: int | None = Field(default=None)
    location: str | None = Field(default=None)
    state: str | None = Field(default=None)
    total_hh: int | None = Field(default=None)
    target_hh: int | None = Field(default=None)
    area: float | None = Field(default=None)
    action: str | None = Field(default=None)
    format_a_y1: int | None = Field(default=None)
    format_a_y2: int | None = Field(default=None)
    format_a_y3: int | None = Field(default=None)
    format_a_y4: int | None = Field(default=None)
    format_a_y5: int | None = Field(default=None)
    format_b_y1: int | None = Field(default=None)
    format_b_y2: int | None = Field(default=None)
    format_b_y3: int | None = Field(default=None)
    format_b_y4: int | None = Field(default=None)
    format_b_y5: int | None = Field(default=None)
    total_store_count_y1: int | None = Field(default=None)
    total_store_count_y2: int | None = Field(default=None)
    total_store_count_y3: int | None = Field(default=None)
    total_store_count_y4: int | None = Field(default=None)
    total_store_count_y5: int | None = Field(default=None)
    group_id: int | None = Field(default=None)

class TblNetworkPlanning(Base):
    __tablename__ = "tbl_network_planning"

    network_id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    location: Mapped[str] = mapped_column(VARCHAR(255), nullable=False)
    state: Mapped[str] = mapped_column(VARCHAR(255), nullable=False)
    total_hh: Mapped[int] = mapped_column(INTEGER, nullable=False)
    target_hh: Mapped[int] = mapped_column(INTEGER, nullable=False)
    area: Mapped[float] = mapped_column(Float, nullable=False)
    action: Mapped[str] = mapped_column(VARCHAR(255), nullable=True)

    format_a_y1: Mapped[int] = mapped_column(INTEGER, nullable=True)
    format_a_y2: Mapped[int] = mapped_column(INTEGER, nullable=True)
    format_a_y3: Mapped[int] = mapped_column(INTEGER, nullable=True)
    format_a_y4: Mapped[int] = mapped_column(INTEGER, nullable=True)
    format_a_y5: Mapped[int] = mapped_column(INTEGER, nullable=True)

    format_b_y1: Mapped[int] = mapped_column(INTEGER, nullable=True)
    format_b_y2: Mapped[int] = mapped_column(INTEGER, nullable=True)
    format_b_y3: Mapped[int] = mapped_column(INTEGER, nullable=True)
    format_b_y4: Mapped[int] = mapped_column(INTEGER, nullable=True)
    format_b_y5: Mapped[int] = mapped_column(INTEGER, nullable=True)

    total_store_count_y1: Mapped[int] = mapped_column(INTEGER, nullable=True)
    total_store_count_y2: Mapped[int] = mapped_column(INTEGER, nullable=True)
    total_store_count_y3: Mapped[int] = mapped_column(INTEGER, nullable=True)
    total_store_count_y4: Mapped[int] = mapped_column(INTEGER, nullable=True)
    total_store_count_y5: Mapped[int] = mapped_column(INTEGER, nullable=True)

    group_id: Mapped[int] = mapped_column(INTEGER, ForeignKey("tbl_group.group_id"), nullable=True)
    group = relationship("TblGroup", back_populates="network_planning")  

    @classmethod
    def create(cls, data: NetworkPlanningBase, db: Session) -> "TblNetworkPlanning":
        data_dict = data.model_dump()
        new_data = cls(**data_dict)
        db.add(new_data)
        db.flush()
        return new_data

    @classmethod
    def get_by_id(cls, network_id: int, db: Session) -> "TblNetworkPlanning | None":
        return db.query(cls).filter(cls.network_id == network_id).first()
    
    @classmethod
    def get_by_group_id(cls, group_id: int, db: Session) -> "TblNetworkPlanning | None":
        return db.query(cls).filter(cls.group_id == group_id).first()
        

    @classmethod
    def update(cls, network_id: int, data: NetworkPlanningBase, db: Session) -> "TblNetworkPlanning | None":
        obj = db.query(cls).filter(cls.network_id == network_id).first()
        if not obj:
            return None
        data_dict = data.model_dump()
        for key, value in data_dict.items():
            if value is not None:
                setattr(obj, key, value)
        db.commit()  
        db.refresh(obj) 
        return obj

