# -*- coding: utf-8 -*-

# За основу кода взят pyAesCrypt 0.4.2 (с косметическими правками)

import logging
from os import urandom

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import (
    hashes,
    hmac,
)
from cryptography.hazmat.primitives.ciphers import (
    algorithms,
    Cipher,
    modes,
)
from passport.backend.utils.io import (
    LazyFileLikeObject,
    StrictSizeReader,
)


# encryption/decryption buffer size - 64K
BUFFER_SIZE = 64 * 1024

# maximum password length (number of chars)
MAX_PASSWORD_LENGTH = 1024

# AES block size in bytes
AES_BLOCK_SIZE = 16


log = logging.getLogger('takeout.common.crypto')


# password stretching function
def stretch_password(password, iv1):
    # hash the external iv and the password 8192 times
    digest = iv1 + (16 * b"\x00")

    for i in range(8192):
        passHash = hashes.Hash(hashes.SHA256(), backend=default_backend())
        passHash.update(digest)
        passHash.update(bytes(password, "utf_16_le"))
        digest = passHash.finalize()

    return digest


def encrypt_stream(input_stream, keys, key_version, buffer_size=BUFFER_SIZE):
    encrypted_stream = _encrypt_stream(input_stream, keys, key_version, buffer_size=buffer_size)
    lazy_wrapper = LazyFileLikeObject(
        encrypted_stream,
    )
    return lazy_wrapper


