# coding: utf-8

import datetime
import tempfile

from sandbox import sdk2
from sandbox.projects.common import link_builder as lb
from sandbox.projects.common.binary_task import LastBinaryTaskRelease, LastBinaryReleaseParameters
from sandbox.projects.common.utils2 import resource_redirect_link
from sandbox.projects.yabs.qa.pipeline.stage import stage
from sandbox.projects.yabs.yt_logfeller_sampler import YtLogFellerSamplers


class LinearModelsFlameGraphReport(sdk2.Resource):
    pass


SAMPLERS = (
    YtLogFellerSamplers.Hours,
    YtLogFellerSamplers.Days,
)
SAMPLERS_INV = {
    sampler.name: sampler
    for sampler in SAMPLERS
}

DEFAULT_SAMPLER = YtLogFellerSamplers.Days.name
DATETIME_FORMAT = '%Y-%m-%d %H:%M'
TEMPLATE_FORMAT = 'sandbox/projects/runtime_models/BuildLinearModelsFlameGraph/templates/{}'
DEFAULT_SKIP_TOP = (
    '??',
    'do_futex_wait.constprop',
    'do_sigwait',
    'epoll_wait',
    'nanosleep',
    'pthread_cond_timedwait@@GLIBC_2.3.2',
    'pthread_cond_wait@@GLIBC_2.3.2',
    'sigwait',
    'syscall',
)


def parse_datetime(data):
    return datetime.datetime.strptime(data, DATETIME_FORMAT)


class DateTimeParameter(sdk2.parameters.String):
    description = 'Datetime in format: `{}`'.format(DATETIME_FORMAT)

    @classmethod
    def cast(cls, value):
        parse_datetime(value)
        return value


def render_request(**kwargs):
    import library.python.resource as lpr
    from jinja2 import Template

    return Template(
        lpr.find(TEMPLATE_FORMAT.format('request.sql'))
    ).render(**kwargs)


def get_file_name(result):
    if result is None:
        return None
    file = tempfile.NamedTemporaryFile(delete=False)
    with file.open():
        for row in result:
            file.write('bt={}\tc={}\n'.format(
                row['Backtrace'],
                row['Counter']
            ))
    return file.name


def write_table(f, rows):
    with f.open():
        for row in rows:
            f.write('bt={}\tc={}\n'.format(
                row['Backtrace'],
                row['Counter']
            ))


class RequestParametersGroup(sdk2.Parameters):
    begin_time = DateTimeParameter(
        'Begin time in format `{}`'.format(DATETIME_FORMAT),
        required=True,
    )
    end_time = DateTimeParameter(
        'End time in format `{}`'.format(DATETIME_FORMAT),
        required=True,
    )
    whitelist_filtration = sdk2.parameters.Dict('Whitelist filtration: column_name -> value')
    blacklist_filtration_top = sdk2.parameters.List('Blacklist to filter top trace', default=DEFAULT_SKIP_TOP)
    separation = sdk2.parameters.List('Separate graph by columns list')


