from collections import defaultdict

import socket

from .mongo import searcher_lookup, tags, gencfg_trunk, gencfg_dns, instances_tags
from libraries.json_diff import merge
from libraries.topology.utils import unzipped, tag_to_version
from libraries.utils import memoize


class MongoBrokenData(ValueError):
    pass


@memoize
def _get_tag_base_diff_commits(commit):
    base_commit, diff_commit = None, None
    rec = tags().find_one({'commit': str(commit)})
    tag = rec['tag']
    if 'fullstate' in rec:
        base_commit = rec['commit']
    elif 'diff_to' in rec:
        base_commit = rec['diff_to']
        diff_commit = rec['commit']
    else:
        raise MongoBrokenData()

    return tag, base_commit, diff_commit


def _get_base_diff_data(group, base_commit, diff_commit):
    base = searcher_lookup().find_one({'commit': int(base_commit), 'group': group})
    diff = None
    if diff_commit:
        diff = searcher_lookup().find_one({'commit': int(diff_commit), 'group': group})
    return base, diff


@memoize
def _get_instances_old_way(tag, group, base_commit, diff_commit):
    base, diff = _get_base_diff_data(group, base_commit, diff_commit)

    if diff:
        if 'dead' in diff:
            return None
        elif 'instances' in diff:
            res = unzipped(diff['instances'], use_list=True)
            return _format_instances_tag(res, tag)
        elif 'diff' in diff:
            base = unzipped(base['instances'], use_list=True)
            diff = unzipped(diff['diff'], use_list=True)
            res = merge(base, diff)
            return _format_instances_tag(res, tag)
        else:
            raise MongoBrokenData()

    if base:
        res = unzipped(base['instances'], use_list=True)
        return _format_instances_tag(res, tag)
    return memoize.NotSave(None)


@memoize
def _get_all_data_new_way(tag, group, base_commit, diff_commit):
    base, diff = _get_base_diff_data(group, base_commit, diff_commit)

    if diff:
        if 'dead' in diff:
            return None
        elif 'data' in diff:
            data = unzipped(diff['data'], use_list=True)
            data['instances'] = _format_instances_tag(data['instances'], tag)
            return data
        elif 'data_diff' in diff:
            base = unzipped(base['data'], use_list=True)
            diff = unzipped(diff['data_diff'], use_list=True)
            data = merge(base, diff)
            data['instances'] = _format_instances_tag(data['instances'], tag)
            return data
        else:
            raise MongoBrokenData()

    if base:
        data = unzipped(base['data'], use_list=True)
        data['instances'] = _format_instances_tag(data['instances'], tag)
        return data
    return memoize.NotSave(None)


def _get_all_data_new_way_no_instances(tag, group, base_commit, diff_commit):
    base, diff = _get_base_diff_data(group, base_commit, diff_commit)

    if diff:
        if 'dead' in diff:
            return None
        elif 'data' in diff:
            data = unzipped(diff['data'], use_list=True)
            return data
        elif 'data_diff' in diff:
            base = unzipped(base['data'], use_list=True)
            diff = unzipped(diff['data_diff'], use_list=True)
            data = merge(base, diff)
            return data
        else:
            raise MongoBrokenData()

    if base:
        data = unzipped(base['data'], use_list=True)
        return data
    return None


def load_instances(commit, group):
    tag, base_commit, diff_commit = _get_tag_base_diff_commits(commit)
    if tag_to_version(tag) < tag_to_version('stable-102-r3'):
        return _get_instances_old_way(tag, group, base_commit, diff_commit)
    else:
        data = _get_all_data_new_way(tag, group, base_commit, diff_commit)
        return data and data['instances']


def load_card(commit, group, memo=True):
    tag, base_commit, diff_commit = _get_tag_base_diff_commits(commit)
    if tag_to_version(tag) >= tag_to_version('stable-102-r3'):
        if memo:
            data = _get_all_data_new_way(tag, group, base_commit, diff_commit)
        else:
            data = _get_all_data_new_way_no_instances(tag, group, base_commit, diff_commit)
        return data and data['card']
    else:
        raise RuntimeError('cards stored for tags >= stable-102-r3')


# NB: hostnames encoded: '.' to '!'
def load_instances_tags(commit, host):
    record = instances_tags().find_one({'commit': int(commit), 'host': host.replace('.', '!')})
    assert record is not None
    return {key.replace('!', '.'): value for key, value in record['data'].iteritems()}


