import os
import json
import logging
import tempfile
import collections

import requests
import retry

import plugin
import utils


_REQUESTS_TIMEOUT = 30


class Plugin(plugin.BasePlugin):
    def apply_config(self, content):
        mapping = {key: value for key, value in content.iteritems() if value and 'group_number' in value}

        if not mapping:
            logging.debug('EMPTY CONFIG, SKIP!')
            raise RuntimeError('EMPTY CONFIG')

        _put_config(self.host, self.port, mapping)
        current_mapping = self.collect_status()
        if current_mapping != mapping:
            logging.debug('config %s', json.dumps(content, indent=4))
            logging.debug('mapping %s', json.dumps(mapping, indent=4))
            logging.info('reconfiguring')
            self.stop()
        else:
            logging.info('config did not change')

    def collect_status(self):
        status = get_intl2_status(self.host, self.port)
        return status

    def stop(self):
        r = requests.get(_url(self.host, self.port, 'admin', action='shutdown'), timeout=_REQUESTS_TIMEOUT)
        r.raise_for_status()


def _put_config(host, port, mapping):
    config_file = tempfile.NamedTemporaryFile(dir='./', delete=False)
    template_filename = '{}:{}.cfg'.format(utils.get_shortname(host), port)
    logging.debug('config template %s', template_filename)
    with open(template_filename, 'r') as template_file:
        prepare_intl2_config(template_file, config_file, mapping)
        config_file.flush()
        os.chmod(config_file.name, 0644)
        os.rename(config_file.name, template_file.name + '.changed')


COLLECTION = 'Collection autostart="must" meta="yes" id="yandsearch"'
DATA_COMMENT = '#DATA'

SlotMapping = collections.namedtuple('SlotMapping', ['slot', 'group_number', 'timestamp'])


def get_path_str(path):
    ret = ''
    for p in path:
        ret = ret + '->' + p
    return ret


def _parse_config(config):
    ret = []
    # a = (path, index, key, value)
    # a = (['', 'Collection autostart="must" meta="yes" id="yandsearch"', 'ScatterOptions'],
    #      0,
    #      'TimeoutTable',
    #      '${InfoTimeoutTable or '10s'}')

    path = ['']
    path_index = {get_path_str(path): 0}
    for line in [x.strip() for x in config]:
        # section ends
        if line.startswith('</'):
            # section_name = line.split('</')[1].rsplit('>', 1)[0]
            path.pop()

        # section starts
        elif line.startswith('<'):
            section_name = line.split('<')[1].rsplit('>', 1)[0]
            path.append(section_name)

            path_str = get_path_str(path)
            path_index.setdefault(path_str, -1)
            path_index[path_str] += 1

        # key value
        else:
            parts = line.split(' ', 1)
            key = parts[0]
            value = parts[1] if len(parts) > 1 else None
            ret.append((list(path), path_index[get_path_str(path)], key, value))

    ret.append(([''], 0, '', ''))
    return ret


def replace_shard_group(shard, group, timestamp):
    timestamp = str(timestamp)
    group = str(group)

    shard = shard.replace('CallistoSlotsTier0', 'WebFreshTier')
    shard = shard.replace('0000000000', timestamp)

    parts = shard.split('-', 2)
    return parts[0] + '-' + group + '-' + parts[2]


def apply_mapping(config, mapping):
    use_required_base_timestamp = False
    result = []
    for cur in config:
        res = cur
        if cur[0] == ['', COLLECTION, 'SearchSource']:
            slot = str(cur[1])

            if not use_required_base_timestamp:
                result.append((cur[0][:2], cur[1], 'UseRequiredBaseTimestamp',
                               '''${ UseRequiredBaseTimestamp or 'yes' }'''))
                use_required_base_timestamp = True

            if slot not in mapping:
                continue

            if cur[2] == 'PrimusList':
                slot_mapping = SlotMapping(
                    slot=slot,
                    group_number=mapping[slot]['group_number'],
                    timestamp=mapping[slot]['timestamp']
                )
                primus_list = []
                for s in cur[3].split(','):
                    primus_list.append(replace_shard_group(
                        s,
                        mapping[slot]['group_number'],
                        mapping[slot]['timestamp']
                    ))

                result.append((cur[0], cur[1], DATA_COMMENT, '{}'.format(json.dumps(slot_mapping.__dict__))))
                res = (cur[0], cur[1], cur[2], ','.join(primus_list))
                result.append((cur[0], cur[1], 'RequiredBaseTimestamp', str(mapping[slot]['timestamp'])))

            elif cur[2] == 'Tier':
                res = (cur[0], cur[1], cur[2], 'WebFreshTier')

        result.append(res)

    return result


def get_mapping(config):
    result = {}
    for cur in config:
        if cur[0] == ['', COLLECTION, 'SearchSource']:
            if cur[2] == DATA_COMMENT:
                slot_mapping = SlotMapping(**json.loads(cur[3]))
                result[slot_mapping.slot] = {
                    'group_number': slot_mapping.group_number,
                    'timestamp': slot_mapping.timestamp
                }

    return result


def _print_config(config):
    result = ''
    tab = '    '

    prev = ([''], 0, None, None)
    for cur in config:
        if prev[0] != cur[0] or prev[1] != cur[1]:
            prefix = ''
            for i in range(len(prev[0]) - 1):
                prefix += tab

            prev1 = list(prev[0])
            cur1 = list(cur[0])

            for x in prev[0]:
                if x in cur[0]:
                    prev1.remove(x)
            for i in reversed(prev1):
                prefix = prefix.replace(tab, '', 1)
                result += prefix + '</' + i.split(' ', 1)[0] + '>\n'

            for x in cur[0]:
                if x in prev[0]:
                    cur1.remove(x)
            for i in reversed(cur1):
                result += prefix + '<' + i + '>\n'
                prefix += tab

            if prev[0] == cur[0]:
                result += prefix.replace(tab, '', 1) + '</' + cur[0][-1].split(' ', 1)[0] + '>\n'
                result += prefix.replace(tab, '', 1) + '<' + cur[0][-1] + '>\n'

            prev = cur

        prefix = ''
        for i in range(len(cur[0]) - 1):
            prefix += tab

        line = ''
        if cur[2]:
            if cur[3]:
                line = cur[2] + ' ' + cur[3]
            else:
                line = cur[2]

        result += prefix + line + '\n'

    return result.strip()


@retry.retry(exceptions=(RuntimeError, requests.exceptions.ConnectionError), tries=5, delay=1)
def get_intl2_status(host, port):

    r = requests.get(_url(host, port, 'yandsearch', info='getconfig'), timeout=_REQUESTS_TIMEOUT)
    if r.ok:
        config = _parse_config(r.text.splitlines())
        return get_mapping(config)
    else:
        raise RuntimeError('get_intl2_status failed: %s', r.reason)


def prepare_intl2_config(template_file, output_file, mapping):
    config = _parse_config(template_file.readlines())
    config = apply_mapping(config, mapping)
    result = _print_config(config)
    for line in result:
        output_file.write(line)


def _url(host, port, path, action=None, info=None):
    url = 'http://{}:{}/{}'.format(host, port, path)
    if action:
        url = '{}?action={}'.format(url, action)
    elif info:
        url = '{}?info={}'.format(url, info)
    return url
