# -*- coding: utf-8 -*-

import base64
from collections import namedtuple
from datetime import timedelta
import json
import logging
from os import urandom
from time import time

import cryptography.exceptions
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.hmac import HMAC
import formencode
from passport.backend.core.conf import settings
from passport.backend.core.crypto.key_storage import SigningKeyStorage
from passport.backend.core.lazy_loader import (
    lazy_loadable,
    LazyLoader,
)
import six


log = logging.getLogger(__name__)


def simple_sign(data, version_id=None, registry=None, version=None):
    if version is None:
        if registry is None:
            registry = get_signing_registry()
        version = registry.get(version_id)
    salt = urandom(version.salt_length)
    data = pack(data, salt)
    signature = sign(data, version.secret, version.algorithm)
    signature = pack(signature, salt)
    signature = pack(signature, version.id)
    return signature


def simple_is_correct_signature(signature, data, registry=None):
    try:
        signature, version_id = unpack(signature)
        signature, salt = unpack(signature)
    except ValueError:
        log.debug('Malformed signature')
        return False
    data = pack(data, salt)
    if registry is None:
        registry = get_signing_registry()
    try:
        version = registry.get(version_id)
    except VersionNotFoundSigningError:
        log.debug('Unknown version: %s' % version_id.decode('ascii'))
        return False
    correct = is_correct_signature(signature, data, version.secret, version.algorithm)
    if not correct:
        log.debug('Invalid signature')
    return correct


def get_signing_registry():
    return LazyLoader.get_instance('SigningRegistry')


def get_yandex_and_yandex_team_signing_registry():
    return LazyLoader.get_instance('YandexAndYandexTeamSigningRegistry')


def pack(bytes1, bytes2):
    if (
        not isinstance(bytes1, six.binary_type) or
        not isinstance(bytes2, six.binary_type)
    ):
        raise TypeError('data must be bytes-like')
    return b'%d.%s%s' % (len(bytes2), bytes2, bytes1)


def unpack(_bytes):
    if not isinstance(_bytes, six.binary_type):
        raise TypeError('data must be bytes-like')
    length, _bytes = _bytes.split(b'.', 1)
    length = int(length)
    if length < 0:
        raise ValueError('Negative length')
    bytes2 = _bytes[:length]
    bytes1 = _bytes[length:]
    return bytes1, bytes2


def sign(data, secret, algorithm):
    hmac = _build_hmac(secret, algorithm)
    hmac.update(data)
    return hmac.finalize()


def is_correct_signature(signature, data, secret, algorithm):
    hmac = _build_hmac(secret, algorithm)
    hmac.update(data)
    try:
        hmac.verify(signature)
        return True
    except cryptography.exceptions.InvalidSignature:
        return False


def _build_hmac(secret, algorithm):
    if algorithm != 'SHA256':
        raise NotImplementedError('Not implemented algorithm: %s' % algorithm)
    algorithm = hashes.SHA256()
    if len(secret) < algorithm.digest_size:
        log.warning('Secret length less than digest size')
    return HMAC(secret, algorithm, default_backend())


def _formencode_invalid_to_message(invalid):
    return invalid.msg.replace('\n', r'\n')


class Version(namedtuple('Version', ['algorithm', 'salt_length', 'secret', 'id'])):
    @classmethod
    def from_dict(cls, _dict):
        try:
            _dict = VersionForm().to_python(_dict)
        except formencode.Invalid as e:
            msg = _formencode_invalid_to_message(e)
            msg = 'Failed to create signing version: ' + msg
            log.debug(msg)
            raise ConfigurationFailedSigningError(msg)
        return Version(
            algorithm=_dict['algorithm'],
            salt_length=_dict['salt_length'],
            secret=_dict['secret'],
            id=_dict['id'],
        )


