# -*- coding: utf-8 -*-

from sandbox import sdk2

from sandbox.projects.mssngr.common import build
from sandbox.projects import resource_types as common_resources
from sandbox.sandboxsdk import environments

from sandbox.projects.release_machine import security as rm_sec
from sandbox.projects.release_machine.core import const as rm_const
from sandbox.projects.release_machine.helpers import startrek_helper as st_helper
import sandbox.projects.release_machine.core.task_env as task_env
import sandbox.projects.release_machine.components.all as rmc

import sandbox.common.types.task as ctt
from sandbox.common.errors import TaskFailure
from multiprocessing import Pool

from sandbox.projects.common.nanny import nanny

import logging
import subprocess as sp
import requests
import re
import time
import tarfile
import datetime


def _init_inflight_hgram():
    return [[x, 0] for x in list(xrange(0, 10, 1)) + list(xrange(10, 100, 10)) + list(xrange(100, 1000, 100))]


def _init_times_hgram():
    return [[x, 0] for x in list(xrange(0, 100, 10)) + list(xrange(100, 1000, 100)) + list(xrange(1000, 10000, 1000))]


def _get_static_yasm_image(chart, start, end, attempt=0):
    yasm_url = "https://yasm.yandex-team.ru/{}/?from={}&to={}".format(chart, start * 1000, end * 1000)
    r = requests.get("https://s.yasm.yandex-team.ru/{}/?from={}&to={}&width=1610&static=1".format(chart, start * 1000, end * 1000), allow_redirects=False)
    if r.status_code == 302:
        return yasm_url, r.headers["Location"]
    elif attempt > 2:
        return yasm_url, "https://yasm.yandex-team.ru/assets/2.192.0/images/hal9000.525e03.png"
    else:
        return _get_static_yasm_image(chart, start, end, attempt + 1)


def st_format_comment(start, end, href, image):
    start_date = datetime.datetime.fromtimestamp(start)
    end_date = datetime.datetime.fromtimestamp(end)
    return "Стрельба завершена\n\n" + """<{{График стрельбы {} от {} до {}
(({} {}))
}}>""".format(start_date.strftime("%Y-%m-%d"), start_date.strftime("%H:%M"), end_date.strftime("%H:%M"), href, image)


class JugglerClient(object):
    address = "https://juggler-api.search.yandex.net"

    def __init__(self, token, logger):
        self.token = token
        self.logger = logger

    def set_downtimes(self, start_time, end_time):
        services = ["{}_{}".format(loc, sig) for sig in ["cpu_throttled", "cpu_usage_max", "cpu_usage_med", "cpu_wait_max", "cpu_wait_med"] for loc in ["man", "vla", "sas"]]
        data = {
            "description": "Автоматические стрельбы",
            "start_time": start_time,
            "end_time": end_time,
            "source": "sandbox",
            "filters": [{"host": "testing_router_workers", "service": service, "namespace": "mssngr"} for service in services]
                     + [{"host": "testing_router_balancers", "service": service, "namespace": "mssngr"} for service in services],
        }
        r = requests.post(
            self.address + "/v2/downtimes/set_downtimes",
            json=data,
            headers={"Authorization": "OAuth " + self.token}
        )
        self.logger.debug("juggler respond %s: %s" % (r.status_code, r.text))


class StatsFetcher(object):
    def __call__(self, hostInfo):
        host = hostInfo["host"]
        port = hostInfo["port"]

        try:
            return requests.get("http://{}:{}/_golovan".format(host, port)).json()
        except Exception as e:
            logger = logging.getLogger("golovan")
            logger.exception("Failed to fetch stats", exc_info=e)
            return {}


