# coding: utf-8

import codecs
from collections import OrderedDict
import contextlib
from fnmatch import fnmatch
import itertools
import json
import os
import socket
import subprocess

import jinja2
import jmespath
from passport.backend.utils.file import (
    chown as passport_chown,
    get_gid,
    get_uid,
    parse_owner,
)
from passport.backend.utils.string import smart_str
import six
from vault_client_deploy.actions import (
    FileTemplateAction,
    MultiFileAction,
    SingleFileAction,
    UnknownFileAction,
)
import yenv


try:
    from library.python.vault_client import VaultClient
    from library.python.vault_client.auth import (
        RSAPrivateKeyAuth,
        RSASSHAgentHash,
    )
    from library.python.vault_client.errors import ClientError
    from library.python.vault_client.instances import (
        Production,
        Testing,
    )
except ImportError:
    # ToDo: Убрать TOP_LEVEL импорты при окончательном заезде в Аркадию.
    from vault_client import VaultClient
    from vault_client.auth import (
        RSAPrivateKeyAuth,
        RSASSHAgentHash,
    )
    from vault_client.errors import ClientError
    from vault_client.instances import (
        Production,
        Testing,
    )

if six.PY2:
    from backports import configparser
else:
    import configparser


ENVIRONMENTS_CONF = 'environments.conf'
DEFAULTS_CONF = 'defaults.conf'
TEMPLATES_FOLDER = 'templates'

OAUTH_TOKEN_ENV_VAR = 'YAV_TOKEN'
HOSTNAME_ENV_VAR = 'YAV_DEPLOY_HOSTNAME'

TAGS_ENV_VAR = 'YAV_DEPLOY_TAGS'
ENV_TAGS_FILENAME = 'env.tags'
RESERVED_TAGS_PREFIXES = (
    'conductor.',
)

TMP_EXT = '.tmp'


def mask_oauth_token(token):
    if not token:
        return

    l = len(token) // 2
    return token[:l] + '*' * l


class VaultError(Exception):
    pass


class CreateFileError(Exception):
    pass


class ConfigParser(configparser.ConfigParser):
    def __init__(self, filename=None, parents=None, encoding='utf-8'):
        super(ConfigParser, self).__init__(
            delimiters=('=', ),
        )
        self.filename = filename
        self.encoding = encoding
        self.parents = parents

    def read(self, filename=None, encoding=None):
        super(ConfigParser, self).read(
            (self.parents or []) + [filename or self.filename],
            encoding=encoding or self.encoding,
        )

    def optionxform(self, option):
        """
        Если в ключе action для файла или в ключе есть префикс с двоеточием, то не меняем регистр.
        Инчаче оставляем стандартное поведение парсера — переводим ключ в нижний регистр.
        """
        option = option.strip()
        if option.startswith('/') or option.find(':') > 0:
            return option
        else:
            return option.lower()


class Secrets(object):
    _cache = dict()

    def __init__(self, vault_client):
        self.vault_client = vault_client

    def get(self, secret_uuid, packed_value=True):
        value = None
        if secret_uuid in self._cache:
            value = self._cache[secret_uuid]

        if value is None:
            try:
                value = self.vault_client.get_version(secret_uuid, packed_value=False)['value']
                self._cache[secret_uuid] = value
            except ClientError as e:
                raise VaultError(u'VaultError: {}: {} (req: {})'.format(
                    secret_uuid,
                    e.kwargs.get('message'),
                    e.kwargs.get('request_id'),
                ))

        return self.vault_client.pack_value(value) if packed_value else value


