# -*- coding: utf-8 -*-
import time
from itertools import chain
from collections import defaultdict

import six

from .parser import Parser
from .expression import transform_literals, Literal
from .errors import HrHostCheckFailedError
from .helpers import cleanHostsDict, hostsDictsIntersection, hostsDictsUnion

from kernel.util.sys.user import getUserName
from kernel.util.functional import singleton
from kernel.util import logging

try:
    from api.logger import SkynetLoggingHandler
except ImportError:
    SkynetLoggingHandler = None


def _transform_hostnames(dnf, f):
    def t(literal):
        if literal.prefix == 'h':
            return Literal(f(literal.name), literal.prefix)
        else:
            return literal
    (dnf, errors) = transform_literals(dnf, t)
    if errors:
        raise HrHostCheckFailedError(errors)
    return dnf


def _extract_family(conf_id):
    pos = conf_id.rfind('-')
    return conf_id[0:pos] if pos > 0 else ''


def _split_family(family):
    items = family.split(':')
    return items[0], items[1] if len(items) > 1 else None


def _process_tags(value):
    # do we have tag?
    parts = value.split('?')
    return parts[0], set(parts[1].split(',')) if len(parts) > 1 else None


@singleton
def _get_metrics_log():
    log = logging.getLogger('resolver.metrics')
    log.propagate = False
    log.setLevel(logging.INFO)
    if not log.handlers:
        if SkynetLoggingHandler is not None:
            h = SkynetLoggingHandler(app='hostresolver', filename='hostresolver.log')
            h.setLevel(logging.INFO)
            log.addHandler(h)
        else:
            log.addHandler(logging.NullHandler())

    return log