class LinearModelsBuildFlameGraph(LastBinaryTaskRelease, sdk2.Task):
    """Build flame graph for yabs_runtime_models"""

    class Parameters(LastBinaryReleaseParameters):

        base_request_params = RequestParametersGroup(prefix='base_')
        diff_report = sdk2.parameters.Bool('Create diff report', default=False)
        with diff_report.value[True]:
            test_request_parameters = RequestParametersGroup(prefix='test_')
        with sdk2.parameters.Group('Yt Parameters') as yt_block:
            cluster = sdk2.parameters.String('YT cluster', default='hahn', required=True)
            yql_token = sdk2.parameters.YavSecret('YQL token secret', required=True)
            user_definitions = sdk2.parameters.String('User text inserted in yql. For example for pragmas',
                                                      default='', multiline=True)
            prefix = sdk2.parameters.String(
                'Prefix to tables',
                default='//logs/yabs-rt-models-linear-models-poormansprofiler',
                required=True,
            )
            sampler = sdk2.parameters.RadioGroup(
                'Sampler to tables in YT',
                choices=tuple(
                    (sampler.name, sampler.name)
                    for sampler in SAMPLERS
                ),
                default=DEFAULT_SAMPLER,
                required=True,
            )

        with sdk2.parameters.Output():
            report = sdk2.parameters.Resource(
                'Report page',
                resource_type=LinearModelsFlameGraphReport,
            )

        with sdk2.parameters.Group('Debug', collapse=True) as debug_block:
            debug_operation_id = sdk2.parameters.String('OperationID of YQL')

    @stage(provides=('base_begin_table', 'base_end_table', 'test_begin_table', 'test_end_table'), result_is_dict=True)
    def get_time_tables(self):
        parsed_base_begin_time = parse_datetime(self.Parameters.base_begin_time)
        parsed_base_end_time = parse_datetime(self.Parameters.base_end_time)

        sampler = SAMPLERS_INV[self.Parameters.sampler]
        result = dict(
            base_begin_table=sampler.get_table_name(
                parsed_base_begin_time,
            ),
            base_end_table=sampler.get_table_name(
                parsed_base_end_time,
            ),
            test_begin_table=None,
            test_end_table=None,
        )
        if self.Parameters.diff_report:
            result['test_begin_table'] = sampler.get_table_name(
                parse_datetime(self.Parameters.test_begin_time),
            )
            result['test_end_table'] = sampler.get_table_name(
                parse_datetime(self.Parameters.test_end_time),
            )
        return result

    def get_params(self, prefix):
        def get_context(name):
            return getattr(self.Context, prefix + '_' + name)

        def get_param(name):
            return getattr(self.Parameters, prefix + '_' + name)

        return dict(
            begin_time=get_context('begin_table'),
            blacklist_filtration_top=get_param('blacklist_filtration_top'),
            end_time=get_context('end_table'),
            separations=get_param('separation'),
            whitelist_filtration=get_param('whitelist_filtration'),
        )

    @stage(provides='operation_id')
    def get_operation_id(self):
        if self.Parameters.debug_operation_id:
            return self.Parameters.debug_operation_id

        client = self.get_yql_client()
        sampler = SAMPLERS_INV[self.Parameters.sampler]
        yt_prefix = sampler.get_folder_name(self.Parameters.prefix)

        params = list()
        params.append(self.get_params('base'))
        if self.Parameters.diff_report:
            params.append(self.get_params('test'))

        render_request_params = dict(
            author_url=lb.staff_link(self.author, link_type=lb.LinkType.plain),
            prefix=yt_prefix,
            task=self,
            task_url=lb.task_link(self.id, plain=True),
            params=params,
            user_definitions=self.Parameters.user_definitions,
        )

        request = client.query(
            render_request(**render_request_params),
            syntax_version=1,
            title='YQL BuildLinearModelsFlameGraph',
        )
        request.run()
        self.set_info(
            '<a href="{0}" target="_blank">{1}</a>'.format(
                request.share_url,
                'Ссылка ны Ыкль'
            ),
            do_escape=False,
        )
        return request.operation_id

    def get_yql_client(self):
        from yql.api.v1.client import YqlClient

        return YqlClient(
            db=self.Parameters.cluster,
            token=self.Parameters.yql_token.value(),
        )

    def set_yt_token(self):
        from yql.api.v1.client import config as YqlConfig

        YqlConfig.token = self.Parameters.yql_token.value()

    def get_request(self, operation_id):
        from yql.client.operation import YqlOperationResultsRequest
        self.set_yt_token()

        request = YqlOperationResultsRequest(operation_id)
        request.run()
        if request.in_progress:
            raise sdk2.WaitTime(300)

        if not request.is_success:
            raise Exception('\n'.join(map(str, request.errors)))

        return request

    def on_execute(self):
        self.get_time_tables()
        operation_id = self.get_operation_id()
        request = self.get_request(operation_id)

        result_files_names = list()
        for table in request.get_results():
            with tempfile.NamedTemporaryFile(delete=False) as f:
                for row in table.get_iterator():
                    f.write('bt={}\tc={}\n'.format(*row))
                result_files_names.append(f.name)

        result_resource = LinearModelsFlameGraphReport(
            task=self,
            description='yabs linear models flame graph',
            path='flame.html',
        )
        result_data = sdk2.ResourceData(result_resource)
        result_file_name = str(result_data.path)

        from yabs.poormansprofiler.flames.lib import make_flame_graph
        make_flame_graph(
            result_files_names[0],
            result_files_names[1] if self.Parameters.diff_report else None,
            result_file_name
        )

        result_data.ready()

        self.Parameters.report = result_resource
        self.set_info(
            resource_redirect_link(
                result_resource.id,
                'Link to graph',
            ),
            do_escape=False,
        )
