from __future__ import print_function

import re
import os
import sys
import socket
import argparse
import multiprocessing

try:
    from urllib.parse import quote
except ImportError:
    from urllib import quote

import yaml
import requests
import yp.client


def _http(cache={}):
    if None not in cache:
        cache[None] = requests.Session()
    return cache[None]


def _yp(cluster, token, cache={}):
    clusters = {
        "sas": "sas.yp.yandex-team.ru:8090",
        "man": "man.yp.yandex-team.ru:8090",
        "vla": "vla.yp.yandex-team.ru:8090",
        "sas.test": "sas-test.yp.yandex-team.ru:8090",
        "man.pre": "man-pre.yp.yandex-team.ru:8090",
    }
    assert cluster in clusters, 'invalid YP cluster {!r}'.format(cluster)
    if cluster not in cache:
        cache[cluster] = yp.client.YpClient(clusters[cluster], config={"token": token})
    return cache[cluster]


def _mod_port(port, mod):
    return int(port) + int(mod) if mod.startswith("+") or mod.startswith("-") else int(mod)


def resolve_dns(scheme, host, port):
    """[scheme://]hostname[:port]"""
    results = socket.getaddrinfo(host, int(port))
    results.sort(key=lambda x: x[0] != socket.AF_INET6)
    for (af, _1, _2, _3, data) in results:
        if af in (socket.AF_INET6, socket.AF_INET):
            yield scheme, data[0], data[1]


def resolve_raw(scheme, host, port, path):
    """[scheme://]ip[:port]"""
    port = int(port) or (443 if scheme in ('https', 'h2') else 80)
    for af in (socket.AF_INET6, socket.AF_INET):
        try:
            socket.inet_pton(af, host)
        except Exception:
            continue
        return [(scheme, host, port)]
    return resolve_dns(scheme, host, port)


def resolve_gencfg(scheme, group, port, path):
    """[scheme://][group@]gencfg[:port or :+mod]"""
    # TODO /group/%s/stable-X-rY/mtn
    assert not path, 'resolving a specific tag to MTN addresses is not supported'
    for k, v in _http().get("https://clusterstate.yandex-team.ru/group/%s/alive/mtn" % quote(group)).json().items():
        bb = v["interfaces"].get("backbone", {})
        ip = bb.get("ipv6addr") or bb.get("ipv4addr")
        if ip:
            yield scheme, ip, _mod_port(k.partition(':')[2], port)


def resolve_gencfg_nomtn_tag(scheme, group, port, path):
    """[scheme://][group@]gencfg-no-mtn[:port or :+mod]/trunk or /stable-X-rY"""
    if path != "/trunk":
        assert re.match(r"/stable-\d+-r\d+", path), "invalid tag {!r}, expected '/trunk' or '/stable-X-rY'".format(path)
        path = "/tags" + path
    for d in _http().get("http://api.gencfg.yandex-team.ru{}/searcherlookup/groups/{}/instances".format(path, quote(group))).json()["instances"]:
        ip = d.get("ipv6addr") or d.get("ipv6addr")
        if ip:
            yield scheme, ip, _mod_port(d["port"], port)


def resolve_gencfg_nomtn(scheme, group, port, path):
    """[scheme://][group@]gencfg-no-mtn[:port or :+mod]"""
    return resolve_gencfg_nomtn_tag(scheme, group, port, path) if path else (
        h
        for i in _http().get("https://clusterstate.yandex-team.ru/api/v1/groups/%s/CURRENT" % quote(group)).json()["current"]["instances"]
        for h in resolve_dns(scheme, i[0], _mod_port(i[1], port))
    )


def resolve_yp(scheme, endpoint_set, cluster, port, path):
    """[scheme://][endpoint-set@]yp[-dc[-dc]][:port or :+mod]"""
    token = os.environ.get('YP_TOKEN')
    assert token, "YP_TOKEN not specified, cannot resolve YP endpoint sets"
    endpoint_filter = '[/meta/endpoint_set_id] = "{}"'.format(endpoint_set)
    for k in ['sas', 'man', 'vla'] if cluster == 'yp' else cluster.split('-')[1:]:
        for _, e in _yp(k, token).select_objects("endpoint", selectors=["/meta", "/spec"], filter=endpoint_filter):
            yield scheme, e["ip6_address"], _mod_port(e["port"], port)


