import itertools
import logging
import os

import jinja2
from enum import Enum

from sandbox import sdk2
from sandbox.common.types.task import Status
from sandbox.common.urls import get_task_link
from sandbox.projects.common import constants as common_constants
from sandbox.projects.common.testenv_client import TEClient
from sandbox.projects.common.yabs.server.util import truncate_output_parameters
from sandbox.projects.common.yabs.server.util.general import check_tasks
from sandbox.projects.yabs.release.tasks.BuildYabsServer import BuildYabsServer
from sandbox.projects.yabs.qa.pipeline.stage import stage
from sandbox.projects.yabs.qa.resource_types import BS_RELEASE_TAR
from sandbox.projects.yabs.qa.tasks.YabsServerB2BFuncShootCmp import YabsServerB2BFuncShootCmp
from sandbox.projects.yabs.qa.tasks.YabsServerB2BFuncShootStability import YabsServerB2BFuncShootStability
from sandbox.projects.yabs.qa.tasks.YabsServerPerformanceMetaCmp import YabsServerPerformanceMetaCmp
from sandbox.projects.yabs.qa.tasks.YabsServerStatPerformanceBestCmp2 import YabsServerStatPerformanceBestCmp2
from sandbox.projects.yabs.qa.tasks.YabsServerStatPerformanceSanitize import YabsServerStatPerformanceSanitize
from sandbox.projects.yabs.qa.template_utils import get_template
from sandbox.sdk2.vcs.svn import Arcadia


logger = logging.getLogger(__name__)


def build_yabs_server(
    task,
    base_revision,
    arcadia_patch,
    **kwargs
):
    checkout_arcadia_from_url = Arcadia.ARCADIA_TRUNK_URL
    if base_revision:
        checkout_arcadia_from_url = Arcadia.replace(checkout_arcadia_from_url, revision=base_revision)

    task = BuildYabsServer(
        task,
        tags=task.Parameters.tags + ['WITH-PATCH'],
        hints=list(task.hints),
        description=(
            'Build server on revision {revision} with patch {patch}'
            .format(revision=base_revision, patch=arcadia_patch)
        ),
        packages_to_build={
            "yabs/server/packages/yabs-server-bundle-brave.json": BS_RELEASE_TAR.name,
        },
        checkout_arcadia_from_url=checkout_arcadia_from_url,
        arcadia_patch=arcadia_patch,
        rebase_onto_revision=base_revision,
        yt_proxy="hahn",
        max_retries=0,
        **kwargs
    ).enqueue()
    return task


@stage(provides="build_yabs_server_release_task_id")
def build_yabs_server_release(
    task,
    base_revision=None,
    arcadia_patch=None,
):
    task = build_yabs_server(
        task,
        base_revision,
        arcadia_patch,
        build_type=common_constants.RELEASE_BUILD_TYPE,
    )
    return task.id


@stage(provides="build_yabs_server_sanitize_address_task_id")
def build_yabs_server_sanitize_address(
    task,
    base_revision=None,
    arcadia_patch=None,
):
    task = build_yabs_server(
        task,
        base_revision,
        arcadia_patch,
        build_type=common_constants.PROFILE_BUILD_TYPE,
        sanitize='address',
        tools_to_build={},
        packages_to_build_debug={},
    )
    return task.id


@stage(provides="build_yabs_server_sanitize_memory_task_id")
def build_yabs_server_sanitize_memory(
    task,
    base_revision=None,
    arcadia_patch=None,
):
    task = build_yabs_server(
        task,
        base_revision,
        arcadia_patch,
        build_type=common_constants.PROFILE_BUILD_TYPE,
        sanitize='memory',
        tools_to_build={},
        packages_to_build_debug={},
    )
    return task.id


def run_shoot_task(
    task,
    baseline_shoot_task_id,
    shoot_task_params,
    task_type=None,
):
    baseline_shoot_task = sdk2.Task[baseline_shoot_task_id]
    if task_type is None:
        task_type = baseline_shoot_task.type

    task_parameters = truncate_output_parameters(
        dict(sdk2.Task[baseline_shoot_task_id].Parameters),
        task_type.Parameters,
    )
    task_parameters.update(shoot_task_params)

    shoot_task = task_type(
        task,
        tags=task.Parameters.tags + ['WITH-PATCH'],
        hints=list(task.hints),
        description=(
            'Copy of task {source_task_type} #{source_task_id} '
            'with updated parameters {updated_params}'
            .format(
                source_task_type=baseline_shoot_task.type.name,
                source_task_id=baseline_shoot_task_id,
                updated_params=shoot_task_params,
            )
        ),
        **task_parameters
    ).enqueue()
    return shoot_task


