# -*- coding: utf-8

from sandbox.sdk2 import Task, Requirements, Parameters, ResourceData
from sandbox.sdk2.helpers.process import subprocess as sdk_subprocess
from sandbox.sdk2.ssh import Key as SshKey
from sandbox.sdk2.yav import Secret as YavSecret
from sandbox.sdk2.service_resources import SandboxTasksBinary
from sandbox.projects.resource_types import OTHER_RESOURCE as OtherResource, ARCADIA_PROJECT_TGZ
from sandbox.projects.common.arcadia import sdk as arcadiasdk
from sandbox.projects.common.constants import constants as sdk_constants
from sandbox import sdk2
import sandbox.common.types.client as ctc
import sandbox.common.types.task as ctt
from tempfile import mkdtemp
import os
from shutil import rmtree
import logging
import tarfile
import random
import subprocess


class PerfRunner:
    osUser = 'robot-ah-releases'

    def __init__(self, perf_resid, pipe, duration):
        self.perf_resid = perf_resid
        self.pipe = pipe
        self.duration = duration
        self.sessionTimeout = duration + 100
        pass

    def run(self):
        from library.copier.helpers import copierShare, copierGet

        pid = subprocess.check_output(['pgrep', '-f',  'app_host-'])
        pid = int(pid)
        temp_dir = mkdtemp()
        os.chmod(temp_dir, 0755)
        perf = copierGet(self.perf_resid, temp_dir)[0]
        tar = tarfile.open(perf)
        tar.extract('./perf', temp_dir)
        tar.close()
        perf = temp_dir + '/perf'
        subprocess.check_call("{} record -N -F 99 -p {} -g --call-graph lbr  -- sleep {}".format(perf, pid, self.duration), shell=True)
        perf_report = temp_dir + '/' + 'out.perf'
        subprocess.check_call("{} script > {}".format(perf, perf_report), shell=True)
        yield copierShare([(perf_report, 'out.perf')])
        self.pipe.get()
        rmtree(temp_dir)


def build_binary(arcadia, target):
    build_path = mkdtemp(dir=os.getcwd())
    arcadiasdk.do_build(
        build_system=sdk_constants.YA_MAKE_FORCE_BUILD_SYSTEM,
        source_root=arcadia,
        targets=[target],
        results_dir=build_path,
        clear_build=False,
    )
    return build_path


def fold_report(arcadia, report):
    fold = arcadia + '/contrib/tools/flame-graph/stackcollapse-perf.pl'
    folded = report[:-len('.perf')] + '.folded'
    sdk_subprocess.call("{} {} > {}".format(fold, report, folded), shell=True)
    return folded


def generate_flame_graph(arcadia, report, out_dir):
    flamegraph_tool = arcadia + '/contrib/tools/flame-graph/flamegraph.pl'
    flamegraph = out_dir + '/' + os.path.basename(report)[:-len('.folded')] + '.svg'
    sdk_subprocess.call("{} {} > {}".format(flamegraph_tool, report, flamegraph), shell=True)
    return flamegraph


def generate_resource(task, path, desc):
    resource = OtherResource(task, desc, path)
    data = ResourceData(resource)
    data.ready()


def get_last_perf():
    attrs = {
        'arcadia_path': 'infra/kernel/tools/perf/build/perf-static.tar.gz',
        'released': 'stable'
    }
    resource = ARCADIA_PROJECT_TGZ.find(attrs=attrs).first()
    return resource.skynet_id


