from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from io import BytesIO
from reportlab.lib.pagesizes import A4
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib import colors

from app.database.main.mysql import get_db
from app.dependency.authantication import JWTPayloadSchema, get_current_student
from app.api.module_4_report.service import Module4ReportService
from app.api.module_4_report.schema import Module4ReportPreview

module_4_report_router = APIRouter()


@module_4_report_router.get("/preview4/{group_id}", response_model=Module4ReportPreview)
async def get_module4_report_preview(
    group_id: int,
    db: Session = Depends(get_db),
    token: JWTPayloadSchema = Depends(get_current_student)
):
    """Preview Module 4 report data"""
    try:
        service = Module4ReportService(db, token)
        return await service.get_report_preview(group_id)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@module_4_report_router.get("/download4/{group_id}")
async def download_module4_report(
    group_id: int,
    db: Session = Depends(get_db),
    token: JWTPayloadSchema = Depends(get_current_student)
):
    try:
        service = Module4ReportService(db, token)
        report = await service._fetch_module4_data(group_id)

        buffer = BytesIO()
        doc = SimpleDocTemplate(buffer, pagesize=A4,
                                leftMargin=50, rightMargin=50,
                                topMargin=50, bottomMargin=50)
        styles = getSampleStyleSheet()
        story = []

        title_style = ParagraphStyle(
            'CustomTitle', parent=styles['Heading1'],
            fontSize=16, spaceAfter=20,
            alignment=1, fontName='Helvetica-Bold')
        story.append(Paragraph("Module 4: Category Management", title_style))
        story.append(Spacer(1, 20))

        section_style = ParagraphStyle(
            'SectionHeader', parent=styles['Heading2'],
            fontSize=12, spaceAfter=10,
            fontName='Helvetica-Bold')
        story.append(Paragraph("Primary Research", section_style))

        story.append(Paragraph("Competitor Research:", styles['Normal']))
        # Use Paragraph in cells to allow wrapping
        comp_data = [
            ['Reference Competitor', report.competitor_analysis.reference_competitor or 'N/A'],
            ['Number of Categories Stocked', str(report.competitor_analysis.number_of_categories_stocked or 0)],
            ['Key Categories Stocked', report.competitor_analysis.key_categories_stocked or 'N/A'],
            ['What are some key observations around stocking and categories in the store? ', report.competitor_analysis.stocking_observations or 'N/A']
        ]

        for label, value in comp_data:
            comp_table = Table(
                [[Paragraph(label, styles['Normal']), Paragraph(value, styles['Normal'])]],
                colWidths=[300, 200],
                rowHeights=[40]
            )
            comp_table.setStyle(TableStyle([
                ('GRID', (0, 0), (-1, -1), 1, colors.black),
                ('FONTNAME', (0, 0), (-1, -1), 'Helvetica'),
                ('FONTSIZE', (0, 0), (-1, -1), 9),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('TOPPADDING', (0, 0), (-1, -1), 8),
                ('BOTTOMPADDING', (0, 0), (-1, -1), 8)
            ]))
            story.append(comp_table)

        story.append(Spacer(1, 10))

        story.append(Paragraph("Customer Research:", styles['Normal']))
        cust_data = [
            ['Define Research Methodology (Sample Size, Probes, and Method)', report.competitor_analysis.research_methodology or 'N/A'],
            ['What are some key observations around shopping behaviours around categories by customers?  ', report.competitor_analysis.shopping_behavior or 'N/A']
        ]

        for label, value in cust_data:
            cust_table = Table(
                [[Paragraph(label, styles['Normal']), Paragraph(value, styles['Normal'])]],
                colWidths=[300, 200],
                rowHeights=[40]
            )
            cust_table.setStyle(TableStyle([
                ('GRID', (0, 0), (-1, -1), 1, colors.black),
                ('FONTNAME', (0, 0), (-1, -1), 'Helvetica'),
                ('FONTSIZE', (0, 0), (-1, -1), 9),
                ('VALIGN', (0, 0), (-1, -1), 'TOP'),
                ('TOPPADDING', (0, 0), (-1, -1), 8),
                ('BOTTOMPADDING', (0, 0), (-1, -1), 8)
            ]))
            story.append(cust_table)

        story.append(Spacer(1, 20))

        story.append(Paragraph("Category Management", section_style))
        
        # Category Management table with database data
        cat_mgmt_data = [['Store Format', 'Level of Category or Merchandise', 'Number of Categories Available']]
        
        # Add data from database for each store format
        for store_format in report.store_formats:
            if store_format.store_format_type and store_format.merchandise:
                # Get the corresponding category count based on merchandise level
                if store_format.merchandise.lower() == 'low':
                    category_count = str(report.pre_selected_categories.low_value or 0)
                elif store_format.merchandise.lower() == 'medium':
                    category_count = str(report.pre_selected_categories.medium_value or 0)
                elif store_format.merchandise.lower() == 'high':
                    category_count = str(report.pre_selected_categories.high_value or 0)
                else:
                    category_count = '0'
                
                cat_mgmt_data.append([
                    store_format.store_format_type,
                    store_format.merchandise,
                    category_count
                ])
        
        # If no store formats, show empty table
        if len(cat_mgmt_data) == 1:
            cat_mgmt_data.append(['N/A', 'N/A', 'N/A'])
        
        cat_table = Table(cat_mgmt_data, colWidths=[150, 200, 200], rowHeights=[30] * len(cat_mgmt_data))
        cat_table.setStyle(TableStyle([
            ('GRID', (0, 0), (-1, -1), 1, colors.black),
            ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
            ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'),
            ('FONTSIZE', (0, 0), (-1, -1), 9),
            ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
            ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
            ('TOPPADDING', (0, 0), (-1, -1), 6),
            ('BOTTOMPADDING', (0, 0), (-1, -1), 6)
        ]))
        story.append(cat_table)
        
        story.append(Spacer(1, 20))

        # Category Role Management section
        story.append(Paragraph("Category Role Management", section_style))
        
        # Category Role Management table with database data
        category_role_data = [['Category Name', 'Category Role', 'Sales Contribution', 'Margin Contribution']]
        
        if report.gross_margin_contributions and any(gm.gross_id for gm in report.gross_margin_contributions):
            for i, gm in enumerate(report.gross_margin_contributions):
                if gm.gross_id:
                    sales_contrib = gm.contribution_to_total_sales or 0
                    margin_contrib = gm.contribution_to_gross_margin or 0
                    
                    # Assign category role based on sales and margin contributions
                    if sales_contrib >= 30 and margin_contrib >= 30:
                        role = "Flagship"
                    elif margin_contrib >= 35:
                        role = "Cash Machine"
                    elif sales_contrib >= 25:
                        role = "Destination"
                    elif sales_contrib >= 15:
                        role = "Core Traffic"
                    elif sales_contrib >= 10:
                        role = "Maintain"
                    else:
                        role = "Under Fire"
                    
                    category_role_data.append([
                        f"Category {i+1}",
                        role,
                        f"{sales_contrib:.1f}%",
                        f"{margin_contrib:.1f}%"
                    ])
        else:
            # If no data, show empty rows
            for i in range(3):
                category_role_data.append(['N/A', 'N/A', 'N/A', 'N/A'])
        
        category_role_table = Table(category_role_data, colWidths=[130, 130, 130, 130], rowHeights=[30] * len(category_role_data))
        category_role_table.setStyle(TableStyle([
            ('GRID', (0, 0), (-1, -1), 1, colors.black),
            ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
            ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'),
            ('FONTSIZE', (0, 0), (-1, -1), 10),
            ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
            ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
            ('TOPPADDING', (0, 0), (-1, -1), 8),
            ('BOTTOMPADDING', (0, 0), (-1, -1), 8)
        ]))
        story.append(category_role_table)
        
        story.append(Spacer(1, 20))

        story.append(Paragraph("Gross Margin Contributions", section_style))
        margin_headers = ['Category Name', 'Contribution to Sales', 'Gross Margin', 'Gross Margin Contribution']
        margin_data = [margin_headers]

        # Calculate totals and show data
        total_sales_contrib = 0
        total_margin_contrib = 0
        total_gross_margin_contrib = 0
        has_data = False
        
        for i, gm in enumerate(report.gross_margin_contributions):
            if gm.gross_id:
                has_data = True
                sales_contrib = gm.contribution_to_total_sales or 0
                margin_contrib = gm.contribution_to_gross_margin or 0
                gross_margin_calc = (sales_contrib * margin_contrib) / 100 if sales_contrib and margin_contrib else 0
                
                total_sales_contrib += sales_contrib
                total_margin_contrib += margin_contrib
                total_gross_margin_contrib += gross_margin_calc

                margin_data.append([
                    Paragraph(f"Category {i+1}", styles['Normal']),
                    Paragraph(f"{sales_contrib:.1f}%", styles['Normal']),
                    Paragraph(f"{margin_contrib:.1f}%", styles['Normal']),
                    Paragraph(f"{gross_margin_calc:.2f}%", styles['Normal'])
                ])
        
        # If no data, add empty rows
        if not has_data:
            for i in range(3):  # Add 3 empty rows
                margin_data.append(['', '', '', ''])

        margin_table = Table(margin_data, colWidths=[130, 130, 130, 130], rowHeights=[35] * len(margin_data))
        margin_table.setStyle(TableStyle([
            ('GRID', (0, 0), (-1, -1), 1, colors.black),
            ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
            ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'),
            ('FONTSIZE', (0, 0), (-1, -1), 10),
            ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
            ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
            ('TOPPADDING', (0, 0), (-1, -1), 8),
            ('BOTTOMPADDING', (0, 0), (-1, -1), 8)
        ]))
        story.append(margin_table)
        story.append(Spacer(1, 10))
        
        # Total Weighted Average Gross Margin table with calculated values
        small_style = ParagraphStyle(
            'SmallWrap', parent=styles['Normal'], 
            fontSize=8, leading=10, wordWrap='CJK'
        )
        total_data = [[
            Paragraph('Total Weighted Average Gross Margin', small_style),
            f'{total_sales_contrib:.1f}%' if has_data else '',
            f'{total_margin_contrib:.1f}%' if has_data else '',
            f'{total_gross_margin_contrib:.2f}%' if has_data else ''
        ]]
        total_table = Table(total_data, colWidths=[130, 130, 130, 130], rowHeights=[35])
        total_table.setStyle(TableStyle([
            ('GRID', (0, 0), (-1, -1), 1, colors.black),
            ('FONTNAME', (0, 0), (-1, -1), 'Helvetica-Bold'),
            ('FONTSIZE', (0, 0), (-1, -1), 10),
            ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
            ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
            ('TOPPADDING', (0, 0), (-1, -1), 8),
            ('BOTTOMPADDING', (0, 0), (-1, -1), 8)
        ]))
        story.append(total_table)

        doc.build(story)
        buffer.seek(0)

        return StreamingResponse(
            BytesIO(buffer.read()),
            media_type="application/pdf",
            headers={"Content-Disposition": f"attachment; filename=module_4_report_group_{group_id}.pdf"}
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@module_4_report_router.get("/{group_id}/preview-report-pdf4")
async def generate_and_serve_pdf(
    group_id: int,
    db: Session = Depends(get_db),
    token: JWTPayloadSchema = Depends(get_current_student)
):
    """Generate PDF report and return file path information"""
    try:
        service = Module4ReportService(db, token)
        file_path = await service.generate_and_save_pdf(group_id)
        
        return {
            "group_id": group_id,
            "file_path": file_path,
        }
            
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error generating PDF: {str(e)}")