class SigningRegistry(object):
    def __init__(self):
        self._version_id_to_version = dict()
        self._default_version_id = None

    def get(self, version_id=None):
        if version_id is None:
            version_id = self.get_default_version_id()
            if version_id is None:
                raise DefaultVersionNotFoundSigningError()
        version = self._version_id_to_version.get(version_id)
        if not version:
            raise VersionNotFoundSigningError('Version not found: %s' % version_id.decode('ascii'))
        return version

    def add(self, version):
        self._version_id_to_version[version.id] = version

    def remove(self, version_id):
        self._version_id_to_version.pop(version_id, None)
        if self._default_version_id == version_id:
            self._default_version_id = None

    def get_default_version_id(self):
        return self._default_version_id

    def set_default_version_id(self, version_id):
        if version_id not in self._version_id_to_version:
            raise VersionNotFoundSigningError()
        self._default_version_id = version_id

    def get_prev_version_id(self, version_id):
        return str(int(version_id) - 1)

    def add_from_dict(self, _dict):
        """
        Пример:

        {
            'default_version_id': '1',
            'versions': [
                {
                    'id':   '1',
                    'algorithm': 'SHA256',
                    'salt_length': 32,
                    'secret': 'too_short_secret'
                }
            ]
        }
        """
        try:
            _dict = SigningRegistryForm().to_python(_dict)
        except formencode.Invalid as e:
            msg = _formencode_invalid_to_message(e)
            msg = 'Failed to add version to signing registry: %s' % msg
            log.debug(msg)
            raise ConfigurationFailedSigningError(msg)

        for version in _dict['versions']:
            self.add(
                Version(
                    algorithm=version['algorithm'],
                    id=version['id'],
                    salt_length=version['salt_length'],
                    secret=version['secret'],
                ),
            )
        if _dict['default_version_id']:
            try:
                self.set_default_version_id(bytes(_dict['default_version_id']))
            except VersionNotFoundSigningError:
                msg = 'Failed to set default version to signing registry: version not found'
                log.debug(msg)
                raise ConfigurationFailedSigningError(msg)


class RotatingSigningRegistry(object):
    def __init__(
        self,
        key_storage,
        epoch_length,
        foresight,
        foresight_cache_ttl,
        timer=None,
    ):
        """
        key_storage -- хранилище секретов
        epoch_length -- число секунд между ротациями секретов
        foresight -- проверять что в хранилище есть запас секретов на foresight
        секунд.
        foresight_cache_ttl -- число секунд, которое нужно помнить о результате
        последней проверки.
        timer -- функция возвращающая нынешний unixtime
        """
        self.signing_registry = SigningRegistry()
        self.key_storage = key_storage
        self.epoch_length = epoch_length

        self.foresight = foresight
        self.foresight_cache_ttl = foresight_cache_ttl
        self._foresight_cache = None
        self._foresight_cache_expires_at = None

        self.timer = timer or time

        self.check_future_version_available()

    def get(self, version_id=None):
        if version_id is None:
            version_id = self.get_default_version_id()
        try:
            version = self.signing_registry.get(version_id)
        except VersionNotFoundSigningError:
            version, _ = self.key_storage.get_key(int(version_id))
            try:
                version = self._parse_key(version)
                self.signing_registry.add_from_dict(version)
            except Exception:
                # Кеширование сломанных ключей, не позволит заменить плохие
                # ключи на хорошие без перезапуска Паспорта.
                self.key_storage.delete_key_from_cache(int(version_id))
                raise
            version = self.signing_registry.get(version_id)

        self.check_future_version_available()

        return version

    def _parse_key(self, key):
        key = json.loads(key)
        key['id'] = key['id'].encode('ascii')
        key['secret'] = base64.standard_b64decode(key['secret'])
        return dict(versions=[key])

    def get_default_version_id(self):
        return self.get_version_id_from_timestamp(self.timer())

    def get_version_id_from_timestamp(self, timestamp):
        return str(int(timestamp / self.epoch_length)).encode('ascii')

    def get_prev_version_id(self, version_id):
        return self.signing_registry.get_prev_version_id(version_id)

    def check_future_version_available(self):
        timestamp = self.get_timestamp_when_secrets_run_out()
        if timestamp is not None:
            time_left = timestamp - self.timer()
            time_left = timedelta(seconds=time_left)
            log.error('There are signing keys for less than %s' % time_left)

    def get_timestamp_when_secrets_run_out(self):
        """
        Проверяет, что в хранилище припасено секретов хотя бы на следующие
        foresight секунд.
        Если серкетов досаточно возвращается None, а если секретов не хватает,
        возвращает наименьший момент времени (unixtime) в котором не хватит
        секретов.
        """
        now = self.timer()

        if (
            self._foresight_cache_expires_at is not None and
            self._foresight_cache_expires_at > now
        ):
            return self._foresight_cache

        next_timestamp = now + self.epoch_length
        foresight_timestamp = now + self.foresight
        retval = None
        while next_timestamp <= foresight_timestamp:
            next_version_id = self.get_version_id_from_timestamp(next_timestamp)
            if not self.key_storage.is_key_available(int(next_version_id)):
                # Вычисляем момент активации недостающего секрета
                retval = next_timestamp - next_timestamp % self.epoch_length
                break
            next_timestamp += self.epoch_length

        self._foresight_cache_expires_at = now + self.foresight_cache_ttl
        self._foresight_cache = retval

        return retval