class CollectAppHostFlameGraphs(Task):
    class Requirements(Requirements):
        client_tags = ctc.Tag.GENERIC & ~ctc.Tag.LXC

    class Parameters(Parameters):
        collect_time = sdk2.parameters.Integer("Time in seconds to collect flamegraphs", default=600)
        specify_hosts = sdk2.parameters.Bool("Specify hosts")
        with specify_hosts.value[True]:
            hosts = sdk2.parameters.List("Hosts", default=[])
        with specify_hosts.value[False]:
            attempts_per_vertical = sdk2.parameters.Integer("Attempts per vertical", default=5)

    class Report:
        def __init__(self, vertical, host, path):
            self.vertical = vertical
            self.host = host
            self.path = path
            pass

    def on_create(self):
        if self.Requirements.tasks_resource is None:
            self.Requirements.tasks_resource = SandboxTasksBinary.find(
                attrs={
                    'target': 'sandbox/projects/app_host/CollectAppHostFlameGraphs',
                    'released': ctt.ReleaseStatus.STABLE
                }
            ).first()

    def on_execute(self):
        hosts_by_vertical = None
        if not self.Parameters.specify_hosts:
            hosts_by_vertical = self.get_hosts()

        reports_dir = self.path('reports').as_posix()
        logging.info('reports dir: {}'.format(reports_dir))

        reports = self.get_reports(hosts_by_vertical, reports_dir)

        if not reports:
            return

        self.generate_resources(reports_dir, reports)

    def generate_resources(self, reports_dir, reports):
        with arcadiasdk.mount_arc_path('arcadia-arc:/#trunk') as arcadia:
            flamegraphs_dir = self.path('flamegraphs').as_posix()
            logging.info('flamegraphs dir: {}'.format(flamegraphs_dir))
            for report in reports:
                folded = fold_report(arcadia, report.path)
                flamegraph_dir = flamegraphs_dir + '/' + report.host
                os.makedirs(flamegraph_dir)
                flamegraph = generate_flame_graph(arcadia, folded, flamegraph_dir)
                generate_resource(self, flamegraph, '{} flamegraph'.format(report.vertical or report.host))
            generate_resource(self, reports_dir, 'raw reports')

    def get_hosts(self):
        from app_host.tools.nanny_tools.lib.common import NannyTools, Arguments, Service
        from library.sky.hostresolver import Resolver

        arguments = Arguments()
        arguments.confirmed = True
        arguments.category = None
        nanny = NannyTools(arguments)
        services = nanny.get_services_with_info(labels={
            'itype': 'apphost',
            'ctype': 'prod'
        })
        logging.debug('Found services: {}'.format(services.keys()))
        hosts_by_vertical = {}
        for service_id, service in services.items():
            vertical = service['labels']['prj']

            service_tools = Service(service_id, arguments)
            _, result = service_tools.pull('current_state')
            if result['summary']['value'] == 'OFFLINE':
                continue

            service_hosts = list(Resolver().resolveHosts('M@{}:ACTIVE'.format(service_id)))
            hosts_by_vertical.setdefault(vertical, []).extend(service_hosts)

        for vertical, hosts in hosts_by_vertical.iteritems():
            attempts = min(self.Parameters.attempts_per_vertical, len(hosts))
            hosts_by_vertical[vertical] = random.sample(hosts, attempts)

        return hosts_by_vertical

    def get_reports(self, hosts_by_vertical, reports_dir):
        from api.cqueue import Client
        from library.copier.helpers import copierGet
        yav_secret = YavSecret('sec-01e4be5dpdkp19faymaz114m71')

        perf_resid = get_last_perf()
        logging.info('perf tool resid: {}'.format(perf_resid))

        selected_hosts = []
        vertical_by_host = {}

        if self.Parameters.specify_hosts:
            selected_hosts = self.Parameters.hosts
        else:
            for vertical, hosts in hosts_by_vertical.iteritems():
                host = hosts.pop()
                selected_hosts.append(host)
                vertical_by_host[host] = vertical

        reports = []
        with SshKey(private_part=yav_secret.data()['private_key']), Client() as client:
            client.register_safe_unpickle(obj=subprocess.CalledProcessError)
            pipe = client.create_pipe()
            task = PerfRunner(perf_resid, pipe, self.Parameters.collect_time)
            while selected_hosts:
                with client.iter(selected_hosts, task) as session:
                    selected_hosts = []
                    for host, result, err in session.wait():
                        if isinstance(err, StopIteration):
                            continue

                        vertical = None
                        if not self.Parameters.specify_hosts:
                            vertical = vertical_by_host[host]

                        if err is not None:
                            logging.info("Execution on %r failed: %s" % (host, err))
                            if vertical is not None and hosts_by_vertical[vertical]:
                                next_host = hosts_by_vertical[vertical].pop()
                                selected_hosts.append(next_host)
                                vertical_by_host[next_host] = vertical
                        else:
                            logging.info('report resid from {}: {}'.format(host, result))
                            report_path = copierGet(result, reports_dir + '/' + host)[0]
                            if not os.path.isfile(report_path):
                                raise Exception('{} not found'.format(report_path))
                            report = self.Report(vertical, host, report_path)
                            reports.append(report)
                            pipe.put(None, [host])

        return reports
