import pytz
import json
import retry
import pymongo
import requests

from mongo import db_topology_commits, mongo_db
from logger import Logger

import base64
import zlib

API_GENCFG_URL = 'https://api.gencfg.yandex-team.ru'
API_GENCFG_URL_TIMEOUT = 100

NEW_API_GENCFG_URL = 'https://aping.gencfg.yandex-team.ru'
NEW_API_GENCFG_URL_TIMEOUT = 100

API_STAFF_URL = 'https://staff-api.yandex-team.ru/v3'
API_STAFF_URL_TIMEOUT = 100

logger = Logger('states')


class PartialUpdate(RuntimeError):
    message = 'some updates not succeeded'


class States(object):
    def __init__(self):
        self._cache = {}

    @logger.time_counting()
    def update(self):
        self._cache['groups_owners'] = self.get_groups_owners_trunk()
        self._cache['groups'] = self.get_groups_trunk()
        self._cache['gencfg_status'] = self.get_gencfg_test_status()
        self._cache['constants'] = self.get_constancts_trunk()

    @staticmethod
    @logger.time_counting()
    def load_slb_names():
        trunk_slbs = _retry_get_gencfg_url('trunk/slbs')
        if isinstance(trunk_slbs, dict):
            slb_names = trunk_slbs.get('slbs', [])
        else:
            slb_names = trunk_slbs
        return [x['fqdn'] for x in slb_names]

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_group_card(group, tag=None):
        if tag is None or tag == 'trunk':
            group_card = _get_gencfg_api('/', group=group, card=1)
        elif tag and tag == 'online':
            group_card = States.get_group_card_online(group)
        else:
            group_card = States.get_group_card_tag(group, tag)

        if not group_card or 'error' in group_card:
            return {}
        return group_card

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_group_card_tag(group, tag):
        url = 'tags/{}/groups/{}/card'.format(tag, group)
        return _retry_get_gencfg_url(url)

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_group_card_online(group):
        url = 'online/groups/{}/card'.format(group)
        return _retry_get_gencfg_url(url)

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_group_instances(group, tag=None):
        if tag is None or tag == 'trunk':
            group_instances = [x for x in _get_gencfg_api('/', group=group).itervalues()]
        elif tag and tag == 'online':
            group_instances = States.get_group_instances_online(group)
        else:
            group_instances = States.get_group_instances_tag(group, tag)

        if not group_instances or 'error' in group_instances:
            return {}
        return group_instances

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_group_instances_tag(group, tag):
        url = 'tags/{}/searcherlookup/groups/{}/instances'.format(tag, group)
        # url = 'tags/{}/groups/{}/instances'.format(tag, group)
        return _retry_get_gencfg_url(url).get('instances', [])

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_group_instances_online(group):
        url = 'online/groups/{}/instances'.format(group)
        return _retry_get_gencfg_url(url).get('instances', [])

    @logger.time_counting()
    @logger.exception_logger('status', raise_on_catch=False)
    def get_slb_names(self, tag=None):
        if tag:
            return self.get_slb_names_tag(tag)
        return self._cache['constants']['slb_names']

    @logger.time_counting()
    @logger.exception_logger('status', raise_on_catch=False)
    def get_slb_names_tag(self, tag):
        url = 'tags/{}/slbs'.format(tag)
        slb_names = _retry_get_gencfg_url(url).get('slbs', [])
        return [x['fqdn'] for x in slb_names]

    @logger.time_counting()
    @logger.exception_logger('status', raise_on_catch=False)
    def get_group_diff_tags(self, group, tag_or_online, tag_or_trunk):
        print('Diff between: {} {}'.format(tag_or_online, tag_or_trunk))

        try:
            left_group_card = States.get_group_card(group, tag_or_online)
            right_group_card = States.get_group_card(group, tag_or_trunk)
            left_group_instances = States.get_group_instances(group, tag_or_online)
            right_group_instances = States.get_group_instances(group, tag_or_trunk)
        except requests.ConnectionError:
            return None

        if not left_group_card or not right_group_card or not left_group_instances or not right_group_instances:
            return None

        left_group_card['hosts'] = [x['hostname'] for x in left_group_instances]
        right_group_card['hosts'] = [x['hostname'] for x in right_group_instances]

        left_porto_powers = [x['power'] for x in left_group_instances]
        left_group_card['porto_power'] = sum(left_porto_powers) / len(left_group_instances)

        right_porto_powers = [x['power'] for x in right_group_instances]
        right_group_card['porto_power'] = sum(right_porto_powers) / len(right_group_instances)

        left_porto_memory = [x['porto_limits']['memory_guarantee'] for x in left_group_instances]
        left_group_card['porto_memory_guarantee'] = sum(left_porto_memory) / len(left_group_instances)

        right_porto_memory = [x['porto_limits']['memory_guarantee'] for x in right_group_instances]
        right_group_card['porto_memory_guarantee'] = sum(right_porto_memory) / len(left_group_instances)

        left_porto_memory = [x['porto_limits']['memory_limit'] for x in left_group_instances]
        left_group_card['porto_memory_limit'] = sum(left_porto_memory) / len(left_group_instances)

        right_porto_memory = [x['porto_limits']['memory_limit'] for x in right_group_instances]
        right_group_card['porto_memory_limit'] = sum(right_porto_memory) / len(left_group_instances)

        left_group_card['volumes'] = self.simplify_volumes(left_group_card['reqs']['volumes'])
        right_group_card['volumes'] = self.simplify_volumes(right_group_card['reqs']['volumes'])

        if not left_group_card or not right_group_card:
            return None

        print('left_group_card: {}'.format(left_group_card['volumes']))
        print('right_group_card: {}'.format(right_group_card['volumes']))

        diff_group_card = States.count_group_diff(left_group_card, right_group_card, [
            'resources.ninstances',
            'owners',
            'reqs.instances.power',
            'reqs.instances.disk',
            'reqs.instances.ssd',
            'legacy.funcs.instancePort',
            'tags.ctype',
            'tags.itype',
            'tags.metaprj',
            'tags.itag',
            'tags.prj',
            'properties.hbf_parent_macros',
            'properties.internet_tunnel',
            'properties.ipip6_ext_tunnel',
            'properties.ipip6_ext_tunnel_v2',
            'properties.mtn.tunnels.hbf_slb_name',
            'properties.mtn.export_mtn_to_cauth',
            'properties.mtn.portovm_mtn_addrs',
            'properties.mtn.use_mtn_in_config',
            'hosts',
            'porto_power',
            'porto_memory_limit',
            'porto_memory_guarantee',
            'volumes'
        ])

        return diff_group_card

    @logger.time_counting()
    @logger.exception_logger('status', raise_on_catch=False)
    def get_extra_groups_on_host(self, hostname):
        def get_group_name_from_tags(tags):
            for tag in tags:
                if tag.startswith('a_topology_group-'):
                    return tag.replace('a_topology_group-', '')
            return ''

        last_tag = _retry_get_gencfg_url('trunk/tags')['displayed_tags'][0]
        hostname = hostname if '.' in hostname else '{}.search.yandex.net'.format(hostname)
        host_data = requests.get('https://clusterstate.yandex-team.ru/api/v1/hosts/{}'.format(hostname)).json()
        gencfg_data = _retry_get_gencfg_url('tags/{}/hosts/{}/instances_tags'.format(last_tag, hostname))

        if 'i' not in host_data or 'instances_tags' not in gencfg_data:
            return None

        host_data_groups = {}
        for key, value in host_data['i'].iteritems():
            if not value.get('group_from_tag'):
                continue
            group_key = (value['group_from_tag'], key.split(':')[-1])
            host_data_groups[group_key] = self.tag_name_by_number(value['version_from_tag'])

        gencfg_data_groups = {}
        for key, value in gencfg_data['instances_tags'].iteritems():
            group_name = get_group_name_from_tags(value)
            if not group_name:
                continue
            group_key = (group_name, key.split(':')[-1])
            gencfg_data_groups[group_key] = 'trunk'

        extra_groups = [x for x in set(host_data_groups.iterkeys()).difference(set(gencfg_data_groups.iterkeys()))]

        diff = {}
        for group_name, port in extra_groups:
            diff[group_name] = (port, 0, host_data_groups[(group_name, port)])
            for trunk_group_name, trunk_port in gencfg_data_groups:
                if group_name == trunk_group_name and port != trunk_port:
                    print('FOUND: {} {} {}'.format(group_name, port, trunk_port))
                    diff[group_name] = (port, trunk_port, host_data_groups[(group_name, port)])

        return diff

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_staff_user_groups(user, token):
        url = 'persons?login={}&_fields=department_group.ancestors.url,groups.group.url'.format(user)
        results = _retry_get_staff_url(url, token)

        groups = []
        for obj in results.get('result', []):
            ancestors = obj['department_group']['ancestors']
            for group in ancestors:
                groups.append(group['url'])

            staff_groups = obj['groups']
            for group in staff_groups:
                groups.append(group['group']['url'])

        return groups

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_groups_trunk(tag=None):
        tag = 'trunk' if not tag else 'tags/{}'.format(tag)
        url = '{}'.format('{}/groups'.format(tag))
        return _retry_get_gencfg_url(url).get('group_names', [])

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_groups_owners_trunk():
        url = '{}'.format('trunk/groups_owners')
        return _retry_get_gencfg_url(url)['resolved_owners']

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_group_info(group, tag=None):
        tag = 'trunk' if not tag else 'tags/{}'.format(tag)
        url = '{}/groups/{}'.format(tag, group)
        return _retry_get_gencfg_url(url)

    @logger.time_counting()
    @logger.exception_logger()
    def get_group_hosts_info(self, group, tag=None):
        instances = self.get_group_instances(group, tag)

        list_hostnames = [i['hostname'] for i in instances if 'hostname' in i]
        group_hosts_info = self.get_hosts_hardware(list_hostnames)

        return group_hosts_info

    @logger.time_counting()
    @logger.exception_logger('status', raise_on_catch=False)
    def get_group_diff(self, group):
        trunk_states = States.get_group_instances(group)
        online_states = States.get_group_instances(group, 'online')

        # online_url = 'online/groups/{}/instances'.format(group)
        # online_states = _retry_get_gencfg_url(online_url).get('instances', [])

        tr_mid_state = self.count_mid_states(trunk_states)
        on_mid_state = self.count_mid_states(online_states)

        difference = self.count_diff_states(tr_mid_state, on_mid_state)

        return difference

    @logger.time_counting()
    @logger.exception_logger('status')
    def get_group_online(self, group):
        online_url = 'online/groups/{}/instances'.format(group)
        online_states = _retry_get_gencfg_url(online_url)

        if not online_states:
            return {}

        online_states = online_states.get('instances', [])
        on_mid_state = self.count_mid_states(online_states)

        return on_mid_state

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_hosts_info():
        hosts_info = {}
        with open('hosts_data.json', 'r') as out:
            for host_info in json.loads(out.read()).get('hosts_data', []):
                hosts_info[host_info['name'] + host_info['domain']] = host_info
            return hosts_info

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_hosts_groups(substr):
        hostname = '{}.search.yandex.net'.format(substr) if '.yandex.net' not in substr else substr
        results = gencfg_api('/trunk/hosts/hosts_to_groups/{}'.format(hostname))
        return results.get('groups_by_host', [])

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_hosts_hardware(list_hostnames):
        results = gencfg_api('/trunk/hosts/hosts_to_hardware', json={'hosts': list_hostnames})
        return results['hardware_by_host']

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_gencfg_test_status():
        data = {}  # commit:xx, test_passed:yy
        commits = db_topology_commits.commits.find(sort=[('_id', -1)]).limit(100)

        for record in commits:
            data[record['commit']] = {
                'test_passed': record.get('test_passed'),
                'time': (record['_id'].generation_time.astimezone(pytz.timezone('Europe/Moscow'))).strftime('%m-%d %H:%M'),
                'author': record.get('author', ''),
                'task_id': record.get('task_id', ''),
            }

        tags = db_topology_commits.tags.find(sort=[('_id', -1)]).limit(50)
        for record in tags:
            if record['commit'] in data:
                data[record['commit']].update({'tag': record['tag']})

        data = [dict(data[key], commit=key) for key in sorted(data, reverse=True)]
        return data

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_gencfg_tags(limit=1000):
        if limit is not None:
            tags = list(db_topology_commits.tags.find(sort=[('_id', -1)]).limit(limit))
        else:
            tags = list(db_topology_commits.tags.find(sort=[('_id', -1)]))
        return [x['tag'] for x in tags if 'tag' in x]

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_constancts_trunk():
        data = {}
        constancts = db_topology_commits.gencfg_constants.find()
        for record in constancts:
            value = record['value']
            if record.get("compression") == "zlib_b64":
                # abc_services was too big for mongo, so we encoded it
                zlibbed = base64.b64decode(value)
                decompressed = zlib.decompress(zlibbed)
                json_list = json.loads(decompressed)
                value = list(json_list)

            data[record['type']] = value
        data['slb_names'] = States.load_slb_names()
        return data

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def group_of_hostname(host):
        for rec in db_topology_commits.gencfg_dns.find(
                {'hostname': host},
                {'_id': 0, 'hostname': 1, 'group': 1}
        ).sort('commit', -1):
            if rec['group']:
                return rec['group']

    @staticmethod
    @logger.time_counting()
    @logger.exception_logger()
    def get_requests(filter, limit=200):
        return list(mongo_db.background_tasks.find(filter).sort(
            'added', pymongo.DESCENDING
        ).limit(limit))

    @staticmethod
    def get_reserved_hosts(group):
        reserved_hosts = set()
        for location in ['MSK', 'MAN', 'SAS', 'VLA']:
            group_info = States.get_group_info("{}_RESERVED".format(location))
            reserved_hosts |= set(group_info['hosts'])

        return reserved_hosts

    def get_my_groups(self, user, token=None):
        groups = self._cache['groups_owners'].get(user, [])
        return sorted(groups)

    def get_all_groups(self, tag=None):
        if tag:
            return sorted(self.get_groups_trunk(tag))
        return sorted(self._cache['groups'])

    def get_gencfg_revisions(self):
        return self._cache['gencfg_status']

    def get_gencfg_status(self):
        return self._cache.get('gencfg_status', [True])[0]

    def get_trunk_ctypes(self):
        return sorted(self._cache['constants']['ctypes'])

    def get_trunk_itypes(self):
        return sorted(self._cache['constants']['itypes'])

    def get_trunk_metaprjs(self):
        return sorted(self._cache['constants']['metaprjs'])

    def get_trunk_dispenser_projects(self):
        return self._cache['constants']['dispenser_projects']
        # return {'yandex': {'resolved_acl': []}, 'gencfg': {'resolved_acl': ['shotinleg']}}

    def get_user_dispenser_projects(self, user):
        return {k: v for k, v in self.get_trunk_dispenser_projects().items() if user in v['resolved_acl']}

    def get_trunk_hbf_macroses(self):
        return sorted(self._cache['constants']['hbf_macroses'], key=lambda x: x['name'])

    def get_user_hbf_macroses(self, user):
        return sorted(
            [x for x in self._cache['constants']['hbf_macroses'] if user in x.get('resolved_owners', [])],
            key=lambda x: x['name']
        )

    def get_trunk_hbf_macro(self, macro_name):
        for macro in self._cache['constants']['hbf_macroses']:
            if macro['name'] == macro_name:
                return macro
        return {'name': '', 'parent_macro': None, 'resolved_owners': []}

    def get_trunk_hbf_ranges(self):
        return self._cache['constants']['hbf_ranges']

    def get_user_hbf_ranges(self, user):
        return {k: v for k, v in self.get_trunk_hbf_ranges().items() if k == '_GENCFG_SEARCHPRODNETS_ROOT_' or user in v['resolved_acl']}

    def get_trunk_available_owners(self):
        staff_users = self._cache['constants'].get('staff_users', [])
        staff_groups = self._cache['constants'].get('staff_groups', [])
        abc_services = self._cache['constants'].get('abc_services', [])
        abc_services = filter(lambda x: x.count(":") <= 1, abc_services)  # GENCFG-4567

        return staff_users + staff_groups + abc_services

    @staticmethod
    @logger.time_counting()
    def count_mid_states(states):
        mid_state = {
            'hosts': [],
            'power': 0,
            'porto_limits.memory_guarantee': 0,
            'porto_limits.memory_limit': 0
        }

        for state in states:
            mid_state['hosts'].append(state['hostname'])
            mid_state['power'] += state.get('power', 0)

            porto_limits = state.get('porto_limits', {'memory_guarantee': 0, 'memory_limit': 0})
            mid_state['porto_limits.memory_guarantee'] += porto_limits.get('memory_guarantee', 0)
            mid_state['porto_limits.memory_limit'] += porto_limits.get('memory_limit', 0)

        for key in mid_state:
            if isinstance(mid_state[key], int) or isinstance(mid_state[key], float):
                mid_state[key] /= len(mid_state['hosts']) if mid_state['hosts'] else 1

        return mid_state

    @staticmethod
    @logger.time_counting()
    def count_diff_states(new_state, old_state):
        def format_unit(key, value):
            if key in ['porto_limits.memory_guarantee', 'porto_limits.memory_limit']:
                return float(value) / (1024 * 1024 * 1024)  # GB
            else:
                return value

        difference = {
            'diff_count_hosts': (
                len(old_state['hosts']), len(new_state['hosts']), len(new_state['hosts']) - len(old_state['hosts'])
            ),
            'new_hosts': ([], [], set(new_state['hosts']).difference(set(old_state['hosts']))),
            'del_hosts': ([], [], set(old_state['hosts']).difference(set(new_state['hosts'])))
        }

        for key in new_state:
            if key != 'hosts':
                difference[key] = (
                    format_unit(key, old_state[key]),
                    format_unit(key, new_state[key]),
                    format_unit(key, new_state[key] - old_state[key])
                )

        return difference

    @staticmethod
    @logger.time_counting()
    def count_group_diff(left_data, right_data, data_paths):
        diff_data = {x: None for x in data_paths}
        for path in data_paths:
            list_keys = path.split('.')

            left_field = left_data
            right_field = right_data

            try:
                for key in list_keys:
                    left_field = left_field[key]
                    right_field = right_field[key]
            except KeyError:
                continue

            if left_field != right_field:
                diff_data[path] = States.diff_field(left_field, right_field)

        print('diff_data: {}'.format(diff_data))

        return diff_data

    @staticmethod
    def diff_field(left_field, right_field):
        if left_field == right_field:
            return left_field, right_field, None

        if isinstance(left_field, int) or isinstance(left_field, float):
            return left_field, right_field, right_field - left_field
        elif isinstance(left_field, list) and isinstance(right_field, list):
            diff_add = set(right_field).difference(set(left_field))
            diff_del = set(left_field).difference(set(right_field))

            diff = ['-{}'.format(x) for x in diff_del]
            diff.extend(['+{}'.format(x) for x in diff_add])

            if not diff:
                return None

            return left_field, right_field, diff
        return left_field, right_field, right_field

    @staticmethod
    def simplify_volumes(volumes):
        simplified = []
        for volume in sorted(volumes, key=lambda x: x['guest_mp']):
            new_volume = []
            for k in sorted(volume):
                if k not in ('guest_mp', 'host_mp_root', 'quota', 'symlinks'):
                    continue

                if k in ('quota',):
                    new_volume.append('{} = {:.2f}'.format(k, volume[k] / 1024. / 1024. / 1024.))
                elif k in ('symlinks',):
                    new_volume.append('{} = {}'.format(k, sorted(volume[k])))
                else:
                    new_volume.append('{} = {}'.format(k, volume[k]))
            simplified.append('<br>&nbsp;&nbsp;&nbsp;&nbsp;'.join(new_volume))
        return simplified

    @staticmethod
    def tag_name_by_number(tag_number):
        tag_number = str(tag_number)
        parts = [x for x in tag_number.split('00') if x]
        return 'stable-{}-r{}'.format(parts[0].strip('0'), parts[1].strip('0'))


