#! /usr/bin/env python3

import re
import sys
import yaml
import time
import shlex
import subprocess
import argparse
import requests
import getpass

from enum import Enum
from os.path import exists


class Env(Enum):
    corp = 'corp'
    prod = 'prod'


class Ctl(Enum):
    nothing = 0
    portoctl = 1
    supervisorctl = 2
    touch = 3
    suspend = 4
    its = 5
    awacs = 6


class Cmd(Enum):
    start = 'start'
    stop = 'stop'
    status = 'status'


class AppInfo:
    def __init__(
        self,
        name=None,
        app=None,
        corp=None,
        prod=None,
        components=None,
        targets=[],
        control=Ctl.supervisorctl,
        manage_balancer=None,
        threads=None,
        group_format=None,
        deploy_endpoints=None,
        user=None,
        stage_format=None,
        handles=None,
        dcs=None,
    ):
        self.Name = name
        self.App = app
        self.Corp = corp
        self.Prod = prod
        self.Components = components
        self.Targets = targets
        self.Control = control
        self.ManageBalancer = manage_balancer
        self.Threads = threads
        self.GroupFormat = group_format
        self.DeployEndpoints = deploy_endpoints
        self.User = getpass.getuser() if user is None else user
        self.StageFormat = stage_format
        self.Handles = handles
        self.Dcs = dcs

        assert self.Name is not None
        assert isinstance(self.App, dict)
        assert isinstance(self.App[Env.prod], str) and self.App[Env.prod]
        assert isinstance(self.App[Env.corp], str) and self.App[Env.corp]
        assert self.Corp is None or isinstance(self.Corp, list)
        assert self.Prod is None or isinstance(self.Prod, list)
        assert isinstance(self.Components, dict)
        assert self.Components[Env.prod] is None or isinstance(self.Components[Env.prod], list)
        assert self.Components[Env.corp] is None or isinstance(self.Components[Env.corp], list)
        assert isinstance(self.Targets, list)
        assert isinstance(self.Control, Ctl)
        assert isinstance(self.Threads, int)
        assert self.User is not None
        assert isinstance(self.Dcs, list)

    @staticmethod
    def load(config):
        def read_list(v, default=None):
            return default if v is None else v if isinstance(v, list) else re.split(r'\s*,\s*', v)

        def read_app(name, values):
            return {
                k: (v.format(app=name) if v else name)
                for k, v in {
                    Env.prod: values.get('prod') if isinstance(values, dict) else values,
                    Env.corp: values.get('corp') if isinstance(values, dict) else values
                }.items()
            }

        def read_components(values):
            return {
                Env.prod: read_list(values.get('prod') if isinstance(values, dict) else values),
                Env.corp: read_list(values.get('corp') if isinstance(values, dict) else values)
            }

        return [
            AppInfo(
                name=name,
                app=read_app(name, attrs.get('app', {})),
                corp=read_list(attrs['corp']),
                prod=read_list(attrs['prod']),
                components=read_components(attrs['components']),
                targets=read_list(attrs['targets'], []),
                control=Ctl[attrs['ctl']],
                manage_balancer=attrs['manage_balancer'],
                threads=attrs['threads'],
                group_format=attrs['group_format'],
                deploy_endpoints=attrs['deploy_endpoints'],
                user=attrs['user'],
                stage_format=attrs['stage_format'],
                handles=attrs['handles'],
                dcs=attrs['dcs'],
            )
            for name, attrs in [next(iter(app.items())) for app in config['config']['services']]
        ]

    def get_deploy_endpoints(self, env, dc):
        fqdns = []
        for component in self.Components[env]:
            endpoint_set_id = self.DeployEndpoints.format(
                app=self.Name,
                env=self.Corp[0] if env == Env.corp else self.Prod[0],
                component=component,
            )
            a = requests.get('http://sd.yandex.net:8080/resolve_endpoints/json', json={
                "cluster_name": dc,
                "endpoint_set_id": endpoint_set_id,
                "client_name": "mail-sre",
            })
            if a.json() and 'endpoint_set' in a.json():
                endpoint_set = a.json()['endpoint_set']
                if 'endpoints' in endpoint_set:
                    for endpoint in endpoint_set['endpoints']:
                        fqdns.append(self.GroupFormat.format(fqdn=endpoint['fqdn']))
        return [fqdns]

    def get_deploy_pods(self, env, dc):
        pods = []
        for du in self.Components[env]:
            stage_id = self.StageFormat.format(
                app=self.Name,
                env=self.Corp[0] if env == Env.corp else self.Prod[0],
            )
            cmd = 'ya tool yp select pod --filter "[/meta/pod_set_id] = \'{stage_id}.{du}\'" --address {dc} --selector /meta/id --format json --no-tabular | jq -r .[][]'.format(
                stage_id=stage_id,
                du=du,
                dc=dc,
            )
            with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=None, shell=True) as process:
                output = process.communicate()[0].decode("utf-8")
            pods += output.split("\n")

        return [a for a in pods if a]

    def get_conductor_groups(self, env, dc):
        if self.DeployEndpoints:
            return self.get_deploy_endpoints(env, dc)

        groups = []
        for environment in (self.Corp if env == Env.corp else self.Prod) or []:
            groups.append([
                '%' + self.GroupFormat.format(
                    app=self.App[env], env=environment, component=component, dc='@'+dc
                ).lstrip('_')
                for component in (['_' + c for c in self.Components[env]] if self.Components[env] else [''])
            ])

        return [g for i, g in enumerate(groups) if g not in groups[:i]]

    def get_cmd_pattern(self):
        cmd_pattern = {
            Ctl.portoctl: {
                Cmd.start: 'portoctl start {targets}',
                Cmd.stop: 'portoctl stop {targets}',
                Cmd.status: 'portoctl get {targets} state',
            },
            Ctl.supervisorctl: {
                Cmd.start: 'supervisorctl start {targets}',
                Cmd.stop: 'supervisorctl stop {targets}',
                Cmd.status: 'supervisorctl status {targets}; ctl_exit_status=$?; if [ $ctl_exit_status -eq 3 ]; then exit 0; else exit $ctl_exit_status; fi',
            },
            Ctl.touch: {
                Cmd.start: 'rm -f {targets}',
                Cmd.stop: 'touch {targets}',
                Cmd.status: 'test -f {targets} && echo service disabled || echo service enabled',
            },
            Ctl.suspend: {
                Cmd.start: 'ya tool yp update pod {pod} --set /spec/pod_agent_payload/spec/target_state active --address {dc}',
                Cmd.stop: 'ya tool yp update pod {pod} --set /spec/pod_agent_payload/spec/target_state suspended --address {dc}',
                Cmd.status: 'ya tool yp get pod {pod} --selector /spec/pod_agent_payload/spec/target_state --address {dc}',
            },
            Ctl.its: {
                Cmd.start: './its/tune --handle {handle} --on',
                Cmd.stop: './its/tune --handle {handle} --off',
                Cmd.status: './its/tune --handle {handle} --status',
            },
            Ctl.awacs: {
                Cmd.start: './l7heavy/tune --balancer {handle}',
                Cmd.stop: './l7heavy/tune --balancer {handle} --off {dc}',
                Cmd.status: './l7heavy/tune --balancer {handle} --status {dc}',
            },
        }.get(self.Control)
        return cmd_pattern

    def get_instance_main_command(self, cmd):
        cmd_pattern = self.get_cmd_pattern()
        return [cmd_pattern[cmd].format(targets=' '.join(self.Targets))] if cmd_pattern else []

    def get_instance_command(self, cmd):
        commands = self.get_instance_main_command(cmd)
        if self.ManageBalancer is not None:
            manage_balancer_cmd = {
                Cmd.start: 'sudo manage-balancer open {target}',
                Cmd.stop: 'sudo manage-balancer close {target}',
                Cmd.status: 'sudo manage-balancer list {target}',
            }[cmd].format(target=self.ManageBalancer).strip()

            if cmd in [Cmd.start]:
                commands.append(manage_balancer_cmd)
            else:
                commands.insert(0, manage_balancer_cmd)

        return ' && '.join(commands)

    def get_commands(self, env, dc, cmd):
        if not dc in self.Dcs:
            return []

        if self.Control == Ctl.suspend:
            return [
                self.get_cmd_pattern()[cmd].format(
                    pod=pod,
                    dc=dc
                )
                for pod in self.get_deploy_pods(env, dc)
            ]

        if self.Control == Ctl.its:
            if env == Env.corp and self.Corp == None or env == Env.prod and self.Prod == None:
                return []
            return [
                self.get_cmd_pattern()[cmd].format(
                    handle = handle.format(
                        dc=dc,
                        app=self.App[env],
                        env=self.Corp[0] if env == Env.corp else self.Prod[0],
                    ),
                ) for handle in self.Handles
            ]
        if self.Control == Ctl.awacs:
            if env == Env.corp and self.Corp == None or env == Env.prod and self.Prod == None:
                return []
            return [
                self.get_cmd_pattern()[cmd].format(
                    dc=dc,
                    handle=handle,
                ) for handle in self.Handles
            ]

        command = self.get_instance_command(cmd)
        return [
            'executer {threads} --user {user} -c --quiet p_exec {cgroup} {command}'.format(
                threads=('-t {}'.format(self.Threads) if self.Threads else ''),
                user=self.User,
                cgroup=','.join(cgroup_list),
                command=command
            )
            for cgroup_list in self.get_conductor_groups(env, dc)
        ]