def _run_ft_shoot_tasks(
    task,
    baseline_shoot_task_ids,
    custom_task_parameters=None,
    task_type=None,
):
    shoot_task_ids = {}
    for role, baseline_task_id in baseline_shoot_task_ids.items():
        if not baseline_task_id:
            continue
        shoot_task = run_shoot_task(
            task,
            baseline_shoot_task_id=baseline_task_id,
            shoot_task_params=custom_task_parameters,
            task_type=task_type,
        )
        shoot_task_ids[role] = shoot_task.id

    return shoot_task_ids


@stage(provides="ft_meta_shoot_task_ids")
def run_ft_meta_shoot_tasks(
    task,
    build_yabs_server_release_task_id,
    baseline_shoot_task_ids=None,
):
    check_tasks(task, build_yabs_server_release_task_id)
    bs_release_tar_resource_id = sdk2.Task[build_yabs_server_release_task_id].Parameters.bs_release_tar_resource.id
    custom_task_parameters = {
        "meta_server_resource": bs_release_tar_resource_id,
    }
    return _run_ft_shoot_tasks(task, baseline_shoot_task_ids, custom_task_parameters=custom_task_parameters)


@stage(provides="ft_shoot_task_ids")
def run_ft_shoot_tasks(
    task,
    build_yabs_server_release_task_id,
    baseline_shoot_task_ids=None,
):
    check_tasks(task, build_yabs_server_release_task_id)
    bs_release_tar_resource_id = sdk2.Task[build_yabs_server_release_task_id].Parameters.bs_release_tar_resource.id
    custom_task_parameters = {
        "meta_server_resource": bs_release_tar_resource_id,
        "stat_server_resource": bs_release_tar_resource_id,
    }
    return _run_ft_shoot_tasks(task, baseline_shoot_task_ids, custom_task_parameters=custom_task_parameters)


@stage(provides="ft_stability_task_ids")
def run_ft_stability_tasks(
    task,
    build_yabs_server_release_task_id,
    baseline_shoot_task_ids=None,
):
    check_tasks(task, build_yabs_server_release_task_id)
    bs_release_tar_resource_id = sdk2.Task[build_yabs_server_release_task_id].Parameters.bs_release_tar_resource.id
    custom_task_parameters = {
        "meta_server_resource": bs_release_tar_resource_id,
        "stat_server_resource": bs_release_tar_resource_id,
        "stability_runs": 2,
        "reuse_shoot_task": True,
        "upload_to_yt_prefix": "",
    }
    return _run_ft_shoot_tasks(task, baseline_shoot_task_ids, custom_task_parameters=custom_task_parameters, task_type=YabsServerB2BFuncShootStability)


def _run_ft_shoot_cmp_tasks(
    task,
    ft_shoot_task_ids,
    baseline_shoot_task_ids,
    arcanum_review_id=0,
):
    check_tasks(task, ft_shoot_task_ids.values())

    ft_shoot_cmp_task_ids = {}
    for role, test_task_id in ft_shoot_task_ids.items():
        compare_task = YabsServerB2BFuncShootCmp(
            task,
            tags=task.Parameters.tags,
            hints=list(task.hints),
            description=(
                'compare tasks {pre_task_id} (original) and {test_task_id} (patched)'
                .format(
                    pre_task_id=baseline_shoot_task_ids[role],
                    test_task_id=test_task_id,
                )
            ),
            pre_task=baseline_shoot_task_ids[role],
            test_task=test_task_id,
            test_name=get_testenv_job_name_from_tags(sdk2.Task[baseline_shoot_task_ids[role]].Parameters.tags),
            arcanum_review_id=arcanum_review_id,
        ).enqueue()
        ft_shoot_cmp_task_ids[role] = compare_task.id

    return ft_shoot_cmp_task_ids


@stage(provides="ft_shoot_cmp_task_ids")
def run_ft_shoot_cmp_tasks(
    task,
    ft_shoot_task_ids,
    baseline_shoot_task_ids=None,
    arcanum_review_id=0,
):
    return _run_ft_shoot_cmp_tasks(task, ft_shoot_task_ids, baseline_shoot_task_ids, arcanum_review_id=arcanum_review_id)


@stage(provides="ft_meta_shoot_cmp_task_ids")
def run_ft_meta_shoot_cmp_tasks(
    task,
    ft_meta_shoot_task_ids,
    baseline_shoot_task_ids=None,
    arcanum_review_id=0,
):
    return _run_ft_shoot_cmp_tasks(task, ft_meta_shoot_task_ids, baseline_shoot_task_ids, arcanum_review_id=arcanum_review_id)


