import os
import codecs
import re
import json
import hashlib
import logging
import pwd
from collections import defaultdict

from google.protobuf import text_format
import six

from crypta.lib.python.logging import logging_helpers


logger = logging.getLogger(__name__)

LOCALS = locals()

CONFIG_SOURCE = '__config_source'

CONFIG_DICT = '__config_dict'


class ConfigAlreadyLoadedError(Exception):
    """Raised when config was already loaded."""


class MissedConfigError(Exception):
    """Raised when config is not found."""


class NoSuchParameter(Exception):
    """Raised when there is no such parameter."""


class UnresolvedReference(Exception):
    """Raised when non-existent config field is referenced."""


class CyclicReference(Exception):
    """Raised when fields are referencing each other."""


class TooComplexReference(Exception):
    """Raised when reference is too complex, e.g.
    combining dict and string."""


class InvalidReference(Exception):
    """Raised when reference is invalid, e.g.
    contains invalid characters."""


def _get_username():
    return pwd.getpwuid(os.getuid())[0]


def setup_global_yt(configuration):
    import yt.logger as yt_logger
    import yt.wrapper as yt_wrapper
    default_formatter = logging.root.handlers[0].formatter
    yt_logger.set_formatter(default_formatter)
    yt_logger.BASIC_FORMATTER = default_formatter
    setup_yt(yt_wrapper.config, configuration)


def _yt_module_filter(module):
    if not module:
        return False
    if not hasattr(module, '__name__'):
        return True
    if module.__name__ == 'numpy':
        return False

    return True


def setup_yt(yt_client, configuration):
    """Setup YT settings such as proxy and some policies."""

    def _merge_existing(src, dst):
        for key in src.keys():
            if key not in dst:
                continue

            if isinstance(src[key], dict):
                _merge_existing(src[key], dst[key])
            else:
                dst[key] = src[key]

    _merge_existing(configuration, yt_client.config)
    yt_client.config['pickling']['module_filter'] = _yt_module_filter

    # for native operations
    os.environ['YT_TOKEN'] = yt_client.config['token']


class ImmutableAttrDict(dict):
    """Immutable dictionary with attribute-style access.
    """

    def __init__(self, dict_):
        super(ImmutableAttrDict, self).__init__(
            {k: self._create_recurse(v) for (k, v) in six.iteritems(dict_)}
        )

    def to_dict(self):
        def _to_dict_recurse(value):
            if isinstance(value, dict):
                return {k: _to_dict_recurse(v) for (k, v) in six.iteritems(value)}
            else:
                return value
        return {k: _to_dict_recurse(v) for (k, v) in six.iteritems(self)}

    def __getattr__(self, key):
        try:
            return super(ImmutableAttrDict, self).__getitem__(key)
        except KeyError:
            raise NoSuchParameter(
                "No such parameter %s, available: %s" % (key, list(self.keys()))
            )

    def __setattr__(self, key, value):
        raise NotImplementedError("Can't set parameter %s" % (key))

    def __setitem__(self, key, value):
        raise NotImplementedError("Can't set parameter %s" % (key))

    def _I_know_what_I_do_set(self, key, value):
        super(ImmutableAttrDict, self).__setitem__(key, value)

    @classmethod
    def _create_recurse(cls, value):
        if isinstance(value, dict):
            return cls(
                {k: cls._create_recurse(v) for (k, v) in six.iteritems(value)}
            )
        else:
            return value


class EnvAccessor(object):
    """Provides access to environment variables."""
    def __getitem__(self, key):
        return os.environ[key]

    def get(self, key, default):
        return os.environ.get(key, default)


class FileAccessor(object):
    """Provides access to file contents."""
    class LazyFileReader(object):
        def __init__(self, filename):
            self.filename = filename

        def __str__(self):
            with open(self.filename) as f:
                return f.read().strip()

    def __getitem__(self, filename):
        fq_filename = os.path.expanduser(filename)
        if os.path.exists(fq_filename):
            return self.LazyFileReader(fq_filename)
        else:
            import jinja2
            raise jinja2.Undefined()


