import ConfigParser
import subprocess
import requests
import argparse
import logging
import logging.handlers
import base64
import shutil
import socket
import json
import time
import os


JUGGLER_URL = "http://juggler-push.search.yandex.net/api/1/batch"


def update_juggler(st, descr=None):
    service_id = os.getenv('NANNY_SERVICE_ID', None)

    if descr is not None and service_id is not None:
        descr = "{0} {1}".format(str(service_id), descr)

    data = {
        "description": "" if descr is None else descr,
        "host": socket.gethostname(),
        "instance": "",
        "service": "tickets_rotation_l7_balancer",
        "status": st,
    }
    resp = requests.post(JUGGLER_URL, data=json.dumps({'source': 'tickets_deploy', 'events': [data, ]}))
    data = resp.json()

    if "status" in data and data["status"] == "200 OK":
        logging.info("Juggler updated")
    else:
        logging.error("Error: failed to send data to juggler {0}".format(json.dumps(data)))


def run_yav_deploy(path, section, oauth):
    yav_deploy = os.path.join(path, "yav_deploy")

    if not os.path.isfile(yav_deploy) or not os.access(yav_deploy, os.X_OK):
        raise Exception("Could not find 'yav_deploy' or it is not executable")

    yav_conf = os.path.join(path, "yav-deploy.conf")

    if not os.path.isfile(yav_conf):
        raise Exception("Could not find 'yav-deploy.conf'")

    yav_env = os.environ.copy()
    with open(oauth) as fd:
        yav_env["YAV_TOKEN"] = fd.read()

    proc = subprocess.Popen(
        [yav_deploy, "--skip-pkg", "-c", path, "--file", "yav-deploy.conf", "--sections", section],
        stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
        env=yav_env
    )
    _, err = proc.communicate("")

    if proc.returncode != 0:
        raise Exception("Failed to run yav_deploy: {0}".format(err))


def decode_tickets(path):
    config = ConfigParser.ConfigParser()
    config.read(os.path.join(path, "yav-deploy.conf"))

    if "tickets" not in config.sections():
        raise Exception("Could not find tickets section in yav config")

    if not config.has_option("tickets", "prefix"):
        raise Exception("Could not find tickets prefix value")

    tickets_names = set()
    prefix = config.get("tickets", "prefix")
    for k in config.options("tickets"):
        if k.startswith("/"):
            keyname = k[1:]
            if ":" in keyname:
                keyname = keyname.split(":")[0]
            secpath = os.path.join(prefix, keyname)

            if not os.path.exists(secpath):
                raise Exception("Could not find base64 encoded key {0}".format(secpath))

            if not os.access(secpath, os.W_OK):
                raise Exception("Could write to {0}".format(secpath))

            basename = "{0}.key".format(os.path.splitext(keyname)[0])
            tickets_names.add(basename)
            dstpath = os.path.join(prefix, basename)

            if not os.access(path, os.W_OK):
                raise Exception("Path is not writable")

            b64 = None
            with open(secpath) as fd:
                b64 = base64.b64decode(fd.read())

            with open(secpath, "w") as fd:
                fd.write(b64)

            os.chmod(secpath, 0400)
            shutil.move(secpath, dstpath)

    return tickets_names


def check_reloaded_keys_count(prefix, counter, resp):
    first = "1st.{0}".format(prefix)
    second = "2nd.{0}".format(prefix)
    third = "3rd.{0}".format(prefix)

    if first not in counter:
        logging.error("Response:\n{0}".format(str(resp.text)))
        raise Exception("No keys was found: {0}".format(first))
    if second not in counter:
        logging.error("Response:\n{0}".format(str(resp.text)))
        raise Exception("No keys was found: {0}".format(second))
    if third not in counter:
        logging.error("Response:\n{0}".format(str(resp.text)))
        raise Exception("No keys was found: {0}".format(third))

    if counter[first] != counter[second] or counter[second] != counter[third]:
        logging.error("Response:\n{0}".format(str(resp.text)))
        raise Exception("Number of reloaded tickets has diff: {0} {1} {2}".format(counter[first], counter[second], counter[third]))


def reload_balancer(port, tickets_names, force):
    resp = requests.get("http://localhost:{0}/admin/events/call/{1}reload_ticket".format(str(port), "force_" if force else ""))
    if resp.text is None:
        raise Exception("Balancer empty response")

    counter = dict()
    wo_idx = set()
    for l in resp.text.split("\n"):
        parts = l.split()
        if len(parts) >= 3 and "Reloading" in parts[0]:
            if parts[-1] != "OK":
                raise Exception("Failed to reload ticket: {0}".format(l))

            key_name = os.path.basename(parts[1][:-3])
            if key_name not in tickets_names:
                raise Exception("Unknown key name {0}".format(key_name))

            key_wo_idx = '.'.join(key_name.split('.')[1:])
            wo_idx.add(key_wo_idx)
            if key_name in counter:
                counter[key_name] += 1
            else:
                counter[key_name] = 1

    for prefix in wo_idx:
        check_reloaded_keys_count(prefix, counter, resp)

    for ticket in tickets_names:
        if ticket not in counter:
            logging.error("Response:\n{0}".format(str(resp.text)))
            raise Exception("Key {0} was not rotated".format(ticket))


def wait_balancer(port, max_time):
    delay = 0
    while True:
        try:
            requests.get("http://localhost:{0}/".format(str(port)))
            return
        except requests.ConnectionError:
            pass

        t = pow(2, delay)
        if t > max_time:
            raise Exception("Delay is greater than timeout")
        logging.error("Error: {0}".format("Balancer is not alive. Waiting for {0} s".format(str(t))))
        delay += 1
        time.sleep(t)


def run_loop(args):
    try:
        run_yav_deploy(args.path, args.section, args.oauth)
        tickets_names = decode_tickets(args.path)
        reload_balancer(args.port, tickets_names, args.force_update)
        logging.info("Update completed")
        if args.monitoring:
            update_juggler('OK', 'OK')
    except Exception as e:
        logging.error("Error: {0}".format(str(e)))
        update_juggler('CRIT', str(e))


def main(args):
    if not os.path.isdir(args.path):
        raise Exception("Incorrect instance path {0}".format(args.path))

    log_handler = logging.handlers.WatchedFileHandler(args.log)
    formatter = logging.Formatter("%(asctime)-15s %(message)s")
    log_handler.setFormatter(formatter)
    logger = logging.getLogger()
    logger.addHandler(log_handler)
    logger.setLevel(logging.INFO)

    if args.wait_balancer:
        wait_balancer(args.port, args.timeout)
    if args.force_update:
        run_loop(args)
    else:
        while True:
            run_loop(args)
            time.sleep(args.timeout)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Rotate tls tickets")

    parser.add_argument("--port", type=int, required=True, help="balancer admin port")
    parser.add_argument("--path", type=str, required=True, help="balancer instance path")
    parser.add_argument("--timeout", type=int, default=60*60, help="update timeout")
    parser.add_argument("--monitoring", action='store_false', help="send info to juggler")
    parser.add_argument("--section", type=str, default="tickets", help="sections to be deployed")
    parser.add_argument("--wait-balancer", action='store_false', help="check balancer on start")
    parser.add_argument("--log", type=str, default="/logs/current-tickets-reload.log", help="log path")
    parser.add_argument("--oauth", type=str, default="/dev/shm/oauth", help="path to file with oauth token")
    parser.add_argument("--force-update", action='store_true', help="force tickets update and exist")

    main(parser.parse_args())

