from contextlib import contextmanager
from itertools import chain
from collections import defaultdict

from clause_resolver import ClauseResolver, resolve_hosts_groups_intersection
from parser import parse
from expression import transform_literals, Literal
from errors import HrHostCheckFailedError, HrConfigurationChanged

from helpers import cleanHostsDict, check_host_fqdn
from hostsource import CmsSource, HostSource


class ResolverBase(object):
    def __init__(self):
        self.cms_source = CmsSource()
        self.host_source = HostSource()
        self.transform_hostname = check_host_fqdn
        self.resolver_meta = None

    def resolve_hosts(self, expression):
        """
        :param str expression: expression
        :rtype: set
        :return: set of possibly transformed hostnames
        """

        dnf = self._parse(expression)
        result = set()

        with self._cms_guard(dnf):
            for clause in dnf.clauses:
                h = self._resolve_hosts(clause)
                result.update(h)

        return result

    def resolve_instances(self, expression):
        """
        :param str expression: expression
        :rtype: dict
        :return: dict: {host -> set of pairs (shard, instance)}
        """

        dnf = self._parse(expression)
        result = defaultdict(set)

        with self._cms_guard(dnf):
            for clause in dnf.clauses:
                for host, data in self._resolve_instances(clause).iteritems():
                    result[host].update(data)

        return cleanHostsDict(result)

    def resolve_instances2(self, expression):
        """
        :param str expression: expression
        :rtype: dict
        """

        dnf = self._parse(expression)
        result = defaultdict(set)

        with self._cms_guard(dnf):
            for clause in dnf.clauses:
                for host, data in self._resolve_instances(clause).iteritems():
                    result[host].update(data)

        return {
            'instances': dict(result),
            'meta': self.resolver_meta,
        }

    def resolve_shards(self, expression):
        """
        :param str expression: expression
        :rtype: dict
        :return: dict: {host -> set of shards}
        """
        result = self.resolve_instances(expression)
        for host, value in result.iteritems():
            result[host] = set(map(lambda x: x[0], value)) - {'none'}
        return cleanHostsDict(result)

    def _parse(self, expression):
        dnf = parse(expression)
        return transform_hostnames(dnf, self.transform_hostname)

    def _resolve_hosts(self, clause):
        if _clause_requires_cms(clause):
            return self._resolve_instances(clause).keys()

        hosts = resolve_hosts_groups_intersection(
            self.host_source,
            clause.get_positive_names_by_prefixes()['h'],
            clause.get_p_literals_by_prefix_set('HKdl')
        )

        if not hosts:
            return set()

        hosts -= clause.get_negative_names_by_prefixes()['h']
        hosts -= self.host_source.getHostsByGroups(clause.get_n_literals_by_prefix_set('HKdl'))

        return hosts

    def _resolve_instances(self, clause):
        clause_resolver = ClauseResolver(clause, self.host_source, self.cms_source)
        result = clause_resolver.resolve()
        self.resolver_meta = clause_resolver.meta_store
        return result

    @contextmanager
    def _cms_guard(self, dnf):
        if _dnf_requires_cms(dnf):
            with SafeResolve(self.cms_source, dnf.clauses):
                yield
        else:
            yield


def transform_hostnames(dnf, f):
    def t(literal):
        if literal.prefix == 'h':
            return Literal(f(literal.name), literal.prefix)
        else:
            return literal
    (dnf, errors) = transform_literals(dnf, t)
    if errors:
        raise HrHostCheckFailedError(errors)
    return dnf


class SafeResolve(object):
    """
    Validates hosts resolving consistency by
    comparing used configuration's modification time before and after the resolution.
    This operation should not be too expensive as nekt0n@ said.
    """

    def __init__(self, cms_source, clauses):
        self.cms_source = cms_source
        used = \
            list(chain.from_iterable(c.get_positive_names_by_prefixes()['C'] for c in clauses)) or ['HEAD'] \
            if len(clauses) > 1 else \
            []  # No need to guard single request
        self.confs = dict.fromkeys(used)

    def __enter__(self):
        for c in self.confs.keys():
            self.confs[c] = self.cms_source.list_confs(c)[c]['mtime']
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:  # No need to check configuration validity - exception raised
            return
        for c in self.confs.keys():
            if self.confs[c] != self.cms_source.list_confs(c)[c]['mtime']:
                raise HrConfigurationChanged(c)


def _dnf_requires_cms(dnf):
    return any(_clause_requires_cms(clause) for clause in dnf.clauses)


def _clause_requires_cms(clause):
    cms_dependent_keys = 's', 'S', 'I', 'C'

    return any(
        chain(
            (clause.get_positive_names_by_prefixes()[x] for x in cms_dependent_keys),
            (clause.get_positive_names_by_prefixes()[x] for x in cms_dependent_keys)
        )
    )