def render_config(str_config, **ctx):
    """Renders configuration stored in string with Jinja2,
    then parses it with YAML.

    :param str_config: config template as a string
    :param undefined: undefined values strategy
    """
    import yaml
    import jinja2
    import collections

    REFERENCE = 'REFERENCE'
    SEPARATOR = '::'
    BEGIN_OF_LINE = r'^'
    END_OF_LINE = r'$'
    REFERENCE_PATH = r'<([^<>]+?)>'

    def _is_valid_reference(string):
        return re.match('<>', string) is None

    def _wrap_reference(string):
        if not _is_valid_reference(string):
            raise InvalidReference(string)
        return REFERENCE + '<' + string + '>'

    def _as_simple_reference(string):
        match = re.match(
            BEGIN_OF_LINE + REFERENCE + REFERENCE_PATH + END_OF_LINE,
            string
        )
        return match

    def _unwrap_reference(string, replacer):
        def _only_string_replacer(ref):
            replacement = replacer(ref)
            if not isinstance(replacement, (str, six.text_type)):
                raise TooComplexReference(ref)

        simple_reference = _as_simple_reference(string)
        if simple_reference:
            return replacer(simple_reference)
        else:
            return re.sub(REFERENCE + REFERENCE_PATH, replacer, string)

    def _join_path(path):
        return SEPARATOR.join(path)

    def _split_path(path):
        return path.split(SEPARATOR)

    def _wrap_to_attrdict(value):
        if isinstance(value, dict):
            return ImmutableAttrDict(value)
        else:
            return value

    class _This(object):
        def __init__(self, path=None):
            self.path = path or []

        def __getattr__(self, key):
            return self.__class__(path=self.path+[key])

        def __getitem__(self, key):
            return getattr(self, key)

        def __str__(self):
            return _wrap_reference(_join_path(self.path))

    def _resolve(root, path, origin):
        element = root
        for edge in _split_path(path):
            try:
                element = element[edge]
            except KeyError:
                raise UnresolvedReference(path)
            if origin is element:
                raise CyclicReference(path)

        return _render(root, element, origin)

    def _render(root, element, origin=None):
        origin = origin or element

        def _substitute_resolved_match(match):
            return _resolve(root, match.group(1), origin)

        if isinstance(element, (str, six.text_type)):
            return _unwrap_reference(element, _substitute_resolved_match)
        if isinstance(element, list):
            return [_render(root, each, origin) for each in element]
        else:
            return element

    def _resolve_references(root):
        def _create_tree():
            return collections.defaultdict(_create_tree)
        result = _create_tree()

        queue = [([], root)]
        while queue:
            path, each = queue.pop()
            if isinstance(each, dict):
                for name, other in six.iteritems(each):
                    queue.append((path + [name], other))
            else:
                current = result
                for e in path[:-1]:
                    current = current[e]
                current[path[-1]] = _render(root, each)
        return result

    template = jinja2.Template(str_config, undefined=jinja2.StrictUndefined)

    if 'this' in ctx:
        raise ValueError('"this" is reserved')
    ctx['this'] = _This()

    raw = yaml.load(template.render(**ctx), Loader=yaml.FullLoader)
    if not raw:
        return _wrap_to_attrdict({})
    resolved = _resolve_references(raw)
    return _wrap_to_attrdict(resolved)


def load_config(template, filename):
    return render_config(
        template,
        env=EnvAccessor(),
        file=FileAccessor(),
        username=_get_username(),
    )


class DirectoryLocator(object):
    def __init__(self, path):
        self.path = os.path.abspath(path)
        if not os.path.exists(self.path):
            raise MissedConfigError('No such directory "%s"' % self.path)
        self.pretty_path = 'directory "%s"' % (self.path)

    def walk(self):
        return os.walk(self.path)

    def read(self, filename):
        if not filename or not os.path.exists(filename):
            raise MissedConfigError('Missed config "{}"'.format(os.path.abspath(filename)))

        return codecs.open(filename, 'r', 'utf-8').read()


