#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ./shell_on_deploy_units.py --tag autotest --unit PerlApp --cmd 'echo 1'
#
#   Скрипт генерирует inverntory_deploy.yaml файл
#   и запускает команды на выбранных группах стэйджей через ansible shell

#   --reset_inventory_file - принудительно пересоздает только inventory_file
#   --reset_config_file    - принудительно пересоздает config_file и inventory_file
#   --ansible_tag          - tag из inventory_file
#   --print_tags           - показать доступные тэги
#   --print_ansible_tags   - показать доступные ansible тэги
#   --tag                  - при первом запуске скрипт собирает все стэйджи и их таги, для выбора стэйджей по тагу
#   --deploy_unit          - добавить в фильтр Deploy Unit
#   --unit_box             - добавить в фильтр Unit Box
#   --debug                - дополнительный вывод при работе
#   --cmd_stdin            - спросить команду перед запуском
#   --cmd                  - команда для запуска на стэйджах
#   --script               - скрипт для запуска на стэйджах
#   --autotest_number      - для --tag=autotest, указать номера стэйджей
#   --download_file        - путь к файлу для скачивания
#   --upload_file          - путь к файлу для загрузки
#   --upload_to            - директория для сохранения. default = '/tmp/'
#

import argparse
import argcomplete
import json
import logging
import os
import paramiko
import re
import rich.traceback
import subprocess
import time
import yaml

from concurrent.futures import ThreadPoolExecutor
from rich.console import Console
from rich.logging import RichHandler
from rich.prompt import Confirm, Prompt
from rich.table import Table

FORMAT = "%(message)s"
logging.basicConfig(
    level="NOTSET", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]
)

console = Console()
log = logging.getLogger("rich")
rich.traceback.install()

deploy_units = ["Backend", "Crons", "Database", "Hourglass", "teamcity"]
unit_boxes = [
    "CronApp",
    "FrontendNode",
    "JavaAppINTAPI",
    "JavaAppJSONAPI",
    "JavaAppHourglass",
    "JavaAppTESTAPI",
    "Haproxy",
    "MockAPI",
    "MySQL",
    "PerlApp",
    "MemcachedBox",
    "teamcity",
]
ignore_tags = [""]

inventory_file = "inventory_deploy.json"
inventory = {"all": {"children": {}}}

config_file = "shell_config.json"
config = {
    "ttl": 3600,
    "workers": 10,
    "debug_level": 1,
}

logging.basicConfig(
    level="NOTSET", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
)
log = logging.getLogger("rich")
logging.getLogger("paramiko").setLevel(logging.WARNING)

my_env = os.environ.copy()

my_parser = argparse.ArgumentParser(
    description="shell_on_deploy_units.py",
    epilog="Example: ./shell_on_deploy_units.py --tag production --unit_box CronApp --cmd 'grep -R ERROR /var/log/partners.yandex.ru/'"
)

my_parser.add_argument("--reset_inventory_file", action="store_true")

my_parser.add_argument("--reset_config_file", action="store_true")

my_parser.add_argument(
    "--ansible_tag",
    metavar="ansible_tag",
    nargs="?",
    type=str,
    help="name of inventory file tag",
)

my_parser.add_argument("--print_tags", action="store_true")

my_parser.add_argument("--print_ansible_tags", action="store_true")

my_parser.add_argument(
    "--tag", metavar="tag", nargs="?", type=str, help="name of deploy tag"
)

my_parser.add_argument(
    "--deploy_unit",
    metavar="deploy_unit",
    nargs="*",
    type=str,
    choices=deploy_units,
    help="name of DU (--deploy_unit=Backend)",
)

my_parser.add_argument(
    "--unit_box",
    metavar="unit_box",
    nargs="*",
    type=str,
    choices=unit_boxes,
    help="name of box (--unit_box=PerlApp)",
)

my_parser.add_argument("--debug", action="store_true")

my_parser.add_argument("--cmd_stdin", action="store_true")

my_parser.add_argument(
    "--cmd", metavar="cmd", nargs="?", type=str, help='cmd to execute (--cmd="echo 1")'
)

my_parser.add_argument(
    "--script",
    metavar="script",
    nargs="?",
    type=str,
    help='script to execute (--script="/path/to/script")',
)