def _encrypt_stream(input_stream, keys, key_version, buffer_size=BUFFER_SIZE):
    input_stream = StrictSizeReader(input_stream)

    # validate bufferSize
    if buffer_size % AES_BLOCK_SIZE != 0:
        raise ValueError("Buffer size must be a multiple of AES block size.")

    password = keys[str(key_version)]

    if len(password) > MAX_PASSWORD_LENGTH:
        raise ValueError("Password is too long.")

    # generate external iv (used to encrypt the main iv and the
    # encryption key)
    iv1 = urandom(AES_BLOCK_SIZE)

    # stretch password and iv
    key = stretch_password(password, iv1)

    # generate random main iv
    iv0 = urandom(AES_BLOCK_SIZE)

    # generate random internal key
    intKey = urandom(32)

    # instantiate AES cipher
    cipher0 = Cipher(algorithms.AES(intKey), modes.CBC(iv0),
                     backend=default_backend())
    encryptor0 = cipher0.encryptor()

    # instantiate HMAC-SHA256 for the ciphertext
    hmac0 = hmac.HMAC(intKey, hashes.SHA256(),
                      backend=default_backend())

    # instantiate another AES cipher
    cipher1 = Cipher(algorithms.AES(key), modes.CBC(iv1),
                     backend=default_backend())
    encryptor1 = cipher1.encryptor()

    # encrypt main iv and key
    c_iv_key = encryptor1.update(iv0 + intKey) + encryptor1.finalize()

    # calculate HMAC-SHA256 of the encrypted iv and key
    hmac1 = hmac.HMAC(key, hashes.SHA256(),
                      backend=default_backend())
    hmac1.update(c_iv_key)

    # write header
    yield bytes("AES", "utf8")

    # write version (AES Crypt version 2 file format -
    # see https://www.aescrypt.com/aes_file_format.html)
    yield b"\x02"

    # reserved byte (set to zero)
    yield b"\x00"

    # setup "CREATED-BY" extension
    cby = "takeout"

    # write "CREATED-BY" extension length
    yield b"\x00" + bytes([1 + len("CREATED_BY" + cby)])

    # write "CREATED-BY" extension
    yield bytes("CREATED_BY", "utf8") + b"\x00" + bytes(cby, "utf8")

    # write "container" extension length
    yield b"\x00\x80"

    # write "container" extension
    for i in range(128):
        yield b"\x00"

    # write end-of-extensions tag
    yield b"\x00\x00"

    # write the iv used to encrypt the main iv and the
    # encryption key
    yield iv1

    # write encrypted main iv and key
    yield c_iv_key

    # write HMAC-SHA256 of the encrypted iv and key
    yield hmac1.finalize()

    acc_bytes_updated = 0
    acc_bytes_read = 0

    # encrypt file while reading it
    while True:
        # try to read bufferSize bytes
        fdata = input_stream.read(buffer_size)
        if len(fdata) > buffer_size:
            log.error('Read %d bytes from input stream, but requested %d', len(fdata), buffer_size)

        # get the real number of bytes read
        bytes_read = len(fdata)
        acc_bytes_read += bytes_read

        # check if EOF was reached
        if bytes_read < buffer_size:
            # file size mod 16, lsb positions
            fs16 = bytes([bytes_read % AES_BLOCK_SIZE])
            # pad data (this is NOT PKCS#7!)
            # ...unless no bytes or a multiple of a block size
            # of bytes was read
            if bytes_read % AES_BLOCK_SIZE == 0:
                pad_len = 0
            else:
                pad_len = AES_BLOCK_SIZE - bytes_read % AES_BLOCK_SIZE
            fdata += bytes([pad_len]) * pad_len
            # encrypt data
            acc_bytes_updated += len(fdata)
            if acc_bytes_updated % AES_BLOCK_SIZE > 0:
                log.error(
                    ' '.join([
                        'Updated encryptor with %d bytes so far, it is not multiple of 16.',
                        'Read %d bytes in total, read %d bytes this time.',
                        'Padded data with %d bytes.',
                    ]),
                    acc_bytes_updated,
                    acc_bytes_read,
                    len(fdata),
                    pad_len,
                )
            try:
                c_text = encryptor0.update(fdata) + encryptor0.finalize()
            except ValueError:
                log.exception(
                    'Updated encryptor with %d bytes so far. Read %d bytes.',
                    acc_bytes_updated,
                    acc_bytes_read,
                )
                raise
            # update HMAC
            hmac0.update(c_text)
            # write encrypted file content
            yield c_text
            # break
            break
        # ...otherwise a full bufferSize was read
        else:
            # encrypt data
            acc_bytes_updated += len(fdata)
            if acc_bytes_updated % AES_BLOCK_SIZE > 0:
                log.error(
                    'Updated encryptor with %d bytes so far, it is not multiple of 16. Read %d bytes.',
                    acc_bytes_updated,
                    acc_bytes_read,
                )
            c_text = encryptor0.update(fdata)
            # update HMAC
            hmac0.update(c_text)
            # write encrypted file content
            yield c_text

    # write plaintext file size mod 16 lsb positions
    yield fs16

    # write HMAC-SHA256 of the encrypted file
    yield hmac0.finalize()


def decrypt_stream(input_stream, keys, key_version, input_length, buffer_size=BUFFER_SIZE):
    decrypted_stream = _decrypt_stream(input_stream, keys, key_version, input_length, buffer_size=buffer_size)
    lazy_wrapper = LazyFileLikeObject(
        decrypted_stream,
    )
    return lazy_wrapper