class DictLocator(object):
    def __init__(self, dictionary):
        self.dictionary = dictionary
        self.path = os.path.dirname(os.path.commonprefix(list(self.dictionary.keys())))
        self.pretty_path = 'dictionary from "%s"' % (self.path)

    def walk(self):
        grouped = defaultdict(list)
        for path in self.dictionary:
            grouped[os.path.dirname(path)].append(os.path.basename(path))
        for dir_path, base_path in six.iteritems(grouped):
            yield dir_path, None, base_path

    def read(self, filename):
        return self.dictionary[filename]


def use_configs(locator, warnings):
    total = {}
    for config_directory, _, config_filenames in locator.walk():
        for filename in config_filenames:
            basename, ext = os.path.splitext(filename)
            if not ext == '.yaml':
                continue

            config_path = list(os.path.split(
                os.path.relpath(config_directory, locator.path))
            )
            config_path.append(basename)
            if config_path[0] == '':
                config_path.pop(0)
            if config_path[0] == '.':
                config_path.pop(0)
            if config_path[-1] == 'main':
                config_path.pop(-1)

            full_config_path = os.path.abspath(
                os.path.join(config_directory, filename)
            )

            config = load_config(locator.read(full_config_path), full_config_path)
            if not config:
                warnings.append(('Empty configuration file: "%s"',
                                 [full_config_path]))

            element = total
            for path in config_path:
                element.setdefault(path, ImmutableAttrDict({}))
                element = element[path]
            element.update(config)
    LOCALS.update(total)
    return total


def use(config):
    """Load config from the specified directory or dict.

    This functions should be called only once, raises
    :class:`ConfigAlreadyLoadedError` otherwise.

    The configuration files are expected to be in the YAML format, templated
    using Jinja2 syntax.

    :param filename: path to the configuration file
    """

    logging_helpers.configure_deploy_logger(logging.getLogger())

    if CONFIG_SOURCE in LOCALS:
        raise ConfigAlreadyLoadedError(LOCALS[CONFIG_SOURCE])

    if isinstance(config, (str, six.text_type)):
        locator = DirectoryLocator(config)
    else:
        locator = DictLocator(config)

    warnings = []

    """We load configurations from all the files in the
    directory or dict. We put everything from this config directly into
    locals() so that file some.yaml containing

    ```
    other:
        thing: Value
    ```

    is available by `some.other.thing` right in this module.
    """
    configuration = use_configs(locator, warnings)

    for message, context in warnings:
        logger.warning(message, *context)

    LOCALS[CONFIG_SOURCE] = locator.path
    LOCALS[CONFIG_DICT] = configuration

    if 'yt' in LOCALS:
        setup_global_yt(LOCALS.get('yt'))

    configuration_hash = hashlib.sha1(json.dumps(configuration, sort_keys=True)).hexdigest()
    logger.info('Using configuration %s (SHA1 %s)', locator.pretty_path, configuration_hash)
    logger.debug('Configuration is %s', configuration)


def use_proto(proto, defaults=None):
    if defaults:
        default_config = proto.__class__()
        text_format.Merge(defaults, default_config)
        default_config.MergeFrom(proto)
        proto = default_config

    LOCALS['proto'] = proto

    # TODO: remove stuff below once transition is over
    if hasattr(proto, 'Api'):
        api = ImmutableAttrDict({
            'oauth': proto.Api.Token,
            'url': proto.Api.Url,
        })
        LOCALS['api'] = api

    if hasattr(proto, 'Yt'):
        yt = ImmutableAttrDict({
            'token': proto.Yt.Token,
            'pool': proto.Yt.Pool or LOCALS.get('yt', {}).get('pool'),
            'proxy': {
                'url': proto.Yt.Proxy
            },
            'main_transaction_timeout': proto.Yt.TransactionTimeout,
        })
        LOCALS['yt'] = yt

    setup_global_yt(LOCALS.get('yt'))


def has_proto():
    return LOCALS.get('proto', False) and True or False