class Environment(object):
    def __init__(self, confs_path, chroot, logger, type_=None, name=None, tags=None, hostname=None, skip_chown=False,
                 config_files=None, oauth_token=None, rsa_login=None, rsa_private_key=None, rsa_agent_key_hash=None,
                 native_client=None, quiet=False, testing=False, debug=False):
        self.confs_path = confs_path
        self.chroot = chroot
        self.skip_chown = skip_chown

        self.type = yenv.type.strip() if type_ is None else type_
        self.name = yenv.name.strip() if name is None else name

        self.env_tags_filename = os.path.join(self.confs_path, ENV_TAGS_FILENAME)
        self.tags = self.fetch_tags() if tags is None else tags

        hostname = hostname or os.environ.get(HOSTNAME_ENV_VAR)
        self.hostname = hostname or socket.getfqdn()

        self.confs = self.find_confs(
            os.path.join(self.confs_path, ENVIRONMENTS_CONF),
            config_files,
        )
        self.templates = jinja2.Environment(
            loader=jinja2.FileSystemLoader(os.path.join(self.confs_path, TEMPLATES_FOLDER)),
            autoescape=False,
            keep_trailing_newline=True,
            trim_blocks=True,
        )
        self.templates.filters['from_json'] = json.loads
        self.templates.filters['json_query'] = lambda data, expr: jmespath.search(expr, data)

        self._native_client = native_client or VaultClient.create_native_client()

        self.oauth_token = oauth_token or os.environ.get(OAUTH_TOKEN_ENV_VAR, None)
        self.rsa_login = rsa_login
        self.rsa_private_key = rsa_private_key
        self.rsa_agent_key_hash = rsa_agent_key_hash

        self.debug = debug
        self.testing = testing
        self.quiet = quiet
        self.logger = logger

    def __repr__(self):
        return u'Environment<type={self.type}, name={self.name}, ' \
               u'hostname={self.hostname}, confs={self.confs}, ' \
               u'confs_path={self.confs_path}, chroot={self.chroot}>'.format(self=self)

    def run(self, force=False, valid_sections=None, skip_pre_update=False, skip_post_update=False):
        self.logger.debug('Load configs')
        confs_ = [
            DeployConfig(self, conf, valid_sections=valid_sections)
            for conf in self.confs
        ]

        self.logger.debug('Prepare files')
        for conf in confs_:
            conf.prepare()

        self.logger.debug('Pre update')
        for conf in confs_:
            conf.run_pre_update(skip_exec=skip_pre_update)

        self.logger.debug('Commit files (chroot: {})'.format(self.chroot))
        for conf in confs_:
            conf.commit(force=force)

        self.logger.debug('Post update')
        for conf in confs_:
            conf.run_post_update(skip_exec=skip_post_update)

        return False if any(map(lambda x: x.has_errors, confs_)) else True

    @property
    def has_confs(self):
        return len(self.confs) > 0

    def conf_exists(self, conf_name):
        return os.path.exists(os.path.join(self.confs_path, conf_name))

    def fetch_tags(self):
        result = None
        env_tags = os.environ.get(TAGS_ENV_VAR, '').strip()
        if env_tags:
            result = env_tags.split(',')

        if not result and os.path.exists(self.env_tags_filename):
            with codecs.open(self.env_tags_filename, 'r', 'utf-8') as f:
                result = ','.join(
                    filter(
                        lambda l: not l.startswith('#'),
                        f.readlines(),
                    ),
                ).split(',')

        if result:
            result = filter(
                lambda x: x is not None,
                map(self.normalize_tag, result),
            )

        return set(result or [])

    @staticmethod
    def normalize_tag(tag):
        tag = tag.strip().lower()
        for prefix in RESERVED_TAGS_PREFIXES:
            if tag.startswith(prefix):
                raise ValueError('{} is a reserved tags prefix'.format(prefix))
        return tag or None

    @staticmethod
    def _multi_fnmatch(str_, patterns):
        patterns = [p.strip() for p in patterns.split(',')]
        for p in patterns:
            if fnmatch(str_, p):
                return True
        return False

    def _match_tags(self, tags):
        matched = dict(
            filter(
                lambda x: x[1] is True,
                map(
                    lambda x: (x[1], fnmatch(x[0], x[1])),
                    itertools.product(self.tags, tags),
                ),
            ),
        )
        return len(tags) == len(matched)

    def find_confs(self, environments_conf_filename, config_files=None):
        result = []

        if config_files:
            result = [f for f in config_files if self.conf_exists(f)]
            return result

        conf = ConfigParser(environments_conf_filename, encoding='utf-8')
        conf.read()
        for c in conf.sections():
            v = dict(conf[c])
            env_type = v.get('type', '')
            env_name = v.get('name', '')
            env_hostname = v.get('hostname')
            env_tags = [self.normalize_tag(t) for t in v.get('tags', '').split(',') if t]

            if (
                self._multi_fnmatch(self.type, env_type) and
                self._multi_fnmatch(self.name, env_name) and
                (not env_hostname or self._multi_fnmatch(self.hostname, env_hostname)) and
                (not len(env_tags) or self._match_tags(env_tags)) and
                self.conf_exists(c)
            ):
                result.append(c)

        if not result:
            static_confs = [
                '{type}-{name}.conf'.format(type=self.type, name=self.name),
                '{type}.conf'.format(type=self.type),
            ]
            for conf in static_confs:
                if self.conf_exists(conf):
                    result.append(conf)
                    break

        if not result and self.conf_exists(DEFAULTS_CONF):
            result.append(DEFAULTS_CONF)

        return result

    def get_secrets_storage(self, section=None, oauth_token=None, rsa_auth=True, rsa_login=None):
        if oauth_token:
            if section:
                self.logger.debug('[{}] VaultClient(OAuth={})'.format(
                    section.name, mask_oauth_token(oauth_token),
                ))
            return Secrets(self.create_vault_client(
                authorization='OAuth {}'.format(oauth_token),
                rsa_auth=False,
                testing=self.testing,
            ))
        else:
            if section:
                self.logger.debug(
                    '[{}] VaultClient(rsa_auth={}, rsa_login={})'.format(section.name, rsa_auth, rsa_login),
                )
            return Secrets(self.create_vault_client(
                rsa_auth=rsa_auth,
                rsa_login=rsa_login,
                testing=self.testing,
            ))

    def create_vault_client(self, authorization=None, rsa_auth=True, rsa_login=None, testing=False):
        class_ = Testing if testing else Production
        return class_(
            check_status=False,
            native_client=self._native_client,
            authorization=authorization,
            rsa_auth=rsa_auth,
            rsa_login=rsa_login,
            decode_files=False,
        )

    def path_join(self, *args):
        return os.path.join(
            self.chroot,
            *[
                (a.decode('utf-8') if isinstance(a, six.binary_type) else a).strip(os.sep) for a in args
            ]
        )

    def compare_file_data(self, filename, data):
        filename = self.path_join(filename)
        if not os.path.exists(filename):
            return False

        with open(filename, 'rb') as f:
            return f.read() == data

    def save_file(self, filename, data, permissions=None, owner=None, force=False):
        result = False
        if force or not self.compare_file_data(filename, data):
            self.logger.debug('save {}'.format(filename))
            with self.create_file(filename, permissions=permissions, owner=owner) as f:
                if not isinstance(data, six.binary_type):
                    data = data.encode('utf-8')
                f.write(data)
            result = True
        else:
            self.logger.debug('skip {} (not modified)'.format(filename))
            result = self.update_permissions(filename, permissions, owner)
        return result

    def update_permissions(self, filename, permissions=None, owner=None):
        result = False
        full_filename = self.path_join(filename.encode('utf-8'))

        if permissions is not None:
            chmod_result = self.chmod(full_filename, permissions)
            result = result or chmod_result

        if owner is not None and not self.skip_chown:
            chown_result = self.chown(full_filename, owner)
            result = result or chown_result

        return result

    @contextlib.contextmanager
    def _wrap_create_file_error(self, description, skip_log=False):
        try:
            yield
            if not skip_log:
                self.logger.debug(description)
        except Exception as e:
            raise CreateFileError(
                '{description}: {e}'.format(
                    description=description,
                    e=e,
                ),
            )

    @contextlib.contextmanager
    def create_file(self, filename, mode='wb', permissions=None, owner=None):
        filename = self.path_join(filename.encode('utf-8'))
        temp_filename = filename + TMP_EXT
        folder = os.path.dirname(filename)

        if not os.path.isdir(folder):
            with self._wrap_create_file_error('mkdir -p ' + folder):
                os.makedirs(folder)

        if os.path.exists(temp_filename):
            with self._wrap_create_file_error('rm ' + temp_filename):
                os.remove(temp_filename)

        try:
            with self._wrap_create_file_error('open({}, {})'.format(temp_filename, mode), skip_log=True):
                with open(temp_filename, mode) as f:
                    yield f

            if permissions is not None:
                self.chmod(temp_filename, permissions)

            if owner is not None and not self.skip_chown:
                self.chown(temp_filename, owner)

            with self._wrap_create_file_error('mv {} {}'.format(temp_filename, filename)):
                os.rename(temp_filename, filename)
        finally:
            if os.path.exists(temp_filename):
                try:
                    os.remove(temp_filename)
                    self.logger.debug('remove temporary file: {}'.format(temp_filename))
                except:
                    pass

    def chmod(self, filename, permissions):
        stat = os.lstat(filename)
        if permissions is None or (stat.st_mode & 0o7777) == permissions:
            # Смотрим на 12 бит файловой маски
            return False

        with self._wrap_create_file_error('chmod {:04o} {}'.format(permissions, filename)):
            os.chmod(filename, permissions)
        return True

    def chown(self, filename, owner):
        if owner is None or owner.strip() == ':':
            return False

        stat = os.lstat(filename)
        user, group = parse_owner(owner)
        new_uid, new_gid = get_uid(user), get_gid(group)

        if (
            (new_uid is None or stat.st_uid == new_uid)
            and (new_gid is None or stat.st_gid == new_gid)
        ):
            return False

        with self._wrap_create_file_error('chown {} {}'.format(owner, filename)):
            passport_chown(filename, user, group)
        return True


