from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from datetime import datetime, timedelta
from app.config import settings
from pydantic import BaseModel
from typing import Literal, Optional

# Define two OAuth2PasswordBearer instances, one for user and one for student login
user_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/FASTAPI/user/login")
student_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/FASTAPI/student/login")

# JWT Payload Schema for both user and student (can be extended)
class JWTPayloadSchema(BaseModel):
    user_id: Optional[int] = None
    student_id: Optional[int] = None
    user_type: Literal["student", "admin", "user"]  # User type: student, admin, or user
    user_role: Literal["student", "admin", "superadmin"]  # Role: student, admin, superadmin
    exp: datetime  

class JWTManager:
    @staticmethod
    def create_access_token(data: JWTPayloadSchema) -> str:
        expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
        data.exp = expire
        payload = data.dict(exclude_none=True)
        return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)

    @staticmethod
    def verify_token(token: str) -> JWTPayloadSchema:
        try:
            payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])

            # Check if necessary keys exist
            required_keys = {"user_type", "user_role", "exp"}
            if not required_keys.issubset(payload):
                raise HTTPException(status_code=401, detail="Invalid token structure")

            # Check if the token has expired
            if datetime.utcnow() > datetime.utcfromtimestamp(payload["exp"]):
                raise HTTPException(status_code=403, detail="Token has expired")

            return JWTPayloadSchema(**payload)
        
        except JWTError:
            raise HTTPException(status_code=401, detail="Invalid or tampered token")
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")

# Dependency to get current user based on user OAuth2 token
def get_current_user(token: str = Depends(user_oauth2_scheme)) -> JWTPayloadSchema:
    print(f"Received user token: {token}")  # Log the token for debugging
    return JWTManager.verify_token(token)

# Dependency to get current student based on student OAuth2 token
def get_current_student(token: str = Depends(student_oauth2_scheme)) -> JWTPayloadSchema:
    print(f"Received student token: {token}")  # Log the token for debugging
    return JWTManager.verify_token(token)

# Role-based access control (both for user and student roles)
def allow_roles(*allowed_roles: str):
    def role_checker(current_user: JWTPayloadSchema = Depends(get_current_user)):
        # Ensure that the role matches the allowed roles for the user
        if current_user.user_role not in allowed_roles:
            raise HTTPException(status_code=403, detail="Access forbidden: insufficient role")
        return current_user
    return role_checker


# Role-based access control for student-specific routes
def allow_student_roles(*allowed_roles: str):
    def role_checker(current_student: JWTPayloadSchema = Depends(get_current_student)):
        # Ensure that the role matches the allowed roles for the student
        if current_student.user_role not in allowed_roles:
            raise HTTPException(status_code=403, detail="Access forbidden: insufficient role")
        return current_student
    return role_checker