@logger.time_counting()
@retry.retry(tries=5, delay=1)
def gencfg_api(url, udata=None, data=None, json=None, timeout=None):
    url = url if url.startswith('/') else '/' + url
    udata = udata or {}
    timeout = timeout or API_GENCFG_URL_TIMEOUT
    try:
        for key, value in udata.items():
            sep = '&' if '?' in url else '?'
            url = '{}{}{}={}'.format(url, sep, key, value)
        print('DEBUG {}: {}'.format(API_GENCFG_URL + url, json))
        json_data = requests.get(API_GENCFG_URL + url, data=data, json=json, timeout=timeout)
        return json_data.json()
    except ValueError:
        return {}
    return json_data


@logger.time_counting()
@retry.retry(tries=5, delay=1)
def _get_gencfg_api(url, **kwargs):
    url = url if url.startswith('/') else '/' + url
    try:
        for key, value in kwargs.iteritems():
            sep = '&' if '?' in url else '?'
            url = '{}{}{}={}'.format(url, sep, key, value)
        json_data = requests.get(NEW_API_GENCFG_URL + url, timeout=NEW_API_GENCFG_URL_TIMEOUT).json()
    except Exception:
        return {}
    return json_data


@logger.time_counting()
@retry.retry(tries=5, delay=1)
def _retry_get_gencfg_url(pathquery):
    if not pathquery.startswith('/'):
        pathquery = '/' + pathquery
    try:
        json_data = requests.get(API_GENCFG_URL + pathquery, timeout=API_GENCFG_URL_TIMEOUT).json()
    except Exception:
        return {}
    return json_data


@logger.time_counting()
@retry.retry(tries=3, delay=1)
def _retry_get_staff_url(pathquery, token):
    if not pathquery.startswith('/'):
        pathquery = '/' + pathquery
    try:
        json_data = requests.get(
            API_STAFF_URL + pathquery,
            headers={'Authorization': 'OAuth {}'.format(token)},
            timeout=API_STAFF_URL_TIMEOUT
        ).json()
    except Exception:
        return {}
    return json_data