my_parser.add_argument(
    "--autotest_number",
    metavar="autotest_number",
    nargs="?",
    type=str,
    help="autotest numbers (--autotest_number=12-15, --autotest_number=1,12,15)",
)

my_parser.add_argument(
    "--download_file",
    metavar="download_file",
    nargs="?",
    type=str,
    help="remote file for downloading",
)

my_parser.add_argument(
    "--upload_file",
    metavar="upload_file",
    nargs="?",
    type=str,
    help="local file for uploading",
)

my_parser.add_argument(
    "--upload_to",
    metavar="upload_to",
    nargs="?",
    type=str,
    help="remote directory for uploading",
)

argcomplete.autocomplete(my_parser)


def debug(msg=""):
    if args.debug:
        log.debug(msg)


def truncate(filename=""):
    f = open(filename, "w+")
    f.truncate(0)


def exception(msg, type=""):
    log.error(msg, extra={"markup": True})
    if type == "usage":
        os._exit(os.EX_USAGE)
    else:
        os._exit(os.EX_OSERR)


def check_token():
    if os.system("cat ~/.ya_token > /dev/null 2>&1"):
        exception('You should specify TOKEN. Please run "ya whoami --save-token"')

def parse_response(response=[], columns=[]):
    response = response[3:-1];
    output = []

    for row in response:
        row = re.split("\s*\|\s*", row)
        if len(columns) == 1:
            output.append(row[columns[0]])
        else:
            output_row = []
            for column in columns:
                output_row.append(row[column])
            output.append(output_row)
    return output


def print_tags():
    if args.print_tags:
        projects = run(
            "ya tool dctl list project",
            array=True,
            parse=True,
            columns=[2]
        )

        for idx, val in enumerate(projects):
            projects[idx] = f"[/annotations/project]='{val}'"
        filter_param = "or".join(projects)

        tags = run(
            f'ya tool yp select stage --address xdc --selector /labels/tags --selector /meta/id --filter "{filter_param}"'
        )
        log.info(tags)
        os._exit(os.EX_OK)


def run(cmd, array=False, parse=False, columns=[]):
    if config["debug_level"] >= 2:
        debug(f"run: {cmd}")

    p = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=my_env
    )
    out, err = p.communicate()
    out = out.decode("utf-8")

    if config["debug_level"] >= 2:
        debug(f"run: {out}")

    if array == True:
        out = out.rstrip().split("\n")
        if parse == True and len(columns) > 0:
            return parse_response(response=out, columns=columns)

    return out


def get_stage_list():
    stage_list = run(
        "ya tool dctl list stage",
        array=True,
        parse=True,
        columns=[2]
    )
    return stage_list


def get_fqdns(stage, deploy_unit):
    # 2>/dev/null
    fqdns = run(
        f"ya tool dctl list endpoint {stage}.{deploy_unit}",
        array=True,
        parse=True,
        columns=[3],
    )
    return fqdns


def get_deploy_units(stage_yaml, stage):
    result = {}
    for unit in stage_yaml["spec"]["deploy_units"].keys():
        try:
            boxes = list(
                stage_yaml["spec"]["deploy_units"][unit]["images_for_boxes"].keys()
            )
            fqdns = get_fqdns(stage, unit)
            result[unit] = {}
            for box in boxes:
                result[unit][box] = []
                for fqdn in fqdns:
                    host = f"{box}.{fqdn}"
                    result[unit][box].append(host)
                    config["host_stage"][host] = stage
        except Exception as e:
            debug(f"Thread {stage}: error {e}")
    return result


def thread_function(stage):
    debug(f"Thread {stage}: starting")

    res = run(f"ya tool dctl get stage {stage}")
    stage_yaml = yaml.safe_load(res)
    tags = filter(lambda x: x not in ignore_tags, stage_yaml["labels"]["tags"])

    config["stages"][stage] = {
        "tags": stage_yaml["labels"]["tags"],
        "deploy_units": get_deploy_units(stage_yaml, stage),
    }
    debug(f"Thread {stage}: finishing")


