from helpers import cleanHostsDict, hostsDictsIntersection, hostsDictsUnion, getAlignedKwargs


class ClauseResolver(object):
    def __init__(self, clause, host_source, cms_source):
        self.clause = clause
        self.host_source = host_source
        self.cms_source = cms_source
        self.positive_hosts = None
        self.positive_instances = None
        self.meta_store = None

    def resolve(self):
        self._resolve_positive_hosts()
        if self._is_empty_set():
            return {}

        self._resolve_positive_instances()
        self._resolve_negative_hosts()
        return self._resolve_negative_instances()

    def _resolve_positive_hosts(self):
        self.positive_hosts = resolve_hosts_groups_intersection(
            self.host_source,
            self.clause.get_positive_names_by_prefixes()['h'],
            self.clause.get_p_literals_by_prefix_set('HKdl')
        )

    def _resolve_positive_instances(self):
        positive = self.clause.get_positive_names_by_prefixes()
        shards = positive['s']
        shard_tags = positive['S']
        instance_tags = positive['I']
        configurations = positive['C'] or {'HEAD'}
        meta = {}

        if self.positive_hosts and len(self.positive_hosts) > 1:
            result = {}

            # get any config to intersect
            some_configuration = list(configurations)[0]
            for h in self.positive_hosts:
                result = hostsDictsUnion(
                    result,
                    self.cms_source.get_instances({'conf': some_configuration, 'host': h})
                )

            # align four types of clauses and send queries to CMS by quartets instead of single clauses
            # (actually magic heuristic)
            kwargslist = getAlignedKwargs(
                zip(
                    ('conf', 'shardTagName', 'instanceTagName', 'shard'),
                    (configurations, shard_tags, instance_tags, shards)
                )
            )

            for kwargs in kwargslist:
                result = hostsDictsIntersection(
                    result,
                    self.cms_source.get_instances(kwargs)
                )
        else:
            # align four types of clauses and send queries to CMS by quartets instead of single clauses
            # (actually magic heuristic)
            kwargslist = getAlignedKwargs(
                zip(
                    ('conf', 'shardTagName', 'instanceTagName', 'shard', 'host'),
                    (configurations, shard_tags, instance_tags, shards, self.positive_hosts or set())
                )
            )

            result_with_meta = self.cms_source.get_instances_with_meta(kwargslist.pop())
            meta.update(result_with_meta['meta'])
            result = result_with_meta['instances']
            for kwargs in kwargslist:
                result_with_meta = self.cms_source.get_instances_with_meta(kwargs)
                meta.update(result_with_meta['meta'])
                result = hostsDictsIntersection(
                    result,
                    result_with_meta['instances']
                )

        self.positive_instances = result
        self.meta_store = meta

    def _resolve_negative_hosts(self):
        hosts = self.clause.get_negative_names_by_prefixes()['h']
        hosts |= self.host_source.getHostsByGroups(self.clause.get_n_literals_by_prefix_set('HKdl'))
        for h in hosts:
            self.positive_instances.pop(h, None)

    def _resolve_negative_instances(self):
        negative = self.clause.get_negative_names_by_prefixes()
        configurations = self.clause.get_positive_names_by_prefixes()['C'] or {'HEAD'}
        some_configuration = list(configurations)[0]
        instances = {}

        for s in negative['s']:
            instances = hostsDictsUnion(
                instances,
                self.cms_source.get_instances({'conf': some_configuration, 'shard': s})
            )

        for st in negative['S']:
            instances = hostsDictsUnion(
                instances,
                self.cms_source.get_instances({'conf': some_configuration, 'shardTagName': st})
            )

        for it in negative['I']:
            instances = hostsDictsUnion(
                instances,
                self.cms_source.get_instances({'conf': some_configuration, 'instanceTagName': it})
            )

        for c in negative['c']:
            instances = hostsDictsUnion(
                instances,
                self.cms_source.get_instances({'conf': c})
            )

        hosts = set(self.positive_instances.iterkeys()) & set(instances.iterkeys())
        for host in hosts:
            self.positive_instances[host] = self.positive_instances[host] - instances[host]

        return cleanHostsDict(self.positive_instances)

    def _is_empty_set(self):
        return self.positive_hosts is not None and len(self.positive_hosts) == 0


def resolve_hosts_groups_intersection(host_source, plain_hosts, groups):
    """
    Returns None if we have no hosts and groups
    Else returns an intersection of hosts and groups.
    """
    if not plain_hosts and not groups:
        return None

    if len(plain_hosts) > 1:
        return set()

    hosts = set()
    if groups:
        hosts = host_source.getHostsByGroups([groups.pop()])
        for group in groups:
            hosts.intersection_update(host_source.getHostsByGroups([group]))
        if plain_hosts:
            hosts.intersection_update(plain_hosts)
    else:
        if plain_hosts:
            hosts = plain_hosts

    return hosts
