# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import gzip
import json
import logging
import os
import traceback
import warnings
from time import sleep

import six
from library.python.vault_client.errors import ClientError
from cached_property import cached_property

from travel.library.python.rasp_vault import common
from travel.library.python.rasp_vault.common import (
    VaultVersion, VaultSecret, SecretError, SecretFetchError, SecretNotFoundError
)


logger = logging.getLogger('rasp-vault')


NON_RETRIABLE_CLIENT_ERRORS = ('invalid_uuid_prefix',)


def _retry_on_exception(func, args=(), kwargs=None, timeout=1, retries=5):
    kwargs = kwargs or {}
    for i in range(retries):
        try:
            return func(*args, **kwargs)
        except ClientError as exc:
            if hasattr(exc, 'kwargs') and isinstance(exc.kwargs, dict):
                if exc.kwargs.get('code') in NON_RETRIABLE_CLIENT_ERRORS:
                    raise
            if i == retries - 1:
                raise
        except Exception:
            if i == retries - 1:
                raise

        logger.warning('Vault Connection Error, attempt #%s', i, exc_info=True)
        sleep(timeout * 2 ** i)


class YavSecretProvider(object):
    def __init__(self, yav_client=None, oauth_token=None):
        self._get_secret_version_cache = {}
        self._get_secret_by_uuid_cache = {}
        self._get_secret_list_cache = None
        self._yav_client = yav_client
        self._oauth_token = oauth_token

    @cached_property
    def yav_client(self):
        return self._yav_client or common.get_vault_client(self._oauth_token)

    def get(self, secret_path):
        if not self.yav_client:
            return None

        secret_key = secret_path
        key_name = None
        if '.' in secret_key:
            secret_key, key_name = secret_path.split('.')

        secret = self._get_secret_by_uuid(secret_key)
        if not secret:
            if os.getenv('YA_TEST_RUNNER'):
                logger.warn('Secret %s not found', secret_path)
                return None
            raise SecretNotFoundError('Secret {} not found'.format(secret_path))

        version = self._get_secret_version(secret)
        if key_name:
            return version.value[key_name]
        else:
            return version.value

    def _get_secret_list(self):
        if not self._get_secret_list_cache:
            self._get_secret_list_cache = []
            page = 0
            while True:
                secrets = _retry_on_exception(self.yav_client.list_secrets, kwargs=dict(page_size=50, page=page))
                page += 1
                self._get_secret_list_cache.extend(secrets)
                if not secrets:
                    break

        return self._get_secret_list_cache

    def _fetch_secret(self, secret_key):
        # All secret UUIDs start with 'sec'
        if secret_key.startswith('sec'):
            try:
                return VaultSecret.from_get_secret_method(
                    _retry_on_exception(self.yav_client.get_secret, args=(secret_key,), kwargs=dict(page_size=1))
                )
            except ClientError:
                pass

        # If secret_key is an alias
        for secret_data in self._get_secret_list():
            if secret_data['name'] == secret_key:
                return VaultSecret.from_list_secrets_method(secret_data)

        return None

    def _get_secret_by_uuid(self, secret_uuid):
        if secret_uuid not in self._get_secret_by_uuid_cache:
            secret = self._fetch_secret(secret_uuid)
            if secret:
                self._get_secret_by_uuid_cache[secret_uuid] = secret

        return self._get_secret_by_uuid_cache.get(secret_uuid, None)

    def _get_secret_version(self, secret):
        if secret.last_version not in self._get_secret_version_cache:
            version = VaultVersion.from_get_version_method(
                secret,
                _retry_on_exception(self.yav_client.get_version, args=(secret.last_version,))
            )
            self._get_secret_version_cache[secret.last_version] = version

        return self._get_secret_version_cache[secret.last_version]


class FileBasedSecretProvider(object):
    def __init__(self, search_path):
        self._search_path = search_path
        self._secrets = {}
        self._load_file_secrets()

    def _load_file_secrets(self):
        for f in os.listdir(self._search_path):
            if f.endswith('.json.gz'):
                self._load_file(os.path.join(self._search_path, f))

    def _load_file(self, fpath):
        """
        secret_data example:
        {
          "secret-alias": {
            "secret": "sec-01cma50anqzshpfj1fg45zmzhd",
            "version": "ver-01cma50anvztsnq627b1mnkyaa",
            "value": {
              "password": "some_password",
              "user": "rasp"
            }
          }
        }
        """
        with gzip.open(fpath) as f:
            secrets_data = json.load(f)
            for alias, secret in secrets_data.items():
                self._secrets[alias] = secret
                self._secrets[secret['secret']] = secret

    def get(self, secret_path):
        if '.' in secret_path:
            secret_id, key_name = secret_path.split('.')
        else:
            secret_id = secret_path
            key_name = None

        if secret_id not in self._secrets:
            raise SecretNotFoundError('Secret {} not found'.format(secret_path))

        value = self._secrets[secret_id]['value']
        if not key_name:
            return value

        if key_name not in value:
            raise SecretNotFoundError('Secret {} not found'.format(secret_path))

        return value[key_name]


class Secrets(object):
    def __init__(self, provider=None):
        if not provider:
            search_path = os.getenv('RASP_VAULT_PATH', '/etc/yav-secrets')
            if os.path.exists(search_path):
                provider = FileBasedSecretProvider(search_path)
            else:
                provider = YavSecretProvider()

        self._provider = provider

    def __rerp__(self):
        return '<Secrets>'

    def __str__(self):
        return '<Secrets>'

    def get(self, secret_path):
        try:
            return self._provider.get(secret_path)
        except SecretError:
            raise
        except Exception as exc:
            try:
                tb = six.text_type(traceback.format_exc())
            except UnicodeDecodeError:
                tb = six.text_type(traceback.format_exc(), encoding='utf-8', errors='ignore')
            raise SecretFetchError('Secret {} not found, error {}, traceback:\n{}'.format(
                secret_path, type(exc), tb
            ))


_secrets = Secrets()


def get_secret(secret_path, secret_stub=None):
    """
    :param secret_path: <алиас секрета> - 'rasp-common' | <uuid секрета> - 'sec-01sdfskljilkjaklj'
        | <алиас секрета>.<ключ> - 'rasp-common.MYSQL_PASSWORD'
        | <uuid секрета>.<ключ> - 'sec-01sdfskljilkjaklj.MYSQL_PASSWORD'
    :param secret_stub: Заглушка, которую нужно вернуть, если RASP_VAULT_STUB_SECRETS включен
    :return: Все ключи секрета или только один ключ
    При включенной переменной окружения RASP_VAULT_STUB_SECRETS возвращает secret_stub.
    """

    if os.getenv('RASP_VAULT_IGNORE_ERRORS'):
        warnings.warn('Use RASP_VAULT_STUB_SECRETS environment variable', DeprecationWarning)
        return secret_stub

    if os.getenv('RASP_VAULT_STUB_SECRETS'):
        return secret_stub

    return _secrets.get(secret_path)