def resolve(scheme, key, host, port, path):
    return set(
        resolve_gencfg(scheme, key, port, path) if host == 'gencfg' else
        resolve_gencfg_nomtn(scheme, key, port, path) if host == 'gencfg-no-mtn' else
        resolve_yp(scheme, key, host, port, path) if host == 'yp' or host.startswith('yp-') else
        resolve_raw(scheme, host, port, path))


def _resolve_mpcb(it):
    try:
        return it, resolve(*it)
    except Exception:
        pass
    try:
        return it, resolve(*it)
    except Exception as e:
        return it, e


def parse_pseudo_url(key, url, port=""):
    """[scheme://][key@]host[:port][/path]"""
    try:
        scheme, _, rest = url.rpartition('://')
        host, sep, path = rest.partition('/')
        if ':' in host and not (host.startswith('[') and host.endswith(']')):
            host, _, port = host.rpartition(':')
        if '@' in host and not (host.startswith('[') and host.endswith(']')):
            key, _, host = host.partition('@')
        if host.startswith('[') and host.endswith(']'):
            host = host[1:-1]
        return scheme or 'http', key, host, port or "+0", sep + path
    except Exception:
        raise ValueError("invalid url: {!r}".format(url))


def main():
    try:
        with open(os.path.expanduser('~/.yp/token')) as fd:
            yp_token = fd.read().strip()
    except Exception:
        yp_token = ""

    argp = argparse.ArgumentParser(
        description="Construct a backend list for the proxy action. To use the output, put "
                    "`#include: path` and `proxy: *backend_name` into your config file.")
    argp.add_argument(
        "--update", metavar="backends.yaml", nargs="?", type=argparse.FileType("r"),
        help="Read the previous version of the backend list and preserve positions of as "
             "many backends common to both versions as possible; using this mode with "
             "`proxy`'s hash-by option provides weighted rendezvous hashing. Also, if "
             "resolving a group fails, everything that refers to it is kept unchanged.")
    argp.add_argument(
        "--yp-only", action="store_true",
        help="Only resolve groups consisting purely of YP endpoint sets; the rest are "
             "ignored or copied from the previous version.")
    argp.add_argument(
        "--yp-token", metavar="TOKEN", action="store", default=os.environ.get('YP_TOKEN') or yp_token,
        help="An OAuth token authorized to read data from YP. Can also be set by "
             "~/.yp/token or the YP_TOKEN environment variable.")
    argp.add_argument(
        "input", metavar="backends.in.yaml", nargs="?", type=argparse.FileType("r"), default=sys.stdin,
        help="Path to the backend definitions. See example.yaml.")
    args = argp.parse_args()
    os.environ['YP_TOKEN'] = args.yp_token

    defns = {
        k: {parse_pseudo_url(k, group) for group in (v if isinstance(v, list) else [v])}
        for k, v in yaml.load(args.input, Loader=yaml.FullLoader).items()
    }
    if not defns:
        exit("No backends defined.")
    update = yaml.load(args.update, Loader=yaml.FullLoader)['backends'] if args.update is not None else {}

    need_resolving = {v for vs in defns.values() for v in vs if not args.yp_only or v[2] == 'yp' or v[2].startswith('yp-')}
    resolved = {}
    failed = {}
    for expr, result in multiprocessing.Pool(4).imap_unordered(_resolve_mpcb, need_resolving):
        if isinstance(result, Exception):
            failed[expr] = "{} -- {}".format(type(result).__name__, result)
        elif not result:
            failed[expr] = "no backends"
        else:
            resolved[expr] = result
    lists = {
        name: set(update[name]) if name in update and not all(group in resolved for group in groups) else {
            '{}://{}:{}'.format(scheme, '[' + ip + ']' if ':' in ip else ip, port)
            for group in groups for scheme, ip, port in resolved.get(group, [])
        } for name, groups in defns.items()
    }

    def merge(old, new):
        really_new = iter(sorted(new - set(old)))
        for b in old:
            p = b
            while b not in new:
                b = next(really_new, None) or old.pop(-1)
                if b == p:
                    break  # popped everything remaining in `old`
            else:
                yield b
        for b in really_new:
            yield b

    print("backends:")
    for name, backends in sorted(lists.items()):
        if backends:
            print("  {0}: &{0}".format(name), *merge(update.get(name, []), backends), sep="\n  - ")
    if failed:
        exit("Failed:" + "".join("\n  - {}://{}@{}:{}{}: ".format(*k) + v for k, v in failed.items()))
