import collections
import hashlib
import cStringIO

try:
    from Crypto.Cipher import AES
    from Crypto.Util import Counter
    from Crypto import Random
except ImportError:
    # Hack: copier only has a linux build, so importing PyCryptodome from skybone/lib
    # will fail on macos. Try importing PyCrypto from skynet package directory instead.
    old_sys_path = sys.path
    sys.path = [x for x in sys.path if x.startswith('/skynet/python/lib')]
    try:
        from Crypto.Cipher import AES
        from Crypto.Util import Counter
        from Crypto import Random
    finally:
        sys.path = old_sys_path

from ..greenish.deblock import Deblock
from ..utils import timer, dummy_timer


PLAIN = 'plain'
AES_CTR = 'aes_ctr'
SUPPORTED_ENCRYPTION_MODES = (PLAIN, AES_CTR)

CAP_PREFIX = 'enc:'

def encryption_modes_to_capability(modes):
    return CAP_PREFIX + ','.join(modes)

def capabilities_to_encryption_modes(caps):
    for cap in caps:
        if cap.startswith(CAP_PREFIX):
            return cap[len(CAP_PREFIX):].split(',')
    return None


crypto_deblock = Deblock()


def make_aes_ctr_cipher(key, nonce):
    return AES.new(key, AES.MODE_CTR, counter=Counter.new(64, prefix=nonce))

class AesEncryptor(object):
    def __init__(self, key, nonce, log, measure_time=True):
        self.log = log
        self._cipher = make_aes_ctr_cipher(key, nonce)
        self._measure_time = measure_time
        self._timer = timer if measure_time else dummy_timer

    def encrypt(self, data):
        encrypted, time_spent = crypto_deblock.apply(self._encrypt_impl, data)
        if self._measure_time:
            self.log.debug('aes_ctr encryption took %.3fms', time_spent * 1000)
        return encrypted

    def _encrypt_impl(self, data):
        with self._timer() as timer:
            encrypted = self._cipher.encrypt(data)
        return encrypted, timer.spent


class AesDecryptor(object):
    def __init__(self, key, nonce, log, measure_time=True):
        self.log = log
        self._cipher = make_aes_ctr_cipher(key, nonce)
        self._measure_time = measure_time
        self._timer = timer if measure_time else dummy_timer

        self._out_stream = None
        self._in_stream = cStringIO.StringIO()

    def start(self, out_stream):
        assert self._in_stream.tell() == 0
        assert self._out_stream is None
        self._out_stream = out_stream

    def write(self, data):
        self._in_stream.write(data)

    def finish(self):
        decrypted, time_spent = crypto_deblock.apply(self._decrypt_impl, self._in_stream.getvalue())
        self._out_stream.write(decrypted)
        if self._measure_time:
            self.log.debug('aes_ctr decryption took %.3fms', time_spent * 1000)

        self._out_stream = None
        self._in_stream.seek(0)
        self._in_stream.truncate()

    def _decrypt_impl(self, data):
        with self._timer() as timer:
            decrypted = self._cipher.decrypt(data)
        return decrypted, timer.spent


EncryptionParams = collections.namedtuple('EncryptionParams', ('mode',))
AesParams = collections.namedtuple('AesParams', EncryptionParams._fields + ('salt', 'nonce'))


def generate_encryption_params(mode):
    if mode == AES_CTR:
        salt = Random.get_random_bytes(32)
        nonce = Random.get_random_bytes(8)
        return AesParams(mode, salt, nonce)
    else:
        raise RuntimeError, 'Unknown encryption mode {}'.format(mode)

# convert message payload into namedtuple
def parse_encryption_params(payload):
    try:
        mode = payload[0]
        if mode == AES_CTR:
            return AesParams(*payload)
        else:
            return None
    except Exception:
        return None

def generate_key(enc_params, uid1, uid2, infohash, log):
    if enc_params.mode == AES_CTR:
        def impl():
            with timer() as t:
                key = hashlib.pbkdf2_hmac(
                    'sha256', uid1 + uid2 + infohash, enc_params.salt, iterations=50000
                )
            log.debug('aes_ctr key generation took %.3fms', t.spent * 1000)
            return key

        return crypto_deblock.apply(impl)
    else:
        raise RuntimeError, 'Unknown encryption mode {}'.format(mode)

def make_encryptor(enc_params, key, log):
    if enc_params.mode == AES_CTR:
        return AesEncryptor(key, enc_params.nonce, log)
    else:
        raise RuntimeError, 'Unknown encryption mode {}'.format(mode)

def make_decryptor(enc_params, key, log):
    if enc_params.mode == AES_CTR:
        return AesDecryptor(key, enc_params.nonce, log)
    else:
        raise RuntimeError, 'Unknown encryption mode {}'.format(mode)