def set_inventory_file():
    truncate(inventory_file)

    for stage in config["stage_list"]:
        for tag in config["stages"][stage]["tags"]:
            for unit in config["stages"][stage]["deploy_units"].keys():
                for box in config["stages"][stage]["deploy_units"][unit].keys():
                    key = f"{tag}_{unit}_{box}"
                    values = config["stages"][stage]["deploy_units"][unit][box]
                    if key not in inventory["all"]["children"].keys():
                        inventory["all"]["children"][key] = {"hosts": {}}
                    for value in values:
                        inventory["all"]["children"][key]["hosts"][value] = None

    with open(inventory_file, "w+") as inventory_file_handle:
        json.dump(
            inventory,
            inventory_file_handle,
            ensure_ascii=True,
            indent=4,
            sort_keys=True,
        )


def set_config_file():
    start_time = time.time()

    global config, inventory

    if (
        args.reset_inventory_file
        or args.reset_config_file
        or not os.path.isfile(config_file)
        or not os.path.isfile(inventory_file)
        or (time.time() - os.stat(config_file).st_ctime) > config["ttl"]
    ):
        log.info("start getting stage configurations (first start ~40sec)")
        config["stage_list"] = get_stage_list()
        config["stages"] = {}
        config["host_stage"] = {}

        future_list = []
        with ThreadPoolExecutor(max_workers=config["workers"]) as executor:
            for stage in config["stage_list"]:
                future_list.append(executor.submit(thread_function, stage))

            for future in future_list:
                try:
                    future.result()
                except Exception as e:
                    exception(e)

        truncate(config_file)

        with open(config_file, "w+") as config_file_handle:
            json.dump(
                config, config_file_handle, ensure_ascii=True, indent=4, sort_keys=True
            )

        set_inventory_file()
        log.info(f"done ({time.time() - start_time} secs)")
    else:
        config = json.load(open(config_file, "r"))
        inventory = json.load(open(inventory_file, "r"))

        if (
            "stage_list" not in config
            or "stages" not in config
            or "all" not in inventory
        ):
            args.reset_config_file = True
            set_config_file(args)


def validate_args():
    if args.print_ansible_tags:
        return True

    if (
        not args.tag
        and not args.deploy_unit
        and not args.unit_box
        and not args.ansible_tag
    ):
        exception(
            "Target for executing should be determined! Possible keys: tag, deploy_unit, unit_box, ansible_tag"
        )

    if args.autotest_number:
        if re.search("^\d+-\d+$", args.autotest_number):
            numbers = [int(el) for el in args.autotest_number.split("-")]
            numbers.sort()
            args.autotest_number = list(range(numbers[0], numbers[1] + 1))
        elif re.search("^(\d+|(\d+\,){1,}\d+)$", args.autotest_number):
            args.autotest_number = [int(el) for el in args.autotest_number.split(",")]
        else:
            exception("Incorrect --autotest_number flag!")

    if args.download_file and args.upload_file:
        exception("Can't use --download_file and --upload_file at the same time!")

    if args.download_file or args.upload_file:
        if args.upload_file:
            if not os.path.isfile(args.upload_file):
                exception(f"File {args.upload_file} does not exist!")
        return True

    if not args.cmd_stdin:
        if not args.cmd and not args.script:
            exception("--cmd or --script should be determined!")

        if args.cmd and args.script:
            exception("Use only one flag --cmd or --script!")

        if args.script:
            if not os.path.isfile(args.script):
                exception(f"File {args.script} does not exist!")
            if not os.access(args.script, os.X_OK):
                log.info(f"Script {args.script} is not executable. Try change it...")
                run(f"chmod a+x {args.script}")
                if not os.access(args.script, os.X_OK):
                    exception(
                        f'Script {args.script} should be executable. Please run "chmod a+x {args.script}"'
                    )

            shebang = run(f"cat {args.script} | head -n 1")
            if not re.search("^#!/", shebang):
                exception(f"Script validation error: shebang must be specified!")


def filter_autotest_hosts(host):
    host_number = config["host_stage"][host].split("-")[-1]

    if host_number == "stage":
        if 1 in args.autotest_number:
            return True
        else:
            return False
    else:
        host_number = int(host_number)

    if host_number in args.autotest_number:
        return True
    return False


