from fastapi import HTTPException
from sqlalchemy.orm import Session
import pandas as pd
import os
from datetime import datetime
from io import BytesIO
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image
from reportlab.lib.units import inch

from app.models.main.customer_location import TblCustomerLocation
from app.models.main.catchment_potential import TblCatchmentPotential
from app.models.main.competitor_intensity import TblCompetitorIntensity
from app.models.main.competitor_benchmarking import TblCompetitorBenchmarking
from app.models.main.location_spillage_factor_analysis import TblLocationSpillageFactor
from app.models.main.location_details import TblLocationDetails
from app.models.main.sales_estimate import TblSalesEstimate
from app.models.main.rent_expenses import TblRentExpenses
from app.models.main.summary import TblSummary, SummaryBase
from app.models.main.group import TblGroup

from app.api.module_2.schema import (
    CustomerSegmentEstimate,
    CatchmentPotential,
    CompetitionIntensity,
    CompetitorBenchmarking,
    LocationSpillageFactor,
    LocationDetails,
    SalesEstimate,
    RentExpenses,
    SalesSummaryComparison,
    Module2Summary,
    Module2ReportPreview,
)
from app.dependency.authantication import JWTPayloadSchema


class Module2ReportService:
    def __init__(self, db: Session, token: JWTPayloadSchema = None):
        self.db = db
        self.token = token

    async def _fetch_module2_data(self, group_id: int) -> Module2ReportPreview:
        if not group_id:
            raise HTTPException(status_code=400, detail="group_id is required")

        # Check if group_id exists in module tables
        module_data_exists = (
            self.db.query(TblCustomerLocation).filter_by(group_id=group_id).first() or
            self.db.query(TblCatchmentPotential).filter_by(group_id=group_id).first() or
            self.db.query(TblCompetitorIntensity).filter_by(group_id=group_id).first() or
            self.db.query(TblLocationDetails).filter_by(group_id=group_id).first() or
            self.db.query(TblSalesEstimate).filter_by(group_id=group_id).first()
        )

        if module_data_exists:
            # Group ID available in module tables - fetch actual data
            customer_records = self.db.query(TblCustomerLocation).filter_by(group_id=group_id).all()
            catchment_records = self.db.query(TblCatchmentPotential).filter_by(group_id=group_id).all()
            competition_record = self.db.query(TblCompetitorIntensity).filter_by(group_id=group_id).first()
            benchmarking_records = self.db.query(TblCompetitorBenchmarking).filter_by(group_id=group_id).all()
            spillage_records = self.db.query(TblLocationSpillageFactor).filter_by(group_id=group_id).all()
            location_records = self.db.query(TblLocationDetails).filter_by(group_id=group_id).all()
            sales_records = self.db.query(TblSalesEstimate).filter_by(group_id=group_id).all()
            rent_records = self.db.query(TblRentExpenses).filter_by(group_id=group_id).all()
            summary_record = self.db.query(TblSummary).filter_by(group_id=group_id).first()

            return Module2ReportPreview(
                group_id=group_id,
                customer_segment_estimates=[CustomerSegmentEstimate(**c.__dict__) for c in customer_records] or [CustomerSegmentEstimate()],
                catchment_potential=[CatchmentPotential(**c.__dict__) for c in catchment_records] or [CatchmentPotential()],
                competition_intensity=CompetitionIntensity(**competition_record.__dict__) if competition_record else CompetitionIntensity(),
                competitor_benchmarking=[CompetitorBenchmarking(**b.__dict__) for b in benchmarking_records] or [CompetitorBenchmarking()],
                location_spillage_factor=[LocationSpillageFactor(**s.__dict__) for s in spillage_records] or [LocationSpillageFactor()],
                location_details=[LocationDetails(**l.__dict__) for l in location_records] or [LocationDetails()],
                sales_estimates=[SalesEstimate(**s.__dict__) for s in sales_records] or [SalesEstimate()],
                rent_expenses=[RentExpenses(**r.__dict__) for r in rent_records] or [RentExpenses()],
                sales_summary_comparison=[SalesSummaryComparison()],
                summary=Module2Summary(**summary_record.__dict__) if summary_record else Module2Summary(),
                last_updated=pd.Timestamp.now()
            )
        else:
            # Check if group_id exists in tbl_group
            group_exists = self.db.query(TblGroup).filter_by(group_id=group_id).first()
            
            if group_exists:
                # Group ID exists in tbl_group but not in module tables - return N/A and 0 values
                return Module2ReportPreview(
                    group_id=group_id,
                    customer_segment_estimates=[CustomerSegmentEstimate(group_id=group_id)],
                    catchment_potential=[CatchmentPotential(group_id=group_id)],
                    competition_intensity=CompetitionIntensity(group_id=group_id),
                    competitor_benchmarking=[CompetitorBenchmarking(group_id=group_id)],
                    location_spillage_factor=[LocationSpillageFactor(group_id=group_id)],
                    location_details=[LocationDetails(group_id=group_id)],
                    sales_estimates=[SalesEstimate(group_id=group_id)],
                    rent_expenses=[RentExpenses(group_id=group_id)],
                    sales_summary_comparison=[SalesSummaryComparison(group_id=group_id)],
                    summary=Module2Summary(group_id=group_id),
                    last_updated=pd.Timestamp.now()
                )
            else:
                # Group ID not found in database
                raise HTTPException(status_code=404, detail="Group ID not found in database")
    
    def _generate_pdf_report(self, data: Module2ReportPreview) -> str:
        """Generate PDF report from Module2 data"""
        buffer = BytesIO()
        doc = SimpleDocTemplate(buffer, pagesize=A4, leftMargin=50, rightMargin=50, topMargin=50, bottomMargin=50)
        styles = getSampleStyleSheet()
        story = []
        
        # Logo
        try:
            logo = Image("TS Logo.png", width=150, height=75)
            logo.hAlign = 'LEFT'
            story.append(logo)
            story.append(Spacer(1, 10))
        except:
            pass  # Continue without logo if file not found
        
        # Title style
        title_style = ParagraphStyle(
            'CustomTitle',
            parent=styles['Heading1'],
            fontSize=16,
            spaceAfter=20,
            alignment=1,
            fontName='Helvetica-Bold'
        )
        story.append(Paragraph("Module 2: Location Analysis", title_style))
        story.append(Spacer(1, 12))
        
        # Section header style
        section_style = ParagraphStyle(
            'SectionHeader',
            parent=styles['Heading2'],
            fontSize=14,
            spaceAfter=10,
            fontName='Helvetica-Bold'
        )
        
        # Customer Segment Estimates
        story.append(Paragraph("Customer Segment Estimates", section_style))
        if data.customer_segment_estimates:
            customer_headers = [Paragraph('ISEC<br/>Segment', styles['Normal']), Paragraph('Sample<br/>Size', styles['Normal']), Paragraph('Average<br/>Age', styles['Normal']), Paragraph('Income<br/>Levels', styles['Normal']), Paragraph('Occupation<br/>Mode', styles['Normal']), Paragraph('Education<br/>Level Mode', styles['Normal']), Paragraph('Shopping<br/>Frequency<br/>Per Month', styles['Normal']), Paragraph('Household<br/>Consumption<br/>Per Month', styles['Normal']), Paragraph('Total Household<br/>Consumption<br/>Per Month', styles['Normal'])]
            customer_rows = [customer_headers]
            for customer in data.customer_segment_estimates:
                customer_rows.append([
                    str(customer.ISEC_Segment or 0),
                    str(customer.sample_size or 0),
                    str(customer.average_age or 0),
                    str(customer.income_levels or "N/A"),
                    str(customer.occupation_mode or "N/A"),
                    str(customer.education_level_mode or "N/A"),
                    str(customer.shopping_frequency_per_month or 0),
                    str(customer.household_consumption_per_month or 0),
                    str(customer.total_household_consumption_per_month or 0)
                ])
            customer_table = Table(customer_rows)
            customer_table.setStyle(TableStyle([
                ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, -1), 6),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.black),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('WORDWRAP', (0, 0), (-1, -1), True)
            ]))
            story.append(customer_table)
        story.append(Spacer(1, 15))
        
        # Catchment Potential
        story.append(Paragraph("Catchment Potential", section_style))
        if data.catchment_potential:
            catchment_headers = ['ISEC Segment', 'Number of Households', 'Percentage of Segment', 'Potential Number of Households']
            catchment_rows = [catchment_headers]
            for catchment in data.catchment_potential:
                catchment_rows.append([
                    str(catchment.ISEC_Segment or 0),
                    str(catchment.number_of_households or 0),
                    str(catchment.percentage_of_segment or 0),
                    str(catchment.potential_number_of_households or 0)
                ])
            catchment_table = Table(catchment_rows)
            catchment_table.setStyle(TableStyle([
                ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, -1), 8),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.black),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('WORDWRAP', (0, 0), (-1, -1), True)
            ]))
            story.append(catchment_table)
        story.append(Spacer(1, 15))
        
        # Competition Intensity
        story.append(Paragraph("Competition Intensity", section_style))
        comp_data = [['Upload Image', 'Total Square Footage', 'Assumptions']]
        comp_data.append([str(data.competition_intensity.upload_image or "N/A"), str(data.competition_intensity.total_square_footage or "N/A"), str(data.competition_intensity.assumptions or "N/A")])
        comp_table = Table(comp_data)
        comp_table.setStyle(TableStyle([
            ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
            ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
            ('FONTSIZE', (0, 0), (-1, -1), 7),
            ('GRID', (0, 0), (-1, -1), 0.5, colors.black),
            ('VALIGN', (0, 0), (-1, -1), 'TOP'),
            ('WORDWRAP', (0, 0), (-1, -1), True)
        ]))
        story.append(comp_table)
        story.append(Spacer(1, 15))
        
        # Competitor Benchmarking
        if data.competitor_benchmarking:
            story.append(Paragraph("Competitor Benchmarking", section_style))
            bench_headers = ['Reference Competitor', 'Bills Per Month', 'Items Pre Bill', 'Avg Price', 'Other Remark']
            bench_rows = [bench_headers]
            for bench in data.competitor_benchmarking:
                bench_rows.append([
                    str(bench.reference_competitor or "N/A"),
                    str(bench.bills_per_month or 0),
                    str(bench.items_pre_bill or 0),
                    str(bench.avg_price or 0),
                    str(bench.other_remark or "N/A")
                ])
            bench_table = Table(bench_rows)
            bench_table.setStyle(TableStyle([
                ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, -1), 6),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.black),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('WORDWRAP', (0, 0), (-1, -1), True)
            ]))
            story.append(bench_table)
            story.append(Spacer(1, 15))
        
        # Location Spillage Factor
        story.append(Paragraph("Location Spillage Factor", section_style))
        if data.location_spillage_factor:
            spillage_headers = ['Estimated Spillage Factor', 'Assumptions']
            spillage_rows = [spillage_headers]
            for spillage in data.location_spillage_factor:
                spillage_rows.append([
                    str(spillage.estimated_spillage_factor or "N/A"),
                    str(spillage.assumptions or "N/A")
                ])
            spillage_table = Table(spillage_rows)
            spillage_table.setStyle(TableStyle([
                ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, -1), 8),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.black),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('WORDWRAP', (0, 0), (-1, -1), True)
            ]))
            story.append(spillage_table)
        story.append(Spacer(1, 15))
        
        # Location Details
        normal_style = styles['Normal']
        normal_style.fontSize = 7
        normal_style.leading = 9
        
        def format_details(details_dict):
            if not details_dict or details_dict == {}:
                return Paragraph("N/A", normal_style)
            txt = "<br/>".join(f"<b>{k}</b>: {v}" for k, v in details_dict.items())
            return Paragraph(txt, normal_style)
        
        if data.location_details:
            story.append(Paragraph("Location Details", section_style))
            story.append(Spacer(1, 10))
            
            headers = [
                'Trading<br/>Radius', 'Trading<br/>Rational', 'Adjacencies', 
                'Adjacencies<br/>Rational', 'Location<br/>Characteristics', 'Corner<br/>Property<br/>Rational', 
                'Rational', 'Details', 'Store<br/>Format<br/>Type'
            ]
            header_row = [Paragraph(h, normal_style) for h in headers]
            rows = [header_row]
            
            for loc in data.location_details:
                details_column = format_details(loc.details)
                row = [
                    Paragraph(str(loc.trading_radius or "N/A"), normal_style),
                    Paragraph(str(loc.trading_rational or "N/A"), normal_style),
                    Paragraph(str(loc.adjacencies or "N/A"), normal_style),
                    Paragraph(str(loc.adjacencies_rational or "N/A"), normal_style),
                    Paragraph(str(loc.location_characteristics or "N/A"), normal_style),
                    Paragraph(str(loc.corner_property_rational or "N/A"), normal_style),
                    Paragraph(str(loc.rational or "N/A"), normal_style),
                    details_column,
                    Paragraph(str(loc.store_format_type or "N/A"), normal_style)
                ]
                rows.append(row)
            
            column_widths = [35, 60, 65, 65, 55, 55, 65, 80, 40]
            location_table = Table(rows, colWidths=column_widths, repeatRows=1)
            location_table.setStyle(TableStyle([
                ('GRID', (0, 0), (-1, -1), 0.5, colors.black),
                ('BOX', (0, 0), (-1, -1), 1, colors.black),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, 0), 7),
                ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'),
                ('FONTSIZE', (0, 1), (-1, -1), 7),
                ('LEFTPADDING', (0, 0), (-1, -1), 4),
                ('RIGHTPADDING', (0, 0), (-1, -1), 4),
                ('TOPPADDING', (0, 0), (-1, -1), 3),
                ('BOTTOMPADDING', (0, 0), (-1, -1), 3),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
            ]))
            story.append(location_table)
            story.append(Spacer(1, 15))
        
        # Sales Estimates
        if data.sales_estimates:
            story.append(Paragraph("Sales Estimates", section_style))
            sales_headers = ['Parameter', 'Value', 'Remark', 'Store Format Type']
            sales_rows = [sales_headers]
            for sales in data.sales_estimates:
                sales_rows.append([
                    str(sales.parameter or "N/A"),
                    str(sales.value or 0),
                    str(sales.remark or "N/A"),
                    str(sales.store_formate_type or "N/A")
                ])
            sales_table = Table(sales_rows)
            sales_table.setStyle(TableStyle([
                ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, -1), 7),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.black),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('WORDWRAP', (0, 0), (-1, -1), True)
            ]))
            story.append(sales_table)
            story.append(Spacer(1, 15))
        
        # Rent Expenses
        if data.rent_expenses:
            story.append(Paragraph("Rent Expenses", section_style))
            rent_headers = ['Store Size', 'Rental Advance Amount', 'Rental Advance Period', 'Rent Per Month', 'Store Format Type', 'Rational For Rental Expenses']
            rent_rows = [rent_headers]
            for rent in data.rent_expenses:
                rent_rows.append([
                    str(rent.store_size or 0),
                    str(rent.rental_advance_amount or 0),
                    str(rent.rental_advance_period or 0),
                    str(rent.rent_per_month or 0),
                    str(rent.store_format_type or "N/A"),
                    str(rent.rational_for_rental_expenses or "N/A")
                ])
            rent_table = Table(rent_rows)
            rent_table.setStyle(TableStyle([
                ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, -1), 6),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.black),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('WORDWRAP', (0, 0), (-1, -1), True)
            ]))
            story.append(rent_table)
            story.append(Spacer(1, 15))
        
        # Sales Summary Comparison
        story.append(Paragraph("Sales Summary Comparison", section_style))
        if data.sales_summary_comparison:
            sales_summary_headers = ['Store Format Type', 'Sales Potential', 'Sales Estimate']
            sales_summary_rows = [sales_summary_headers]
            for summary in data.sales_summary_comparison:
                sales_summary_rows.append([
                    str(summary.store_format_type or "N/A"),
                    str(summary.sales_potential or 0),
                    str(summary.sales_estimate or 0)
                ])
            sales_summary_table = Table(sales_summary_rows)
            sales_summary_table.setStyle(TableStyle([
                ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, -1), 8),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.black),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('WORDWRAP', (0, 0), (-1, -1), True)
            ]))
            story.append(sales_summary_table)
        story.append(Spacer(1, 15))
        
        # Summary
        summary_normal_style = styles['Normal']
        summary_normal_style.fontSize = 8
        summary_normal_style.leading = 10
        
        story.append(Paragraph("Summary", section_style))
        story.append(Spacer(1, 10))
        
        headers = [
            Paragraph('Promotions<br/>Submissions', summary_normal_style),
            Paragraph('Promotions<br/>Summarise', summary_normal_style)
        ]
        
        row = [
            Paragraph(str(data.summary.promotions_submissions or "N/A"), summary_normal_style),
            Paragraph(str(data.summary.promotions_summarise or "N/A"), summary_normal_style)
        ]
        
        table_data = [headers, row]
        column_widths = [255, 255]
        
        summary_table = Table(table_data, colWidths=column_widths, repeatRows=1)
        summary_table.setStyle(TableStyle([
            ('GRID', (0, 0), (-1, -1), 0.7, colors.black),
            ('BOX', (0, 0), (-1, -1), 1, colors.black),
            ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
            ('FONTSIZE', (0, 0), (-1, 0), 8),
            ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'),
            ('FONTSIZE', (0, 1), (-1, -1), 8),
            ('LEFTPADDING', (0, 0), (-1, -1), 5),
            ('RIGHTPADDING', (0, 0), (-1, -1), 5),
            ('TOPPADDING', (0, 0), (-1, -1), 4),
            ('BOTTOMPADDING', (0, 0), (-1, -1), 4),
            ('VALIGN', (0, 0), (-1, -1), 'TOP'),
            ('ALIGN', (0, 0), (-1, -1), 'LEFT')
        ]))
        story.append(summary_table)
        story.append(Spacer(1, 10))
        
        # Last Updated
        story.append(Spacer(1, 10))
        story.append(Paragraph(f"Last Updated: {data.last_updated.strftime('%Y-%m-%d %H:%M:%S')}", styles['Normal']))
        
        doc.build(story)
        buffer.seek(0)
        
       # Save to specified directory
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"module2_complete_report_group_{data.group_id}_{timestamp}.pdf"
        filepath = f"uploaded_files/{filename}"
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        
        with open(filepath, 'wb') as f:
            f.write(buffer.getvalue())
        
        return os.path.join("uploaded_files", filename)
    
    async def generate_and_save_pdf(self, group_id: int) -> str:
        """Generate PDF and save to database, return file path"""
        report_data = await self._fetch_module2_data(group_id)
        file_path = self._generate_pdf_report(report_data)
        self._save_file_path_to_db(group_id, file_path)
        return file_path
    
    def _save_file_path_to_db(self, group_id: int, file_path: str):
        """Save or update file path in summary table with group_id_mod"""
        try:
            group_id_mod = f"g{group_id}_m2"
            print(f"DEBUG: Looking for existing record with group_id_mod: {group_id_mod}")
            existing_summary = self.db.query(TblSummary).filter(TblSummary.group_id_mod == group_id_mod).first()
            
            if existing_summary:
                print(f"DEBUG: Found existing record, updating uploaded_files")
                existing_summary.uploaded_files = file_path
                self.db.commit()
                self.db.refresh(existing_summary)
                print(f"DEBUG: Updated record successfully")
            else:
                print(f"DEBUG: No existing record found, creating new one")
                new_summary = TblSummary(
                    group_id=group_id,
                    group_id_mod=group_id_mod,
                    uploaded_files=file_path
                )
                self.db.add(new_summary)
                self.db.commit()
                self.db.refresh(new_summary)
                print(f"DEBUG: Created new record with summary_id: {new_summary.summary_id}")
        except Exception as e:
            self.db.rollback()
            print(f"ERROR saving file path to database: {e}")
            import traceback
            traceback.print_exc()