def _decrypt_stream(input_stream, keys, key_version, input_length, buffer_size=BUFFER_SIZE):
    # validate bufferSize
    if buffer_size % AES_BLOCK_SIZE != 0:
        raise ValueError("Buffer size must be a multiple of AES block size")

    password = keys[str(key_version)]

    input_length = int(input_length)

    if len(password) > MAX_PASSWORD_LENGTH:
        raise ValueError("Password is too long.")

    fdata = input_stream.read(3)
    # check if file is in AES Crypt format (also min length check)
    if fdata != bytes("AES", "utf8") or input_length < 136:
        raise ValueError("File is corrupted or not an AES Crypt file.")

    # check if file is in AES Crypt format, version 2
    # (the only one compatible with pyAesCrypt)
    fdata = input_stream.read(1)
    if len(fdata) != 1:
        raise ValueError("File is corrupted.")

    if fdata != b"\x02":
        raise ValueError("Incompatible version.")

    # skip reserved byte
    input_stream.read(1)

    # skip all the extensions
    while True:
        fdata = input_stream.read(2)
        if len(fdata) != 2:
            raise ValueError("File is corrupted.")
        if fdata == b"\x00\x00":
            break
        input_stream.read(int(fdata.hex(), 16))

    # read external iv
    iv1 = input_stream.read(16)
    if len(iv1) != 16:
        raise ValueError("File is corrupted.")

    # stretch password and iv
    key = stretch_password(password, iv1)

    # read encrypted main iv and key
    c_iv_key = input_stream.read(48)
    if len(c_iv_key) != 48:
        raise ValueError("File is corrupted.")

    # read HMAC-SHA256 of the encrypted iv and key
    hmac1 = input_stream.read(32)
    if len(hmac1) != 32:
        raise ValueError("File is corrupted.")

    # compute actual HMAC-SHA256 of the encrypted iv and key
    hmac1Act = hmac.HMAC(key, hashes.SHA256(),
                         backend=default_backend())
    hmac1Act.update(c_iv_key)

    # HMAC check
    hmac1Act_finalized = hmac1Act.finalize()
    if hmac1 != hmac1Act_finalized:
        raise ValueError("Wrong password (or file is corrupted).")

    # instantiate AES cipher
    cipher1 = Cipher(algorithms.AES(key), modes.CBC(iv1),
                     backend=default_backend())
    decryptor1 = cipher1.decryptor()

    # decrypt main iv and key
    iv_key = decryptor1.update(c_iv_key) + decryptor1.finalize()

    # get internal iv and key
    iv0 = iv_key[:16]
    intKey = iv_key[16:]

    # instantiate another AES cipher
    cipher0 = Cipher(algorithms.AES(intKey), modes.CBC(iv0),
                     backend=default_backend())
    decryptor0 = cipher0.decryptor()

    # instantiate actual HMAC-SHA256 of the ciphertext
    hmac0Act = hmac.HMAC(intKey, hashes.SHA256(),
                         backend=default_backend())

    while input_stream.tell() < input_length - 32 - 1 - buffer_size:
        # read data
        cText = input_stream.read(buffer_size)
        # update HMAC
        hmac0Act.update(cText)
        # decrypt data and write it to output file
        yield decryptor0.update(cText)

    # decrypt remaining ciphertext, until last block is reached
    while input_stream.tell() < input_length - 32 - 1 - AES_BLOCK_SIZE:
        # read data
        cText = input_stream.read(AES_BLOCK_SIZE)
        # update HMAC
        hmac0Act.update(cText)
        # decrypt data and write it to output file
        yield decryptor0.update(cText)

    # last block reached, remove padding if needed
    # read last block

    # this is for empty files
    if input_stream.tell() != input_length - 32 - 1:
        cText = input_stream.read(AES_BLOCK_SIZE)
        if len(cText) < AES_BLOCK_SIZE:
            raise ValueError("File is corrupted.")
    else:
        cText = bytes()

    # update HMAC
    hmac0Act.update(cText)

    # read plaintext file size mod 16 lsb positions
    fs16 = input_stream.read(1)
    if len(fs16) != 1:
        raise ValueError("File is corrupted.")

    # decrypt last block
    p_text = decryptor0.update(cText) + decryptor0.finalize()

    # remove padding
    toremove = ((16 - fs16[0]) % 16)
    if toremove != 0:
        p_text = p_text[:-toremove]

    # write decrypted data to output file
    yield p_text

    # read HMAC-SHA256 of the encrypted file
    hmac0 = input_stream.read(32)
    if len(hmac0) != 32:
        raise ValueError("File is corrupted.")

    # HMAC check
    if hmac0 != hmac0Act.finalize():
        raise ValueError("Bad HMAC (file is corrupted).")