@stage(provides="stat_performance_shoot_task_ids")
def run_stat_performance_shoot_tasks(
    task,
    build_yabs_server_release_task_id,
    baseline_shoot_task_ids=None,
):
    check_tasks(task, build_yabs_server_release_task_id)

    bs_release_tar_resource_id = sdk2.Task[build_yabs_server_release_task_id].Parameters.bs_release_tar_resource.id

    custom_task_parameters = {
        "meta_server_resource": bs_release_tar_resource_id,
        "stat_server_resource": bs_release_tar_resource_id,
    }
    shoot_task_ids = {}
    for role, baseline_task_id in baseline_shoot_task_ids.items():
        if not baseline_task_id:
            continue
        shoot_task = run_shoot_task(
            task,
            baseline_shoot_task_id=baseline_task_id,
            shoot_task_params=custom_task_parameters,
        )
        shoot_task_ids[role] = shoot_task.id

    return shoot_task_ids


@stage(provides="stat_performance_shoot_cmp_task_ids")
def run_stat_performance_shoot_cmp_tasks(
    task,
    stat_performance_shoot_task_ids,
    baseline_shoot_task_ids=None,
):
    check_tasks(task, stat_performance_shoot_task_ids.values())

    # https://a.yandex-team.ru/arc/trunk/arcadia/testenv/jobs/yabs/declarative_generator.py?rev=r8615126#L316
    perf_settings_mode_sensitive = {
        "yabs": {
            "stat_shoot_request_limit": 10**6,
            "rps_diff_threshold_percent": 1.1,
        },
        "bs": {
            "stat_shoot_request_limit": 5 * 10**5,
            "rps_diff_threshold_percent": 1.1,
        },
        "bsrank": {
            "stat_shoot_request_limit": 15 * 10**5,
            "rps_diff_threshold_percent": 2.,
        },
    }

    cmp_task_ids = {}
    for role, test_task_id in stat_performance_shoot_task_ids.items():
        compare_task = YabsServerStatPerformanceBestCmp2(
            task,
            tags=task.Parameters.tags,
            hints=list(task.hints),
            description=(
                'compare tasks {pre_task_id} (original) and {test_task_id} (patched)'
                .format(
                    pre_task_id=baseline_shoot_task_ids[role],
                    test_task_id=test_task_id,
                )
            ),
            pre_task=baseline_shoot_task_ids[role],
            test_task=test_task_id,
            test_name="",
            rps_diff_threshold_percent=perf_settings_mode_sensitive[role]["rps_diff_threshold_percent"],
            stat_shoot_request_limit=perf_settings_mode_sensitive[role]["stat_shoot_request_limit"],
            improvement_is_diff=True,
            rss_diff_threshold_percent=0,
            need_ram_check=False,
            use_check_task_shoot_parameters=False,
            stat_shoot_mode="finger",
        ).enqueue()
        cmp_task_ids[role] = compare_task.id

    return cmp_task_ids


@stage(provides="meta_performance_shoot_task_ids")
def run_meta_performance_shoot_tasks(
    task,
    build_yabs_server_release_task_id,
    baseline_shoot_task_ids=None,
):
    check_tasks(task, build_yabs_server_release_task_id)

    bs_release_tar_resource_id = sdk2.Task[build_yabs_server_release_task_id].Parameters.bs_release_tar_resource.id

    custom_task_parameters = {
        "meta_server_resource": bs_release_tar_resource_id,
    }
    shoot_task_ids = {}
    for role, baseline_task_id in baseline_shoot_task_ids.items():
        if not baseline_task_id:
            continue
        shoot_task = run_shoot_task(
            task,
            baseline_shoot_task_id=baseline_task_id,
            shoot_task_params=custom_task_parameters,
        )
        shoot_task_ids[role] = shoot_task.id

    return shoot_task_ids


@stage(provides="meta_performance_shoot_cmp_task_ids")
def run_meta_performance_shoot_cmp_tasks(
    task,
    meta_performance_shoot_task_ids,
    baseline_shoot_task_ids=None,
):
    check_tasks(task, meta_performance_shoot_task_ids.values())

    cmp_task_ids = {}
    for role, test_task_id in meta_performance_shoot_task_ids.items():
        compare_task = YabsServerPerformanceMetaCmp(
            task,
            tags=task.Parameters.tags,
            hints=list(task.hints),
            description=(
                'compare tasks {pre_task_id} (original) and {test_task_id} (patched)'
                .format(
                    pre_task_id=baseline_shoot_task_ids[role],
                    test_task_id=test_task_id,
                )
            ),
            pre_task=baseline_shoot_task_ids[role],
            test_task=test_task_id,
            test_name="",
            rps_diff_threshold_percent=2.0 if role == 'bsrank' else 1.0,
            improvement_is_diff=True,
            rss_diff_threshold_percent=0,
            cmp_shoot_request_limit=200000 if role == 'bsrank' else 100000,
        ).enqueue()
        cmp_task_ids[role] = compare_task.id

    return cmp_task_ids


