import base64
from hashlib import md5

from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from fastapi import HTTPException

from app.config import settings
from app.utils.log_utils import CUSTOM_LOGGER

_encryption_iv = settings.MAIN_ENCRYPTION_IV  # Initialization vector used for encryption
_encryption_key = settings.MAIN_ENCRYPTION_KEY  # Key used for encryption


_backend = default_backend()
_cipher = Cipher(algorithms.AES(_encryption_key.encode()), modes.CTR(_encryption_iv.encode()), backend=_backend)


def encrypt(data: str | None) -> str:
    """
    Encrypts the provided data using AES encryption algorithm.

    Args:
        data (str): The data to be encrypted.

    Returns
    -------
        str: The encrypted data encoded in Base64.
    """
    data = str(data)
    # AES Encryption method
    encryptor = _cipher.encryptor()
    encryption = encryptor.update(data.encode()) + encryptor.finalize()
    # Convert to Base64
    return base64.b64encode(encryption).decode("utf-8")


def decrypt(encrypted_data_str: str) -> str:
    """
    Decrypts the provided encrypted data.

    Args:
        encrypted_data_str (str): The encrypted data encoded in Base64.

    Returns
    -------
        str: The decrypted data.
    """
    # Convert from Base64
    try:
        encrypted_data = base64.b64decode(encrypted_data_str.encode("utf-8"))
        # AES Decryption method
        decryptor = _cipher.decryptor()
        decrypted_data = decryptor.update(encrypted_data) + decryptor.finalize()
        return decrypted_data.decode("utf-8")
    except ValueError as e:
        CUSTOM_LOGGER.error_log(__name__, f"Unable to decrypt {e!s}")
        return encrypted_data_str


class SecureEncryption:
    """Communicates with the frontend to encrypt and decrypt data."""

    def __init__(self) -> None:
        self.BLOCK_SIZE = 16
        self.secret_key = b"01234567890123456789012345678901"
        self.SALT_LENGTH = 8

    def pad(self, data: bytes) -> bytes:
        """Pad function."""
        length = self.BLOCK_SIZE - (len(data) % self.BLOCK_SIZE)
        return data + bytes([length] * length)

    def unpad(self, data: bytes) -> bytes:
        """Unpad function."""
        return data[: -(data[-1] if isinstance(data[-1], int) else ord(data[-1]))]

    def bytes_to_key(self, data: bytes, salt: bytes, output: int = 48) -> bytes:
        """Bytes to key function."""
        if len(salt) != self.SALT_LENGTH:
            raise HTTPException(status_code=400, detail=f"Salt must be {self.SALT_LENGTH} bytes long, not {len(salt)}")

        data += salt

        # Md5 alogo used by Frontend and result in failure if not used here
        key = md5(data).digest()  # noqa: S324
        final_key = key
        while len(final_key) < output:
            key = md5(key + data).digest()  # noqa: S324
            final_key += key
        return final_key[:output]

    def encrypt(self, message: str) -> str:
        """Encrypt function."""
        salt = get_random_bytes(8)
        key_iv = self.bytes_to_key(self.secret_key, salt, 32 + 16)
        key = key_iv[:32]
        iv = key_iv[32:]
        cipher = AES.new(key, AES.MODE_CBC, iv)
        encrypted_message = cipher.encrypt(self.pad(message.encode()))
        return base64.b64encode(b"Salted__" + salt + encrypted_message).decode()

    def decrypt(self, encrypted_message: str) -> str:
        """Decrypt function."""
        encrypted_message_bytes = base64.b64decode(encrypted_message)
        if encrypted_message_bytes[0:8] != b"Salted__":
            raise HTTPException(status_code=200, detail="Failed to decrypt")

        salt = encrypted_message_bytes[8:16]
        key_iv = self.bytes_to_key(self.secret_key, salt, 32 + 16)
        key = key_iv[:32]
        iv = key_iv[32:]
        cipher = AES.new(key, AES.MODE_CBC, iv)
        decrypted_bytes = self.unpad(cipher.decrypt(encrypted_message_bytes[16:]))
        return decrypted_bytes.decode("utf-8")


