import datetime
import re
import json
import logging
import subprocess
import os.path
import itertools
import requests
from collections import defaultdict

from retry import retried

logger = logging.getLogger(__name__)

# bsbackups in each group are duplicates
BSBACKUP_GROUPS = ['yabs_bsbackup01_2', 'yabs_bsbackup02']


def get_all_table_sources(date_str=None, tmp_dir='.'):
    """Returns dict: instance.database.table -> bsdb/bstatXX/bsistatXX/bspstatXX/... """
    lines_iter = itertools.chain.from_iterable((iter_bsbackup_group_lines(g, tmp_dir=tmp_dir) for g in BSBACKUP_GROUPS))
    return dict(iter_table_sources(lines_iter, date_str))


def iter_bsbackup_group_lines(group, tmp_dir='.'):
    out_path = os.path.join(tmp_dir, '%s.rsync_list' % group)

    @retried(max_retries=5, initial_timeout=10, sleep=True, exceptions=(subprocess.CalledProcessError,))
    def run_rsync(host, timeout=120):
        timeout = int(timeout)
        cmdline = [
            'timeout', '-sKILL', str(timeout + 5),  # rsync is known to hang!
            'rsync', '--list-only', '--timeout=%s' % timeout, '--contimeout=%s' % timeout, '--recursive', 'rsync://{}/mysql_backups'.format(host)
        ]

        with open(out_path, 'w') as out:
            logger.info("Running %s ...", ' '.join(cmdline))
            subprocess.check_call(cmdline, stdout=out)

    hosts = _group2hosts(group)
    for n, host in enumerate(hosts):
        try:
            run_rsync(host)
        except subprocess.CalledProcessError:
            if n == len(hosts) - 1:
                raise
            logging.info("Host %s ultimately failed, trying the next one.")
        else:
            break

    with open(out_path) as out:
        for line in out:
            yield line


def iter_table_sources(bsbackup_lines_iter, needed_date_str=None):
    if needed_date_str is None:
        needed_date_str = (datetime.datetime.now() - datetime.timedelta(hours=10)).strftime('arc%Y%m%d')
    else:
        needed_date_str = 'arc' + needed_date_str

    table_servers = defaultdict(set)
    all_servers = set()

    for line in bsbackup_lines_iter:
        _, _, table_date, table_time, path = line.split(None, 5)
        path_fields = path.split('/')
        if len(path_fields) != 5:
            continue
        server, _, date_str, database, filename = path_fields
        if date_str != needed_date_str:
            continue
        table, ext = os.path.splitext(filename)
        if ext != '.tgz':
            logger.warning("Strange file extension %s: %s", ext, path)
            continue

        dbtable = database + '.' + table
        table_servers[dbtable].add(server)
        all_servers.add(server)

    source_identifier = SourceIdentifier(all_servers)

    for dbtable, servers in table_servers.iteritems():
        for source_instance, source_server in source_identifier.iter_real_servers(servers, dbtable):
            if source_instance is not None:
                yield source_instance + '.' + dbtable, source_server


BSDB = 'bsdb'

_SERVER_DESCRS = {
    BSDB: 'yabs',
    'bstat': 'yabsst',
    'bspstat': 'yabstarget',
    'bscoll': None,
    'bsint': None,
    'bsistat01': 'yabsistat',
    'bsistat02': 'yabsistat02',
    'bsistat03': 'yabsistat',
}


class SourceIdentifier(object):

    def __init__(self, all_servers):
        self.all_servers_by_prefix = _group_servers_by_prefix(all_servers)

    def iter_real_servers(self, servers, dbtable_for_logging):
        servers = set(servers)
        servers_by_prefix = _group_servers_by_prefix(servers)
        for prefix, prefix_servers in servers_by_prefix.iteritems():
            all_prefix_servers = self.all_servers_by_prefix[prefix]
            if len(prefix_servers) >= len(all_prefix_servers) / 2 or BSDB not in servers_by_prefix:
                # FIXME Wild heuristics.
                # If the label is present on less than half of the servers with this prefix,
                # then this is a bsdb replica (placed into backup by mistake) and should be ignored
                for server in prefix_servers:
                    yield get_server_instance(server), server
            else:
                logging.info(
                    "Suspicious table %s in backups of %s (also in %s), likely bsdb replica",
                    dbtable_for_logging, ', '.join(sorted(prefix_servers)), ', '.join(sorted(servers - prefix_servers))
                )


def _group_servers_by_prefix(servers):
    servers_by_prefix = defaultdict(set)
    for server in servers:
        prefix, _ = _split_server(server)
        servers_by_prefix[prefix].add(server)
    return servers_by_prefix


def _split_server(server):
    for prefix in _SERVER_DESCRS:
        if not server.startswith(prefix):
            continue
        suffix = server[len(prefix):]
        m = re.match(r'(\d*).*?', suffix)
        suffix = m.group(1)
        return prefix, suffix

    raise RuntimeError("Unknown MySQL server %s" % server)


def get_server_instance(server):
    prefix, suffix = _split_server(server)
    inst_prefix = _SERVER_DESCRS[prefix]
    return None if inst_prefix is None else inst_prefix + suffix


@retried(max_retries=5, initial_timeout=5, sleep=False, exceptions=(Exception,))
def _group2hosts(group, timeout=1):
    uri = "https://c.yandex-team.ru/api/groups2hosts/{}".format(group)
    response = requests.get(uri, timeout=timeout)
    response.raise_for_status()
    return set(response.text.splitlines())


if __name__ == '__main__':
    import sys
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    sources = get_all_table_sources()
    json.dump(sources, sys.stdout, indent=4, sort_keys=True)
