from pydantic import Field
from sqlalchemy import INTEGER, VARCHAR, BigInteger, ForeignKey
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
from app.models.main import Base
from app.utils.schemas_utils import CustomModel

class CustomerLocationBase(CustomModel):

    customer_location_id: int | None = Field(default=None)
    ISEC_Segment: int | None = Field(default=None)
    sample_size: int | None = Field(default=None)
    average_age: int | None = Field(default=None)
    income_levels: str | None = Field(default=None)
    occupation_mode: str | None = Field(default=None)
    education_level_mode: str | None = Field(default=None)
    shopping_frequency_per_month: int | None = Field(default=None)
    household_consumption_per_month: int | None = Field(default=None)
    total_household_consumption_per_month: int | None = Field(default=None)
    group_id : int | None = None

class TblCustomerLocation(Base):
    __tablename__ = "tbl_customer_location"

    customer_location_id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    ISEC_Segment: Mapped[int] = mapped_column("ISEC_Segment", INTEGER, nullable=False)
    sample_size: Mapped[int] = mapped_column(INTEGER, nullable=False)
    average_age: Mapped[int] = mapped_column(INTEGER, nullable=False)
    income_levels: Mapped[str] = mapped_column(VARCHAR(255), nullable=True)
    occupation_mode: Mapped[str] = mapped_column(VARCHAR(100), nullable=True)
    education_level_mode: Mapped[str] = mapped_column(VARCHAR(100), nullable=True)
    shopping_frequency_per_month: Mapped[int] = mapped_column(INTEGER, nullable=False)
    household_consumption_per_month: Mapped[int] = mapped_column(BigInteger, nullable=False)
    total_household_consumption_per_month: Mapped[int] = mapped_column(BigInteger, nullable=True) 
    group_id: Mapped[int] = mapped_column("group_id", INTEGER, ForeignKey("tbl_group.group_id"), nullable=True)

    group = relationship("TblGroup", back_populates="customerlocation")
   
    @classmethod
    def create(cls, data: CustomerLocationBase, db: Session) -> "TblCustomerLocation":
        data_dict = data.model_dump()  
        new_data = cls(**data_dict)  
        db.add(new_data)
        db.flush()
        return new_data
    
    @classmethod
    def get_by_id(cls, customer_location_id: int, db: Session) -> "TblCustomerLocation | None":
        return db.query(cls).filter(cls.customer_location_id == customer_location_id).first()

    @classmethod
    def update(cls, customer_location_id: int, data: CustomerLocationBase, db: Session) -> "TblCustomerLocation | None":
        get_data = db.query(cls).filter(cls.customer_location_id == customer_location_id).first()
        if not get_data:
            return None  
        data_dict = data.model_dump()
        for key, value in data_dict.items():
            if value is not None:
                setattr(get_data, key, value)
        db.commit()
        db.refresh(get_data)
        return get_data
    
    @classmethod
    def delete(cls, customer_location_id: int, db: Session) -> bool:
        obj = db.query(cls).filter(cls.customer_location_id == customer_location_id).first()
        if not obj:
            return False
        db.delete(obj)
        db.commit()
        return True