@lazy_loadable('SigningRegistry')
class YandexSigningRegistry(SigningRegistry):
    def __init__(self):
        super(YandexSigningRegistry, self).__init__()
        self.add_from_dict(settings.SIGNING_REGISTRY)


@lazy_loadable('YandexAndYandexTeamSigningRegistry')
class YandexAndYandexTeamRotatingSigningRegistry(RotatingSigningRegistry):
    def __init__(self, **kwargs):
        defaults = dict(
            key_storage=SigningKeyStorage(settings.YANDEX_AND_YANDEX_TEAM_SIGNING_REGISTRY_PATH),
            epoch_length=settings.SIGNING_REGISTRY_KEY_TTL,
            foresight=settings.SIGNING_REGISTRY_FORESIGHT,
            foresight_cache_ttl=settings.SIGNING_REGISTRY_FORESIGHT_CACHE_TTL,
        )
        for key in defaults:
            kwargs.setdefault(key, defaults[key])
        super(YandexAndYandexTeamRotatingSigningRegistry, self).__init__(**kwargs)


class ByteSequence(formencode.validators.ByteString):
    # отличается от ByteString тем, что принимает и возвращает именно байты, а не строку

    messages = {
        'bytesExpected': 'Invalid value type, bytes expected',
    }

    def to_python(self, value, state=None):
        if not isinstance(value, six.binary_type):
            raise formencode.Invalid(self.message('bytesExpected', state), value, state)
        value = value.decode('latin1')
        value = super(ByteSequence, self).to_python(value, state)
        return value.encode('latin1')


class VersionIdValidator(ByteSequence):
    not_empty = True


class BytesFriendlyForm(formencode.schema.Schema):
    def _value_is_iterator(self, value):
        # в недрах formencode считает bytes итератором, надо его разубедить
        return False


class VersionForm(BytesFriendlyForm):
    id = VersionIdValidator()
    algorithm = formencode.validators.OneOf(['SHA256'], not_empty=True)
    salt_length = formencode.validators.Int(min=1, not_empty=True)
    secret = ByteSequence(not_empty=True)


class SigningRegistryForm(BytesFriendlyForm):
    default_version_id = VersionIdValidator(if_missing=None)
    versions = formencode.foreach.ForEach(VersionForm())


class SigningError(Exception):
    pass


class InvalidSignatureSigningError(SigningError):
    pass


class VersionNotFoundSigningError(SigningError):
    pass


class DefaultVersionNotFoundSigningError(SigningError):
    pass


class ConfigurationFailedSigningError(SigningError):
    pass