class UnsafeMixin(object):
    @contextlib.contextmanager
    def unsafe(self, fail_on_exception=False):
        try:
            yield
        except Exception as e:
            self.failed = True if fail_on_exception else self.failed
            self.log_exception(e)

    def log_exception(self, exception):
        self.environment.logger.error(
            smart_str(exception),
            exc_info=(self.environment.debug and not isinstance(exception, VaultError)),
        )


class DeployConfig(UnsafeMixin):
    def __init__(self, environment, filename, valid_sections=None):
        self.failed = False
        self.environment = environment
        self.confs_path = self.environment.confs_path
        self.filename = filename
        self.full_path = os.path.join(self.confs_path, self.filename)
        self.valid_sections = valid_sections or []

        with self.unsafe(fail_on_exception=True):
            self.conf = self.read_conf()

            self.sections = self.parse_sections()
            if not self.sections:
                self.environment.logger.debug(
                    'Sections not found in the "{}" file. '
                    'The DEFAULT section is never executed directly'.format(self.filename),
                )

    @property
    def has_errors(self):
        return self.failed or any(map(lambda x: x.has_errors, self.sections.values()))

    def read_conf(self):
        self.environment.logger.debug('{}:'.format(self.filename))
        with self.unsafe(fail_on_exception=True):
            conf = ConfigParser(
                os.path.join(self.confs_path, self.filename),
                parents=[os.path.join(self.confs_path, DEFAULTS_CONF)],
                encoding='utf-8',
            )
            conf.read()
        return conf

    def parse_sections(self):
        result = OrderedDict()
        for sec in self.conf.sections():
            if self.valid_sections and sec not in self.valid_sections:
                continue
            result[sec] = DeploySection(self.environment, self.conf[sec], name=sec)
        return result

    def foreach_sections(self):
        for sec in self.sections.values():
            if not sec.has_actions:
                continue
            yield sec

    def prepare(self):
        if self.failed:
            return
        with self.unsafe(fail_on_exception=True):
            self.environment.logger.debug('Processing {}:'.format(self.full_path))
            for sec in self.foreach_sections():
                sec.prepare()

    def run_pre_update(self, skip_exec=False):
        if self.failed:
            return
        with self.unsafe(fail_on_exception=True):
            self.environment.logger.debug('Processing {}:'.format(self.full_path))
            for sec in self.foreach_sections():
                sec.run_pre_update(skip_exec=skip_exec)

    def commit(self, force=False):
        if self.failed:
            return
        with self.unsafe(fail_on_exception=True):
            self.environment.logger.debug('Processing {}:'.format(self.full_path))
            for sec in self.foreach_sections():
                sec.commit(force=force)

    def run_post_update(self, skip_exec=False):
        if self.failed:
            return
        with self.unsafe(fail_on_exception=True):
            self.environment.logger.debug('Processing {}:'.format(self.full_path))
            for sec in self.foreach_sections():
                sec.run_post_update(skip_exec=skip_exec)