class Command:
    def __init__(self, text):
        self.CommandText = text


class Manager:
    def __init__(self, config):
        self.AppInfoList = dict([(x.Name, x) for x in AppInfo.load(config)])

    def get_commands(self, cmd, svcs, envs, dcs):
        svcs = self.AppInfoList.keys() if svcs == 'all' else svcs if isinstance(svcs, list) else [svcs]
        envs = envs if isinstance(envs, list) else [envs]
        dcs = dcs if isinstance(dcs, list) else [dcs]

        commands = []
        for svc, env, dc in [(s, e, d) for s in svcs for e in envs for d in dcs]:
            cmds = [Command(text) for text in self.AppInfoList[svc].get_commands(env, dc, cmd)]
            commands.extend(cmds)

        return commands


class DryRunExecuter:
    def __init__(self, file=sys.stdout):
        self.file = file

    def run(self, commands):
        for command in commands:
            print(command.CommandText, file=self.file, flush=True)


class CmdExecuter:
    def __init__(self, pause=None):
        self.pause = pause

    def run(self, commands):
        for i, command in enumerate(commands, start=1):
            print("Run '{command}'...".format(command=command.CommandText), file=sys.stderr)
            args = shlex.split(command.CommandText)
            res = subprocess.run(args, capture_output=True, encoding="utf-8")
            if res.returncode != 0:
                print("Command {command} exited with {code}".format(command=command.CommandText, code=res.returncode))
                print("stderr:")
                print(res.stderr)
                print("stdout:")
                print(res.stdout)
            else:
                if res.stdout:
                    print("stdout (last 10 lines):")
                    print('\n'.join(res.stdout.split('\n')[-10:]))
                print("\n\n")

            if self.pause is not None and i < len(commands):
                if self.pause > 0.5:
                    print("Pause for {pause} seconds...".format(pause=self.pause), file=sys.stderr)
                time.sleep(self.pause)