class MssngrRouterLoadTest(sdk2.Task):
    """Runs load test"""

    class Requirements(build.CommonRequirements):
        disk_space = 100 * 1024
        cores = 12
        ram = 4 * 1024
        environments = [
            environments.PipEnvironment("numpy"),
            environments.PipEnvironment("psycopg2-binary"),
            environments.PipEnvironment("yandex-yt"),
            environments.PipEnvironment("yandex-yt-yson-bindings-skynet"),
            task_env.TaskRequirements.startrek_client,
        ]
        semaphores = ctt.Semaphores(
            acquires=[
                ctt.Semaphores.Acquire(name="mssngr/perftest-ydb")
            ],
            release=(ctt.Status.Group.BREAK, ctt.Status.Group.FINISH)
        )
        client_tags = task_env.TaskTags.startrek_client

    class Parameters(sdk2.Task.Parameters):
        yt_proxy = sdk2.parameters.String(
            label="YT proxy",
            group="YT",
            default="hahn",
        )

        yt_path = sdk2.parameters.String(
            label="Output table path",
            group="YT",
            default="//home/mssngr/perf/results",
        )

        shoot_binary_package = sdk2.parameters.Resource(
            label="Shoot binary package",
            group="Shoot options",
            resource_type=common_resources.YA_PACKAGE,
            default=1239492292
        )

        rps_plain = sdk2.parameters.Integer(
            label="Plain RPS",
            group="Shoot options",
            default=300
        )

        rps_seenmarkers = sdk2.parameters.Integer(
            label="Seenmarkers RPS",
            group="Shoot options",
            default=800
        )

        rps_guid_history = sdk2.parameters.Integer(
            label="History requests by Guid per second",
            group="Shoot options",
            default=400
        )

        rps_chat_id_history = sdk2.parameters.Integer(
            label="History requests by ChatId per second",
            group="Shoot options",
            default=800
        )

        duration = sdk2.parameters.Integer(
            label="Duration (seconds)",
            group="Shoot options",
            default=600
        )

        secret_owner = sdk2.parameters.String(
            label="Secrets owner",
            group="Secrets",
            default="MSSNGR"
        )

        nanny_token_secret = sdk2.parameters.String(
            label="Nanny token secret name",
            group="Secrets",
            default="nanny_oauth_token"
        )

        yt_token_secret = sdk2.parameters.String(
            label="YT token secret name",
            group="Secrets",
            default="yt-token"
        )

        st_token_secret = sdk2.parameters.String(
            label="StarTrek token secret name",
            group="Secrets",
            default="st-token"
        )

        juggler_token_secret = sdk2.parameters.String(
            label="Juggler token secret name",
            group="Secrets",
            default="juggler-token"
        )

        debug_secret = sdk2.parameters.String(
            label="Mssngr debug secret",
            group="Secrets",
            default="mssngr-debug-secret"
        )

        release_number = sdk2.parameters.Integer(
            label="Release number",
            group="Settings"
        )

        yasm_chart = sdk2.parameters.String(
            label="Yasm chart",
            group="Settings"
        )

    worker_stats = {
        "handler_push_global_inflight_ammv":                       _init_inflight_hgram(),
        "handler_push_privileged_global_inflight_ammv":            _init_inflight_hgram(),
        "handler_message_info_global_inflight_ammv":               _init_inflight_hgram(),
        "handler_message_info_batch_global_inflight_ammv":         _init_inflight_hgram(),
        "handler_history_global_inflight_ammv":                    _init_inflight_hgram(),
        "handler_edit_history_global_inflight_ammv":               _init_inflight_hgram(),
        "handler_subscribe_global_inflight_ammv":                  _init_inflight_hgram(),
        "handler_push_payload_global_inflight_ammv":               _init_inflight_hgram(),
        "handler_whoami_global_inflight_ammv":                     _init_inflight_hgram(),
    }

    worker_timing_stats = {
        "conveyor_dbpostproc_processing_time_hgram":                _init_times_hgram(),
        "conveyor_dbsaver_processing_time_hgram":                   _init_times_hgram(),
        "conveyor_depresolv_processing_time_hgram":                 _init_times_hgram(),
        "conveyor_registrysync_processing_time_hgram":              _init_times_hgram(),
        "conveyor_sharefiles_processing_time_hgram":                _init_times_hgram(),
        "fin_history_message_processing_times_hgram":               _init_times_hgram(),
        "fin_other_message_processing_times_hgram":                 _init_times_hgram(),
        "fin_queue_sizes_hgram":                                    _init_times_hgram(),
        "fin_total_processing_times_hgram":                         _init_times_hgram(),
        "postproc_block_processing_time_hgram":                     _init_times_hgram(),
        "postproc_context_processing_time_hgram":                   _init_times_hgram(),
        "state_kikimr_add_message_processing_times_hgram":          _init_times_hgram(),
        "state_kikimr_chat_edit_history_processing_times_hgram":    _init_times_hgram(),
        "state_kikimr_chat_history_processing_times_hgram":         _init_times_hgram(),
        "state_kikimr_edit_message_processing_times_hgram":         _init_times_hgram(),
        "state_kikimr_init_chat_state_processing_times_hgram":      _init_times_hgram(),
        "state_kikimr_message_info_processing_times_hgram":         _init_times_hgram(),
        "state_kikimr_read_chat_state_processing_times_hgram":      _init_times_hgram(),
        "state_kikimr_update_chat_state_processing_times_hgram":    _init_times_hgram(),
        "state_kikimr_update_reaction_processing_times_hgram":      _init_times_hgram(),
        "state_kikimr_update_seen_markers_processing_times_hgram":  _init_times_hgram(),
        "uniproxy_processing_time_hgram":                           _init_times_hgram(),
    }

    balancer_stats = {
        "handler_push_global_inflight_ammv":                _init_inflight_hgram(),
        "handler_push_privileged_global_inflight_ammv":     _init_inflight_hgram(),
        "handler_message_info_global_inflight_ammv":        _init_inflight_hgram(),
        "handler_message_info_batch_global_inflight_ammv":  _init_inflight_hgram(),
        "handler_history_global_inflight_ammv":             _init_inflight_hgram(),
        "handler_unread_count_global_inflight_ammv":        _init_inflight_hgram(),
        "handler_edit_history_global_inflight_ammv":        _init_inflight_hgram(),
        "handler_subscribe_global_inflight_ammv":           _init_inflight_hgram(),
        "handler_push_payload_global_inflight_ammv":        _init_inflight_hgram(),
        "handler_whoami_global_inflight_ammv":              _init_inflight_hgram(),
    }

    balancer_timing_stats = {
        "fin_history_message_processing_times_hgram":  _init_times_hgram(),
        "fin_other_message_processing_times_hgram":    _init_times_hgram(),
        "fin_total_processing_times_hgram":            _init_times_hgram(),
        "postproc_block_processing_time_hgram":        _init_times_hgram(),
        "postproc_context_processing_time_hgram":      _init_times_hgram(),
    }
    shoot_address = "http://testing.l3.mssngr.yandex.net:31925"

    def update_stats(self, stats, input_stats):
        for signal, value in input_stats.iteritems():
            if signal in stats:
                buckets = stats[signal]
                for idx in xrange(len(buckets) - 2):
                    if buckets[idx][0] <= value and value < buckets[idx + 1][0]:
                        buckets[idx][1] += 1
                        break
                else:
                    buckets[-1][1] += 1

    def update_hgram_stats(self, stats, input_stats):
        for signal, input_buckets in input_stats.iteritems():
            if signal in stats:
                buckets = stats[signal]
                for b in input_buckets:
                    for idx in xrange(len(buckets) - 2):
                        if buckets[idx][0] <= b[0] and b[0] < buckets[idx + 1][0]:
                            buckets[idx][1] += b[1]
                            break
                    else:
                        buckets[-1][1] += b[1]

    def start_shooting(self, shooter, env_spec, balancers, workers, rps_plain, rps_seenmarkers, rps_chat_id_history, rps_guid_history, duration):
        logger = logging.getLogger("shoot")
        juggler_token = str(sdk2.Vault.data(self.Parameters.secret_owner, self.Parameters.juggler_token_secret))
        juggler = JugglerClient(juggler_token, logger)
        start_time = int(time.time())
        juggler.set_downtimes(start_time, start_time + duration + 300)

        fetcher = StatsFetcher()
        with sdk2.helpers.ProcessLog(self, logger=logging.getLogger("shoot")) as pl:
            cmd = [
                shooter,
                "--silent",
                "--tvm-id", "2001089",
                "--tvm-secret", str(sdk2.Vault.data(self.Parameters.secret_owner, self.Parameters.debug_secret)),
                "--tvm-dst", "2001077",
                "-e", env_spec,
                "-M", str(rps_plain),
                "-S", str(rps_seenmarkers),
                "-H", str(rps_chat_id_history),
                "-G", str(rps_guid_history),
                "-d", "0",
                "-D", str(duration),
                "-a", self.shoot_address
            ]
            startTs = int(time.time())
            p = sp.Popen(cmd, stdout=pl.stdout, stderr=sp.STDOUT)

            pool = Pool(8)

            while p.poll() is None:
                for s in pool.map(fetcher, balancers):
                    self.update_stats(self.balancer_stats, dict(s))
                for s in pool.map(fetcher, workers):
                    self.update_stats(self.worker_stats, dict(s))
                time.sleep(2)
            p.wait()

            endTs = int(time.time())

        self.worker_stats.update(self.worker_timing_stats)
        self.balancer_stats.update(self.balancer_timing_stats)

        for b in balancers:
            stats = fetcher(b)
            self.update_hgram_stats(self.balancer_stats, dict(filter(lambda x: x[0].endswith('_hgram'), stats)))

        for w in workers:
            stats = fetcher(w)
            self.update_hgram_stats(self.worker_stats, dict(filter(lambda x: x[0].endswith('_hgram'), stats)))
        return startTs, endTs

    def on_execute(self):
        def host_available(host, port):
            try:
                r = requests.get("http://{}:{}/admin?action=ping".format(host, port), timeout=2)
                return r.status_code == 200
            except:
                return False

        timestamp = int(time.time())
        logger = logging.getLogger("init")

        nanny_token = str(sdk2.Vault.data(self.Parameters.secret_owner, self.Parameters.nanny_token_secret))
        nanny_client = nanny.NannyClient(rm_const.Urls.NANNY_BASE_URL, nanny_token)

        workers = []
        for service_id in ["testing_mssngr_router_workers_yp_sas", "testing_mssngr_router_workers_yp_man", "testing_mssngr_router_workers_yp_vla"]:
            workers += [{"host": h["container_hostname"], "port": h["port"]}
                        for h in nanny_client.get_service_current_instances(service_id)['result']
                        if host_available(h["container_hostname"], h["port"])]

        balancers = []
        for service_id in ["testing_mssngr_router_balancers_yp_sas", "testing_mssngr_router_balancers_yp_man", "testing_mssngr_router_balancers_yp_vla"]:
            balancers += [{"host": h["container_hostname"], "port": 31925}
                          for h in nanny_client.get_service_current_instances(service_id)['result']
                          if host_available(h["container_hostname"], 31925)]

        logger.info("workers: %s" % workers)
        logger.info("balancers: %s" % balancers)

        r = requests.get("http://{}:{}/admin?action=version".format(balancers[0]["host"], balancers[0]["port"]))
        version_info = r.text.split('\n')
        svn_info_idx = version_info.index("Svn info:")

        for l in version_info[svn_info_idx:]:
            m = re.match(".*/arc/tags/fanout/(stable-.+)/arcadia", l)
            if m:
                self.tag = m.group(1)
            m = re.match(".*Last Changed Rev:\s*([0-9]+)", l)
            if m:
                self.revision = m.group(1)

        if not self.tag or not self.revision:
            raise TaskFailure("failed to get svn revision: version info: %s" % version_info)

        logger.debug("tag:{}, revision:{}".format(self.tag, self.revision))

        import yt.wrapper as yt
        yt.config.set_proxy(self.Parameters.yt_proxy)
        yt.config.config['token'] = str(sdk2.Vault.data(self.Parameters.secret_owner, self.Parameters.yt_token_secret))
        if not yt.exists(self.Parameters.yt_path):
            schema = [
                {"name": "Timestamp", "type": "uint64", "sort_order": "ascending"},
                {"name": "Revision", "type": "uint64"},
                {"name": "Tag", "type": "string"},
                {"name": "Task", "type": "uint64"},
                {"name": "Signal", "type": "string"},
                {"name": "Source", "type": "string"},
                {"name": "P50", "type": "uint64"},
                {"name": "P80", "type": "uint64"},
                {"name": "P90", "type": "uint64"},
                {"name": "P95", "type": "uint64"},
                {"name": "P99", "type": "uint64"},
            ]
            yt.create("table", self.Parameters.yt_path, attributes={"schema": schema, "optimize_for": "scan"})

        with tarfile.open(str(sdk2.ResourceData(self.Parameters.shoot_binary_package).path), 'r') as f:
            f.extractall(str(self.path(".")))

        # Shoot

        startTs, endTs = self.start_shooting(
            str(self.path("shoot/shoot")),
            str(self.path("env_spec.json")),
            balancers,
            workers,
            rps_plain=self.Parameters.rps_plain,
            rps_seenmarkers=self.Parameters.rps_seenmarkers,
            rps_chat_id_history=self.Parameters.rps_chat_id_history,
            rps_guid_history=self.Parameters.rps_guid_history,
            duration=self.Parameters.duration
        )

        # Collect stats
        from stats import quant
        table = "<append=%true>" + self.Parameters.yt_path
        rows = []

        def _make_row(signal, source, timestamp, task, revision, tag, buckets):
            quantiles = [0.5, 0.8, 0.9, 0.95, 0.99]
            row = {
                "Source": source,
                "Signal": signal,
                "Timestamp": timestamp,
                "Task": int(task),
                "Revision": int(revision),
                "Tag": tag,
            }
            for q, value in zip(quantiles, quant(buckets, quantiles)):
                row["P{}".format(int(q * 100))] = value
            return row

        for signal, buckets in self.balancer_stats.iteritems():
            rows.append(_make_row(signal, "balancer", timestamp, self.id, self.revision, self.tag, buckets))

        for signal, buckets in self.worker_stats.iteritems():
            rows.append(_make_row(signal, "worker", timestamp, self.id, self.revision, self.tag, buckets))

        yt.write_table(table, rows, raw=False)

        # write StarTrek comment
        if self.Parameters.st_token_secret:
            st_token = str(sdk2.Vault.data(self.Parameters.secret_owner, self.Parameters.st_token_secret))
        else:
            st_token = str(sdk2.Vault.data(rm_const.COMMON_TOKEN_OWNER, rm_const.COMMON_TOKEN_NAME))
        st = st_helper.STHelper(st_token)
        c_info = rmc.get_component('fanout')
        href, image = _get_static_yasm_image(self.Parameters.yasm_chart, startTs, endTs)
        st.comment(self.Parameters.release_number, st_format_comment(startTs, endTs, href, image), c_info)