class ResolverBase(object):
    def __init__(self, host_source, check_host_name, no_tag, use_service_resolve_policy=None):
        self.parser = Parser()
        self.host_source = host_source
        self.use_service_resolve_policy = use_service_resolve_policy
        try:
            from api.hq import HQResolver
            try:
                self.hq_resolver = HQResolver(use_service_resolve_policy=use_service_resolve_policy)
            except TypeError:  # no support for use_service_resolve_policy
                self.hq_resolver = HQResolver()
        except ImportError:
            # Skynet before 16.1.0 may not have this module
            self.hq_resolver = None

        # import GenCfg after HQ, to avoid import of old requests module
        from api.gencfg import GenCfg
        self.gencfg_source = GenCfg()

        self.check_host_name = check_host_name
        self.no_tag = no_tag

    @staticmethod
    def _write_stats(dnf):
        """write info about clauses to special log"""
        exprs = {
            str(literal)
            for clause in dnf.clauses
            for literal in chain(clause.positive_literals, clause.negative_literals)
        }
        _get_metrics_log().info("[%s] %s %s", getUserName(), time.time(), ' '.join(exprs))

    def resolve_hosts(self, command):
        resolved_hosts = set()

        dnf = self.parser.parse(command)
        dnf = _transform_hostnames(dnf, self.check_host_name)

        self._write_stats(dnf)

        for clause in dnf.clauses:
            h = self._resolve_hosts_by_clause(clause)
            resolved_hosts.update(h)

        return resolved_hosts

    def _resolve_hosts_by_clause(self, clause):
        positive_dict = clause.get_positive_names_by_prefixes()
        negative_dict = clause.get_negative_names_by_prefixes()

        hosts = self.__get_hosts_groups_intersection(
            positive_dict['h'],
            clause.get_p_literals_by_prefix_set("HKkPDdlGWwtfQqCMmzYpbISs")
        )

        if not hosts:
            return set()

        hosts -= negative_dict['h']
        hosts -= self.host_source.get_hosts_by_groups(clause.get_n_literals_by_prefix_set("HKkPDdlGWwtfQqCMmzYpbISs"))

        return hosts

    def resolve_shards(self, command):
        resolved_hosts = self.resolve_instances(command)
        for host, value in six.iteritems(resolved_hosts):
            resolved_hosts[host] = set(map(lambda x: x[0], value)) - {'none'}
        return cleanHostsDict(resolved_hosts)

    def resolve_instances(self, command):
        resolved_hosts = defaultdict(set)

        dnf = self.parser.parse(command)
        dnf = _transform_hostnames(dnf, self.check_host_name)

        self._write_stats(dnf)

        for clause in dnf.clauses:
            hosts = self._resolve_instances_by_clause(clause)
            for host, set_shard_instance_pairs in six.iteritems(hosts):
                resolved_hosts[host].update(set_shard_instance_pairs)

        return cleanHostsDict(resolved_hosts)

    def _resolve_instances_by_clause(self, clause):
        # this is a shortcut, whole clause resolves to false
        if clause.positive_literals & clause.negative_literals:
            return {}

        # resolve positive literals
        positive_dict = clause.get_positive_names_by_prefixes()

        hosts = self.__get_hosts_groups_intersection(
            positive_dict['h'],
            clause.get_p_literals_by_prefix_set("HKkPDQqdlISs")
        )

        # if we have raw hosts and/or groups and have no intersection
        if not hosts and hosts is not None:
            return {}

        # resolve CMS entities
        configurations = positive_dict['C']
        gencfg_groups = positive_dict['G']
        families = positive_dict['f']

        result_instances = {}
        if hosts and len(hosts) > 0:
            for g in gencfg_groups:
                items = g.split(':')
                result_instances = hostsDictsIntersection(
                    result_instances,
                    self.gencfg_source.get_instances(items[0], items[1] if len(items) > 1 else None, self.no_tag, hosts)
                ) if result_instances else \
                    self.gencfg_source.get_instances(items[0], items[1] if len(items) > 1 else None, self.no_tag, hosts)

            if self.hq_resolver:
                if families:
                    for f in families:
                        f, tags = _process_tags(f)
                        # family can contain configuration id, so we try exclude it here
                        f, c = _split_family(f)
                        result_instances = hostsDictsIntersection(
                            result_instances,
                            self.hq_resolver.get_instances(f, c, hosts, tags)
                        ) if result_instances else \
                            self.hq_resolver.get_instances(f, c, hosts, tags)
                if configurations:
                    for c in configurations:
                        c, tags = _process_tags(c)
                        f = _extract_family(c)
                        if not f:
                            continue
                        result_instances = hostsDictsIntersection(
                            result_instances,
                            self.hq_resolver.get_instances(f, c, hosts, tags)
                        ) if result_instances else \
                            self.hq_resolver.get_instances(f, c, hosts, tags)
        else:
            for g in gencfg_groups:
                items = g.split(':')
                result_instances = hostsDictsIntersection(
                    result_instances,
                    self.gencfg_source.get_instances(items[0], items[1] if len(items) > 1 else None, self.no_tag)
                ) if result_instances else \
                    self.gencfg_source.get_instances(items[0], items[1] if len(items) > 1 else None, self.no_tag)

            if self.hq_resolver:
                if families:
                    for f in families:
                        f, tags = _process_tags(f)
                        f, c = _split_family(f)
                        result_instances = hostsDictsIntersection(
                            result_instances,
                            self.hq_resolver.get_instances(f, c, tags=tags)
                        ) if result_instances else \
                            self.hq_resolver.get_instances(f, c, tags=tags)
                if configurations:
                    for c in configurations:
                        c, tags = _process_tags(c)
                        f = _extract_family(c)
                        if not f:
                            continue
                        result_instances = hostsDictsIntersection(
                            result_instances,
                            self.hq_resolver.get_instances(f, c, tags=tags)
                        ) if result_instances else \
                            self.hq_resolver.get_instances(f, c, tags=tags)

        # subtract negative instances
        negative_dict = clause.get_negative_names_by_prefixes()

        # resolve negative hosts
        hosts = negative_dict['h']
        hosts |= self.host_source.get_hosts_by_groups(clause.get_n_literals_by_prefix_set('HKkPDQqdlISs'))
        for h in hosts:
            try:
                del result_instances[h]
            except KeyError:
                pass

        negative_instances = {}

        if self.hq_resolver:
            for f in negative_dict['f']:
                f, tags = _process_tags(f)
                f, c = _split_family(f)
                negative_instances = hostsDictsUnion(
                    negative_instances,
                    self.hq_resolver.get_instances(f, c, tags=tags)
                )
            for c in negative_dict['C']:
                c, tags = _process_tags(c)
                f = _extract_family(c)
                if not f:
                    continue
                negative_instances = hostsDictsUnion(
                    negative_instances,
                    self.hq_resolver.get_instances(f, c, tags=tags)
                )

        for g in negative_dict['G']:
            items = g.split(':')
            negative_instances = hostsDictsUnion(
                negative_instances,
                self.gencfg_source.get_instances(items[0], items[1] if len(items) > 1 else None, self.no_tag)
            )

        hosts = set(six.iterkeys(result_instances)) & set(six.iterkeys(negative_instances))
        for host in hosts:
            result_instances[host] = result_instances[host] - negative_instances[host]

        return cleanHostsDict(result_instances)

    def resolve_slots(self, command):
        resolved_hosts = set()

        dnf = self.parser.parse(command)
        dnf = _transform_hostnames(dnf, self.check_host_name)

        self._write_stats(dnf)

        for clause in dnf.clauses:
            h = self._resolve_slots_by_clause(clause)
            resolved_hosts.update(h)

        return resolved_hosts

    def _resolve_slots_by_clause(self, clause):
        # this is a shortcut, whole clause resolves to false
        if clause.positive_literals & clause.negative_literals:
            return set()

        # resolve literals
        positive_dict = clause.get_positive_names_by_prefixes()
        negative_dict = clause.get_negative_names_by_prefixes()

        hosts = self.__get_hosts_groups_intersection(
            positive_dict['h'],
            clause.get_p_literals_by_prefix_set("HKkPDQqdl")
        )

        # if we have raw hosts and/or groups and have no intersection
        if not hosts and hosts is not None:
            return {}

        # resolve ISS entities
        families = positive_dict['f']

        tmp_slots = defaultdict(list)
        for f in families:
            tmp_slots = hostsDictsIntersection(tmp_slots, self.__get_slots_grouped_by_host(f, hosts)) \
                if tmp_slots else self.__get_slots_grouped_by_host(f, hosts)

        # resolve negative hosts
        hosts = negative_dict['h']
        hosts |= self.host_source.get_hosts_by_groups(clause.get_n_literals_by_prefix_set("HKkPDQqdl"))
        for h in hosts:
            try:
                del tmp_slots[h]
            except KeyError:
                pass

        result_slots = []
        for slot in six.itervalues(tmp_slots):
            result_slots.extend(slot)

        return result_slots

    def __get_hosts_groups_intersection(self, raw_hosts, groups):
        """
        Returns None if we have no hosts and groups
        Else returns an intersection of hosts and groups.
        """
        if not raw_hosts and not groups:
            return None

        hosts = set()
        if len(raw_hosts) > 1:
            return hosts

        if groups:
            hosts = self.host_source.get_hosts_by_groups([groups.pop()])
            for group in groups:
                hosts.intersection_update(self.host_source.get_hosts_by_groups([group]))
            if raw_hosts:
                hosts.intersection_update(raw_hosts)
        else:
            if raw_hosts:
                hosts = raw_hosts

        return hosts

    def __get_slots_grouped_by_host(self, family, hosts):
        slots = defaultdict(list)

        if not self.hq_resolver:
            return slots

        f, tags = _process_tags(family)
        items = f.split(':')
        raw_slots = self.hq_resolver.get_slots(items[0], items[1] if len(items) > 1 else None, hosts, tags)
        for slot in raw_slots:
            slots[slot[0]].append(slot)

        return slots