def make_argument_parser():
    allowed_data_centers = ['iva', 'myt', 'sas', 'vla']
    is_command_mode = not any(arg in sys.argv for arg in ['--list-services'])
    parser = argparse.ArgumentParser(description='manage services')
    parser.add_argument(
        '-i', '--config',
        type=argparse.FileType('r'),
        dest='config',
        metavar='<config>',
        default='manage_services.yaml',
        help='Services configuration file',
    )
    parser.add_argument(
        '-c', '--command',
        type=Cmd,
        choices=Cmd,
        required=is_command_mode,
        dest='command',
        metavar='<command>',
        help='Allowed commands: ' + ', '.join([c.value for c in Cmd]),
    )
    parser.add_argument(
        '-s', '--service',
        nargs='+',
        required=is_command_mode,
        dest='service',
        metavar='<service>',
        help='Allowed services: all, ...',
    )
    parser.add_argument(
        '-e', '--env', '--environment',
        type=Env,
        choices=Env,
        nargs='+',
        required=is_command_mode,
        dest='environment',
        metavar='<environment>',
        help='Allowed environments: ' + ', '.join([e.value for e in Env])
    )
    parser.add_argument(
        '-d', '--dc', '--data-center',
        choices=allowed_data_centers,
        required=is_command_mode,
        dest='dc',
        metavar='<dc>',
        help='Allowed data centers: ' + ', '.join(allowed_data_centers)
    )
    parser.add_argument('--pause', type=float, default=0, help='Command invocation delay (secs)')
    parser.add_argument('--skip-cache-update', action='store_true', help='Do not try to update executer cache')
    parser.add_argument('--dry-run', action='store_true')
    parser.add_argument('--list-services', action='store_true', help='List allowed services')

    return parser


if __name__ == '__main__':
    parser = make_argument_parser()
    args = parser.parse_args()

    manager = Manager(config=yaml.safe_load(args.config))

    if args.list_services:
        print(' '.join(manager.AppInfoList.keys()), flush=True)
        quit()

    if not exists("./its/tune"):
        subprocess.run("ya make", shell=True, cwd="./its")

    if not exists("./l7heavy/tune"):
        subprocess.run("ya make", shell=True, cwd="./l7heavy")

    commands = manager.get_commands(
        args.command,
        'all' if 'all' in args.service else args.service,
        args.environment,
        [args.dc]
    )

    if args.dry_run:
        DryRunExecuter(file=sys.stdout).run(commands=commands)
        quit()

    if not args.skip_cache_update:
        print("Updating executer cache...", file=sys.stderr)
        subprocess.run(shlex.split("executer ''"))

    CmdExecuter(pause=args.pause).run(commands=commands)