def prepare_hosts():
    global groups, inventory

    groups = list(inventory["all"]["children"].keys())

    if args.ansible_tag:
        if args.ansible_tag in inventory["all"]:
            groups = inventory["all"]["children"][args.ansible_tag]
        else:
            exception("flag --ansible_tag incorrect!")
    else:
        if args.tag:
            groups = list(filter(lambda x: re.search("^%s\_" % args.tag, x), groups))
        if args.deploy_unit:
            groups = list(
                filter(
                    lambda x: re.search("\_(?:%s)\_" % "|".join(args.deploy_unit), x),
                    groups,
                )
            )
        if args.unit_box:
            groups = list(
                filter(
                    lambda x: re.search("\_(?:%s)$" % "|".join(args.unit_box), x),
                    groups,
                )
            )

    hosts = []
    for tag in groups:
        hosts.extend(inventory["all"]["children"][tag]["hosts"])

    if args.tag == "autotest" and args.autotest_number:
        hosts = list(filter(lambda x: filter_autotest_hosts(x), hosts))

    return hosts


def run_ssh_shell(cli, host):
    stdin, stdout, stderr = cli.exec_command(args.cmd)
    return stdout.read() + stderr.read()


def run_script(cli, host):
    sftp = cli.open_sftp()
    fn = os.path.basename(args.script)
    sftp.put(args.script, f"/tmp/{fn}")
    sftp.chmod(f"/tmp/{fn}", 100)

    stdin, stdout, stderr = cli.exec_command(f"/tmp/{fn}")
    cli.exec_command(f"rm /tmp/{fn}")
    return stdout.read() + stderr.read()


def execute(hosts):
    table = Table(show_footer=False)
    if args.cmd:
        table.title = f"Cmd: {args.cmd}"
    else:
        table.title = f"Script: {args.script}"

    table.add_column("Stage", style="cyan", no_wrap=True)
    table.add_column("Host", style="cyan", no_wrap=True)
    table.add_column("Result", style="magenta")

    executor = "run_script"
    if args.cmd:
        executor = "run_ssh_shell"

    for host in hosts:
        cli = get_cli(host)
        stage = config["host_stage"][host]
        box = host.split(".")[0]
        log.info(f"Executing on {stage} ({box})")
        res = eval(executor + "(cli, host)")
        cli.close()
        table.add_row(stage, host, res.decode("utf-8"))

    console.print(table)


def download_file(hosts):
    for host in hosts:
        cli = get_cli(host)
        sftp = cli.open_sftp()

        fn = os.path.basename(args.download_file)
        stage = config["host_stage"][host]

        log.info(f"Downloading {fn} from {stage}")
        sftp.get(f"{args.download_file}", fn)
        if not os.path.exists(stage):
            os.mkdir(stage)

        os.replace(f"./{fn}", f"./{stage}/{fn}")


def upload_file(hosts):
    for host in hosts:
        cli = get_cli(host)
        sftp = cli.open_sftp()

        fn = os.path.basename(args.upload_file)
        stage = config["host_stage"][host]

        destination = f"/tmp"
        if args.upload_to:
            destination = args.upload_to

        destination = re.sub("/$", "", destination)
        log.info(f"Uploading {fn} to {stage}:{destination}")

        sftp.put(args.upload_file, f"{destination}/{fn}")


def get_cli(host):
    cli = paramiko.client.SSHClient()
    cli.load_system_host_keys()
    cli.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    try:
        cli.connect(hostname=host, username="root")
        return cli
    except Exception as e:
        log.error(f"SSHClient error: {e}.")
        exception("Try run ssh agent - eval `ssh-agent -s`; ssh-add ~/.ssh/id_rsa;")


if __name__ == "__main__":
    args = my_parser.parse_args()
    print_tags()
    log.info("starting...")
    validate_args()
    debug(args)
    check_token()
    set_config_file()

    if args.print_ansible_tags:
        tags = list(inventory["all"]["children"].keys())
        tags.sort()
        for tag in tags:
            log.info(tag)
        os._exit(os.EX_OK)

    hosts = prepare_hosts()
    hosts.sort()

    if args.download_file:
        download_file(hosts)
        os._exit(os.EX_OK)

    if args.upload_file:
        upload_file(hosts)
        os._exit(os.EX_OK)

    if args.cmd_stdin:
        while not args.cmd:
            args.cmd = Prompt.ask("Please enter command here")

    log.info(f"hosts: {hosts}")
    if args.cmd:
        log.info(f"command: {args.cmd}")
    else:
        log.info(f"script: {args.script}")

    if Confirm.ask("Are you sure?", default=True):
        execute(hosts)
        log.info("finished")
    else:
        log.info("execution skipped")