class DeploySection(UnsafeMixin):
    def __init__(self, environment, conf_section, name=None):
        self.failed = False
        self.environment = environment
        self._action_handlers = [
            MultiFileAction,
            SingleFileAction,
            FileTemplateAction,
            UnknownFileAction,
        ]

        conf_section = OrderedDict(conf_section)
        self.conf_section = OrderedDict(conf_section)  # Сохраняем оригинал из файла
        self.name = name
        self.prefix = conf_section.pop('prefix', '/')
        self.owner = conf_section.pop('owner', None)

        self.mode = conf_section.pop('mode', None)
        if self.mode is not None:
            self.mode = int(self.mode, 8)

        self.pre_update = conf_section.pop('pre-update', None)
        self.post_update = conf_section.pop('post-update', None)

        self.secrets = self.environment.get_secrets_storage(
            self,
            **self._get_auth_params(conf_section)
        )

        self.actions = None
        with self.unsafe(fail_on_exception=True):
            self.actions = self.get_actions(conf_section)
            conf_section = OrderedDict(filter(lambda x: x[0] not in self.actions, conf_section.items()))

            self.vars = OrderedDict(conf_section)

        if self.has_actions:
            for action in self.actions.values():
                self.environment.logger.debug(repr(action))
        else:
            self.environment.logger.debug(
                '[{self.name}] Actions not found in the section. '
                'The action line should starts with a slash'.format(self=self),
            )

    @property
    def has_actions(self):
        return self.actions is not None and len(self.actions) > 0

    @property
    def need_post_update(self):
        return any(map(lambda x: x.need_post_update, self.actions.values()))

    @property
    def has_errors(self):
        return self.failed

    def get_actions(self, conf_section):
        result = OrderedDict()
        for k, v in conf_section.items():
            for action_class in self._action_handlers:
                action = action_class.build_action(self, self.secrets, k, v)
                if action is not None:
                    result[k] = action
                    break
        return result

    def path_join(self, *args):
        return os.path.join(self.prefix or '/', *[a.strip(os.sep) for a in args])

    def _get_auth_params(self, conf_section):
        if self.environment.oauth_token:
            return dict(
                oauth_token=self.environment.oauth_token,
                rsa_auth=False,
            )

        rsa_login = conf_section.pop('rsa-login', None)
        rsa_login = self.environment.rsa_login or rsa_login
        if not rsa_login:
            rsa_login = None

        rsa_private_key = conf_section.pop('rsa-private-key', None)
        rsa_private_key = self.environment.rsa_private_key or rsa_private_key
        if rsa_private_key:
            with open(os.path.expanduser(rsa_private_key)) as f:
                return dict(
                    rsa_login=rsa_login,
                    rsa_auth=RSAPrivateKeyAuth(f.read()),
                )

        rsa_agent_key_hash = conf_section.pop('rsa-agent-key-hash', None)
        rsa_agent_key_hash = self.environment.rsa_agent_key_hash or rsa_agent_key_hash
        if rsa_agent_key_hash:
            return dict(
                rsa_login=rsa_login,
                rsa_auth=RSASSHAgentHash(rsa_agent_key_hash),
            )

        return dict(rsa_auth=True, rsa_login=rsa_login)

    def prepare(self):
        if self.failed:
            return
        with self.unsafe(fail_on_exception=True):
            self.environment.logger.debug('[{}]'.format(self.name))
            for action in self.actions.values():
                action.prepare()

    def commit(self, force=False):
        if self.failed:
            return
        with self.unsafe(fail_on_exception=True):
            self.environment.logger.debug('[{}]'.format(self.name))
            for action in self.actions.values():
                action.commit(force=force)

    def _run_commands(self, commands, raise_exception=True):
        process = subprocess.Popen(
            commands,
            shell=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        stdout, stderr = process.communicate()
        self.environment.logger.debug(stdout)
        self.environment.logger.debug(stderr)
        if process.returncode != 0 and raise_exception:
            raise subprocess.CalledProcessError(process.returncode, commands)
        return process.returncode

    def run_pre_update(self, skip_exec=False):
        if self.failed:
            return
        with self.unsafe(fail_on_exception=True):
            if self.pre_update:
                self.environment.logger.debug(
                    '[{self.name}] Run {self.pre_update} (skipped: {skip_exec})'.format(
                        self=self,
                        skip_exec=skip_exec,
                    ),
                )
                if not skip_exec:
                    code = self._run_commands(self.pre_update, raise_exception=False)
                    if code != 0:
                        self.failed = True
                        self.environment.logger.debug(
                            '[{self.name}] pre-update returns {code}. Skip section'.format(
                                self=self,
                                code=code,
                            ),
                        )
            else:
                self.environment.logger.debug(
                    '[{self.name}] pre-update not found (skipped)'.format(
                        self=self,
                    ),
                )

    def run_post_update(self, skip_exec=False):
        if self.failed:
            return

        with self.unsafe(fail_on_exception=True):
            if not self.need_post_update:
                self.environment.logger.debug(
                    '[{self.name}] Post update doesn\'t required (skipped)'.format(
                        self=self,
                    ),
                )
                return

            if not self.post_update:
                self.environment.logger.debug(
                    '[{self.name}] post-update not found (skipped)'.format(
                        self=self,
                    ),
                )
                return

            self.environment.logger.debug(
                '[{self.name}] Run {self.post_update} (skipped: {skip_exec})'.format(
                    self=self,
                    skip_exec=skip_exec,
                ),
            )
            if not skip_exec:
                self._run_commands(self.post_update)
