from datetime import datetime, timedelta
from typing import Dict, Optional
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jose import jwt, JWTError
from app.config import settings
from typing import Any

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/FASTAPI/user/login")

class JWTManager:
    """Handles JWT access and refresh token creation and validation."""

    @staticmethod
    def create_access_token(data: dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
        """Generate a short-lived access token."""
        to_encode = data.copy()
        expire = datetime.utcnow() + (
            expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
        )
        to_encode.update({"exp": expire, "type": "access"})

        if "user_id" not in to_encode or "sub" not in to_encode:
            raise ValueError("Token must include both `user_id` and `sub` (email).")

        return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)

    @staticmethod
    def create_refresh_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
        """Generate a long-lived refresh token."""
        to_encode = data.copy()
        expire = datetime.utcnow() + (
            expires_delta or timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
        )
        to_encode.update({"exp": expire, "type": "refresh"})

        if "user_id" not in to_encode or "sub" not in to_encode:
            raise ValueError("Token must include both `user_id` and `sub` (email).")

        return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)


    @staticmethod
    def verify_token(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]:
        """Validate access token from Authorization header."""
        try:
            if not token:
                raise HTTPException(status_code=401, detail="Missing token")

            token = token.strip()
            if token.lower().startswith("bearer "):
                token = token[7:]

            payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
            token_type = payload.get("type")

            if token_type != "access":
                raise HTTPException(status_code=401, detail="Invalid token type for this operation")

         
            if not payload.get("user_id") or not payload.get("sub"):
                raise HTTPException(status_code=401, detail="Invalid token: missing user data")

            
            return payload

        except JWTError:
            raise HTTPException(status_code=401, detail="Invalid or expired token")


    @staticmethod
    def refresh_access_token(refresh_token: str) -> Dict[str, str]:
        """Use a valid refresh token to create a new access token."""
        try:
            payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
            if payload.get("type") != "refresh":
                raise HTTPException(status_code=401, detail="Invalid refresh token type")

            user_id = payload.get("user_id")
            user_name = payload.get("sub")

            if not user_id or not user_name:
                raise HTTPException(status_code=401, detail="Refresh token is missing user data")

            # ✅ Create a new access token using the same user data
            access_token = JWTManager.create_access_token({
                "user_id": user_id,
                "sub": user_name
            })

            return {
                "access_token": access_token,
                "token_type": "bearer"
            }

        except JWTError:
            raise HTTPException(status_code=401, detail="Invalid or expired refresh token")