def run_sanitize_shoot_tasks(
    task,
    build_yabs_server_sanitize_address_task_id,
    baseline_shoot_task_ids=None,
    custom_params=None,
):
    check_tasks(task, build_yabs_server_sanitize_address_task_id)

    bs_release_tar_resource_id = sdk2.Task[build_yabs_server_sanitize_address_task_id].Parameters.bs_release_tar_resource.id

    testenv_resources = TEClient.get_resources("yabs-2.0")

    tasks = {}
    for role, baseline_task_id in baseline_shoot_task_ids.items():
        if not baseline_task_id:
            continue
        baseline_shoot_task = sdk2.Task[baseline_task_id]

        dolbilka_plan_resource_name = "YABS_SERVER_DOLBILKA_PLAN_{}".format(role.upper())
        for testenv_resource in testenv_resources:
            if testenv_resource["name"] == dolbilka_plan_resource_name and testenv_resource["status"] == "OK":
                shoot_plan_resource = testenv_resource["resource_id"]

        params = {
            "shoot_plan_resource": shoot_plan_resource,
            "cache_daemon_stub_resource": baseline_shoot_task.Parameters.cache_daemon_stub_resource,
            "meta_binary_base_resources": baseline_shoot_task.Parameters.meta_binary_base_resources,
            "stat_binary_base_resources": baseline_shoot_task.Parameters.stat_binary_base_resources,
            "stat_shards": baseline_shoot_task.Parameters.stat_shards,
            "meta_server_resource": bs_release_tar_resource_id,
            "stat_server_resource": bs_release_tar_resource_id,
            "meta_mode": role,
            "stat_shoot_sessions": 1,
            "stat_shoot_request_limit": 500000,
            "shoot_request_limit": 100000,
            "store_plan_in_memory": True,
            "stat_store_request_log": True,
            "shoot_threads": 16,
            "generic_disk_space": 300,
            "shard_space": 220,
            "run_perf": False,
            "specify_cluster_set_config": True,
            "need_ram_check": False,
            "prepare_stat_dplan": True,
        }
        params.update(custom_params or {})
        shoot_task = YabsServerStatPerformanceSanitize(
            task,
            tags=task.Parameters.tags + ['WITH-PATCH'],
            hints=list(task.hints),
            **params
        ).enqueue()
        tasks[role] = shoot_task.id

    return tasks


@stage(provides="sanitize_address_shoot_task_ids")
def run_sanitize_address_shoot_tasks(
    task,
    build_yabs_server_sanitize_address_task_id,
    baseline_shoot_task_ids=None,
):
    tasks = run_sanitize_shoot_tasks(
        task,
        build_yabs_server_sanitize_address_task_id,
        baseline_shoot_task_ids,
        custom_params={
            "meta_custom_environment": {
                "ASAN_OPTIONS": "detect_stack_use_after_return=1"
            },
            "stat_custom_environment": {
                "ASAN_OPTIONS": "detect_stack_use_after_return=1"
            }
        }
    )
    return tasks


@stage(provides="sanitize_memory_shoot_task_ids")
def run_sanitize_memory_shoot_tasks(
    task,
    build_yabs_server_sanitize_memory_task_id,
    baseline_shoot_task_ids=None,
):
    tasks = run_sanitize_shoot_tasks(
        task,
        build_yabs_server_sanitize_memory_task_id,
        baseline_shoot_task_ids,
    )
    return tasks


def get_report_template():
    return get_template(
        "report_template.html",
        templates_dir=os.path.dirname(__file__),
        undefined=jinja2.StrictUndefined,
        trim_blocks=True,
        lstrip_blocks=True,
    )