def _format_instances_tag(instances, tag):
    res_ = []
    add_topology_tag = tag_to_version(tag) < tag_to_version('stable-94-r1')
    sort_tags = tag_to_version(tag) >= tag_to_version('stable-110-r57')
    for key in instances:
        host, _, port = key.partition(':')
        rec = dict(instances[key], hostname=host.replace('!', '.'), port=int(port))
        if add_topology_tag:
            # before commit 2548288 to populate_searcher_lookup `a_topology_` tags had been removing
            rec['tags'] = list(set(rec['tags']) | {'a_topology_' + tag})
        if sort_tags:
            rec['tags'] = sorted(rec['tags'])
        res_.append(rec)
    return res_


def _format_instances_trunk(instances):
    res_ = []
    for key in instances:
        host, _, port = key.partition(':')
        rec = dict(instances[key], hostname=host.replace('!', '.'), port=int(port))
        res_.append(rec)
    return res_


def load_full_trunk(commit, sleep_method=None):
    result = {}
    for rec in gencfg_trunk().find({'commit': int(commit)}):
        rec['commit'] = int(commit)
        result[rec['group']] = LazyData(rec, _format_trunk)
        if sleep_method:
            sleep_method()
    return result


def load_group_trunk(commit, group):
    for rec in gencfg_trunk().find({
        'commit': int(commit), 'group': group
    }):
        rec['commit'] = int(commit)
        rec = _format_trunk(rec)
        return rec


def load_dns(sleep_method=None):
    direct = defaultdict(set)
    invert = {}

    cnt = 0
    for rec in gencfg_dns().find(
            {},
            {'_id': 0, 'hostname': 1, 'ipv6addr': 1}
    ).sort('commit', 1):
        if 'hostname' not in rec or 'ipv6addr' not in rec:
            continue

        direct[rec['hostname'].lower()].add(rec['ipv6addr'])
        _update_with_newer(invert, ipv6_to_binary(rec['ipv6addr']), rec['hostname'])
        cnt += 1
        if sleep_method and cnt % 100 == 0:
            sleep_method()

    return direct, invert


def dns_direct():
    for rec in gencfg_dns().find(
            {},
            {'_id': 0, 'hostname': 1, 'ipv6addr': 1}
    ).sort('commit', -1):
        seen_host_ips = set()
        if 'hostname' not in rec or 'ipv6addr' not in rec:
            continue
        hostname, ipv6addr = rec['hostname'].lower(), rec['ipv6addr']
        if hostname not in seen_host_ips and ipv6addr not in seen_host_ips:
            seen_host_ips |= {hostname, ipv6addr}
            yield hostname, ipv6addr


def ips_of_hostname(host):
    for rec in gencfg_dns().find(
            {'hostname': host},
            {'_id': 0, 'hostname': 1, 'ipv6addr': 1}
    ).sort('commit', -1):
        yield rec['ipv6addr']


def group_of_hostname(host):
    for rec in gencfg_dns().find(
            {'hostname': host},
            {'_id': 0, 'hostname': 1, 'group': 1}
    ).sort('commit', -1):
        if rec['group']:
            return rec['group']


def hostname_of_ip(ip):
    for rec in gencfg_dns().find(
            {'ipv6addr': ip},
            {'_id': 0, 'hostname': 1, 'ipv6addr': 1}
    ).sort('commit', -1):
        return rec['hostname']


# Dirty hack, should be removed after 1.11.2017
def _update_with_newer(coll, key, value):
    # just insert
    if key not in coll:
        coll[key] = value
        return

    # newer hostname is shorter
    old_value = coll[key]
    if len(value) < len(old_value):
        coll[key] = value


def _format_trunk(data):
    instances = unzipped(data['instances'], use_list=True)
    data['instances'] = _format_instances_trunk(instances)
    return data


def ipv6_to_binary(s):
    return socket.inet_pton(socket.AF_INET6, s)


def canonize_ipv6(addr):
    return socket.inet_ntop(socket.AF_INET6, socket.inet_pton(socket.AF_INET6, addr))


class LazyData(object):
    def __init__(self, data, handler):
        self._data = data
        self._handler = handler
        self._processed = False

    def get(self):
        if not self._processed:
            self._data, self._processed = self._handler(self._data), True
        return self._data

    def get_raw(self):
        return self._data