@stage(provides=("report_html", "has_diff"))
def generate_ft_report(
    task,
    ft_shoot_cmp_task_ids,
    ft_stability_task_ids,
    ft_meta_shoot_cmp_task_ids=None,
):
    ft_meta_shoot_cmp_task_ids = {}

    check_tasks(task, ft_shoot_cmp_task_ids.values() + ft_meta_shoot_cmp_task_ids.values())
    check_tasks(task, ft_stability_task_ids.values(), raise_on_fail=False)

    has_diff = any([
        sdk2.Task[task_id].Context.has_diff for task_id in ft_shoot_cmp_task_ids.values() + ft_meta_shoot_cmp_task_ids.values()
    ] + [
        sdk2.Task[task_id].status != Status.SUCCESS
    ])

    tests = {}

    class TestTypes(Enum):
        ft = "ft"
        ft_meta = "ft_meta (with stable stat)"
        ft_stability = "ft stability (AA-test with patched meta and stat)"

    for test_type, cmp_tasks in (
        (TestTypes.ft.value, ft_shoot_cmp_task_ids),
        (TestTypes.ft_meta.value, ft_meta_shoot_cmp_task_ids),
    ):
        for role, task_id in cmp_tasks.items():
            tests.setdefault(test_type, {})[role] = {
                "ok": not sdk2.Task[task_id].Context.has_diff,
                "task_link": get_task_link(task_id),
                "task_id": task_id,
                "misc_text": getattr(sdk2.Task[task_id].Context, "short_report_text", ""),
            }

    for role, task_id in ft_stability_task_ids.items():
        tests.setdefault(TestTypes.ft_stability.value, {})[role] = {
            "ok": sdk2.Task[task_id].status == Status.SUCCESS,
            "task_link": get_task_link(task_id),
            "task_id": task_id,
            "misc_text": "",
        }

    report_data = [
        {
            "test_type": test_type.value,
            "tests": tests.get(test_type.value, {}),
        }
        for test_type in TestTypes
    ]

    template = get_report_template()
    return template.render(data=report_data), has_diff


@stage(provides=("report_html", "has_diff"))
def generate_stat_performance_report(
    task,
    stat_performance_shoot_cmp_task_ids,
):
    check_tasks(task, stat_performance_shoot_cmp_task_ids.values())

    has_diff = any(sdk2.Task[task_id].Context.has_diff for task_id in stat_performance_shoot_cmp_task_ids.values())

    tests = {}
    for role, task_id in stat_performance_shoot_cmp_task_ids.items():
        tests[role] = {
            "ok": not sdk2.Task[task_id].Context.has_diff,
            "task_link": get_task_link(task_id),
            "task_id": task_id,
            "misc_text": getattr(sdk2.Task[task_id].Context, "short_report_text", ""),
        }

    report_data = [{
        "test_type": "stat performance",
        "tests": tests,
    }]

    template = get_report_template()
    return template.render(data=report_data), has_diff


@stage(provides=("report_html", "has_diff"))
def generate_meta_performance_report(
    task,
    meta_performance_shoot_cmp_task_ids,
):
    check_tasks(task, meta_performance_shoot_cmp_task_ids.values())

    has_diff = any(sdk2.Task[task_id].Context.has_diff for task_id in meta_performance_shoot_cmp_task_ids.values())

    tests = {}
    for role, task_id in meta_performance_shoot_cmp_task_ids.items():
        tests[role] = {
            "ok": not sdk2.Task[task_id].Context.has_diff,
            "task_link": get_task_link(task_id),
            "task_id": task_id,
            "misc_text": getattr(sdk2.Task[task_id].Context, "short_report_text", ""),
        }

    report_data = [{
        "test_type": "meta performance",
        "tests": tests,
    }]

    template = get_report_template()
    return template.render(data=report_data), has_diff


@stage(provides=("report_html", "has_diff"))
def generate_sanitize_report(
    task,
    sanitize_address_shoot_task_ids,
    sanitize_memory_shoot_task_ids,
):
    def get_tests_data(shoot_task_ids):
        tests = {}
        for role, task_id in shoot_task_ids.items():
            tests[role] = {
                "ok": sdk2.Task[task_id].status == Status.SUCCESS,
                "task_link": get_task_link(task_id),
                "task_id": task_id,
            }
        return tests

    shoot_task_ids = list(itertools.chain(
        sanitize_address_shoot_task_ids.values(),
        sanitize_memory_shoot_task_ids.values(),
    ))
    check_tasks(task, shoot_task_ids)

    has_diff = any(sdk2.Task[task_id].status != Status.SUCCESS for task_id in shoot_task_ids)

    report_data = [
        {
            "test_type": "sanitize address",
            "tests": get_tests_data(sanitize_address_shoot_task_ids),
        },
        {
            "test_type": "sanitize memory",
            "tests": get_tests_data(sanitize_memory_shoot_task_ids),
        },
    ]

    template = get_report_template()
    return template.render(data=report_data), has_diff


def get_testenv_job_name_from_tags(tags):
    prefix = 'TESTENV-JOB-'
    for tag in tags:
        if tag.startswith(prefix):
            return tag[len(prefix):]
    return ''
