import json
import time
from sandbox import sdk2
import logging
import sandbox.common.types.task as ctt
from sandbox.projects.catboost.util.resources import (
    CatBoostBinary,
    CatBoostPythonPackageWheel,
    CatBoostRunPythonPackageTrain,
    CatBoostPerfTestJsonTask,
    CatBoostDataPrepParamsJson)
from sandbox.projects.catboost.run_n_perf_tests_on_two_binaries import CatBoostRunNPerfTestsOnTwoBinaries
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sdk2.service_resources import SandboxTasksBinary
from sandbox.sandboxsdk import svn

_logger = logging.getLogger(__name__)


class CatBoostPerfCompareMetricsReport(sdk2.Resource):
    pass


class CatBoostPerfCompareTablesReport(sdk2.Resource):
    pass


class RunCatBoostPerfCompareAndPushScoreParameters(sdk2.Task.Parameters):
    with sdk2.parameters.Group("CatBoost performance test parameters"):
        with sdk2.parameters.RadioGroup('API type') as api_type:
            api_type.values['cli'] = api_type.Value('cli', default=True)
            api_type.values['python-package'] = api_type.Value('python-package')
        use_last_binary = sdk2.parameters.Bool(
            'use last binary archive',
            default=True, )
        with use_last_binary.value[True]:
            with sdk2.parameters.RadioGroup('Binary release type') as release_type:
                release_type.values['stable'] = release_type.Value('stable', default=True)
                release_type.values['test'] = release_type.Value('test')
        with use_last_binary.value[False]:
            custom_tasks_archive_resource = sdk2.parameters.Resource(
                'task archive resource',
                default=None, )

        with api_type.value['python-package']:
            run_python_package_train = sdk2.parameters.Resource(
                'run_python_package_train script (not arcadia build python binary!)',
                resource_type=CatBoostRunPythonPackageTrain,
                default=None, )
        use_first_catboost_binary = sdk2.parameters.Bool(
            'use first catboost binary',
            default=False, )
        with use_first_catboost_binary.value[False]:
            first_revision_number = sdk2.parameters.Integer(
                'first revision number',
                default=None, )
        with use_first_catboost_binary.value[True] and api_type.value['cli']:
            first_catboost_binary = sdk2.parameters.Resource(
                'first catboost binary',
                resource_type=CatBoostBinary,
                default=None, )
        with use_first_catboost_binary.value[True] and api_type.value['python-package']:
            first_catboost_python_package_wheel = sdk2.parameters.Resource(
                'first catboost python package wheel',
                resource_type=CatBoostPythonPackageWheel,
                default=None,)
        first_catboost_name = sdk2.parameters.String(
            'first catboost name',
            default='catboost_1', )
        use_second_catboost_binary = sdk2.parameters.Bool(
            'use second catboost binary',
            default=False, )
        with use_second_catboost_binary.value[False]:
            second_revision_number = sdk2.parameters.Integer(
                'second revision number',
                default=None, )
        with use_second_catboost_binary.value[True] and api_type.value['cli']:
            second_catboost_binary = sdk2.parameters.Resource(
                'second catboost binary',
                resource_type=CatBoostBinary,
                default=None, )
        with use_second_catboost_binary.value[True] and api_type.value['python-package']:
            second_catboost_python_package_wheel = sdk2.parameters.Resource(
                'second catboost python package wheel',
                resource_type=CatBoostPythonPackageWheel,
                default=None,)
        second_catboost_name = sdk2.parameters.String(
            'second catboost name',
            default='catboost_2', )
        push_result_to_yt_table = sdk2.parameters.Bool(
            'push result to yt table',
            default=False, )
        with push_result_to_yt_table.value[True]:
            metrics_table_proxy = sdk2.parameters.String(
                'metrics table proxy',
                default='hahn.yt.yandex.net', )
            metrics_table_path = sdk2.parameters.String(
                'metrics table path',
                required=True, )
            create_metrics_table = sdk2.parameters.Bool(
                'create metrics table',
                default=False, )
        token_name = sdk2.parameters.String(
            'token name',
            default='yt_token', )
        perf_test_json_task = sdk2.parameters.Resource(
            'json task',
            resource_type=CatBoostPerfTestJsonTask,
            required=True, )
        with api_type.value['python-package']:
            first_data_prep_params_json = sdk2.parameters.Resource(
                'first data prep params json',
                resource_type=CatBoostDataPrepParamsJson,
                default=None, )
            second_data_prep_params_json = sdk2.parameters.Resource(
                'second data prep params json',
                resource_type=CatBoostDataPrepParamsJson,
                default=None, )
        number_of_runs = sdk2.parameters.Integer(
            'number of runs',
            default=1, )
        required_fraction_of_successful_tasks = sdk2.parameters.Float(
            'required fraction of successful tasks',
            default=0.7, )


class RunCatBoostPerfCompareAndPushScore(sdk2.Task):

    def _run_perf_test(
        self,
        use_first_binary,
        first_revision_number,
        use_second_binary,
        second_revision_number,
        json_task_id,
        number_of_runs
    ):
        kwargs = {
            'api_type': self.Parameters.api_type,
            'booster': 'catboost',
            'perf_test_json_task': json_task_id,
            'number_of_runs': number_of_runs,
            'required_fraction_of_successful_tasks': self.Parameters.required_fraction_of_successful_tasks,
            'first_name': self.Parameters.first_catboost_name,
            'second_name': self.Parameters.second_catboost_name,
            'token_name': self.Parameters.token_name,
            'kill_timeout': self.Parameters.kill_timeout}
        if use_first_binary:
            kwargs['use_first_catboost_binary'] = True
            if self.Parameters.api_type == 'cli':
                kwargs['first_catboost_binary'] = self.Parameters.first_catboost_binary
            elif self.Parameters.api_type == 'python-package':
                kwargs['first_catboost_python_package_wheel'] = self.Parameters.first_catboost_python_package_wheel
                kwargs['first_data_prep_params_json'] = self.Parameters.first_data_prep_params_json
        else:
            kwargs['use_first_catboost_binary'] = False
            kwargs['first_revision_number'] = first_revision_number

        if use_second_binary:
            kwargs['use_second_catboost_binary'] = True
            if self.Parameters.api_type == 'cli':
                kwargs['second_catboost_binary'] = self.Parameters.second_catboost_binary
            elif self.Parameters.api_type == 'python-package':
                kwargs['second_catboost_python_package_wheel'] = self.Parameters.second_catboost_python_package_wheel
                kwargs['second_data_prep_params_json'] = self.Parameters.second_data_prep_params_json
        else:
            kwargs['use_second_catboost_binary'] = False
            kwargs['second_revision_number'] = second_revision_number

        if self.Parameters.api_type == 'python-package':
            if self.Parameters.run_python_package_train:
                kwargs['run_python_package_train'] = self.Parameters.run_python_package_train

        run_perf_test_task = CatBoostRunNPerfTestsOnTwoBinaries(
            self,
            description="run perf test",
            **kwargs).enqueue()
        return run_perf_test_task.id

    def _append_metrics_to_table(self, yt_client, table_path, rows):
        import yt.wrapper as yt
        yt_client.write_table(yt.TablePath(table_path, append=True), rows, format='json')

    def _create_metrics_table(self, yt_client, table_path):
        schema = [
            {"name": "date", "type": "double"},
            {"name": "dataset_name", "type": "string"},
            {"name": "first_arcadia_svn_revision", "type": "int64"},
            {"name": "first_binary_sandbox_id", "type": "int64"},
            {"name": "second_arcadia_svn_revision", "type": "int64"},
            {"name": "second_binary_sandbox_id", "type": "int64"},
            {"name": "number_of_runs", "type": "int64"},
            {"name": "number_of_successful_runs", "type": "int64"},
            {"name": "clean_time_coef", "type": "double"},
            {"name": "total_time_coef", "type": "double"}]

        for metric in ('clean_time', 'total_time', 'max_rss'):
            schema.append({"name": "first_" + metric, "type": "double"})
            schema.append({"name": "second_" + metric, "type": "double"})

        yt_client.create('table', table_path, attributes={"schema": schema, "optimize_for": "scan"})

    def _get_dataset_names(self, perf_test_json_task):
        dataset_names = []
        json_task_path = str(sdk2.ResourceData(perf_test_json_task).path)
        with open(json_task_path, 'r') as json_task:
            tasks_dict = json.load(json_task)
        for evaluatioon_task in tasks_dict:
            dataset_names.append(evaluatioon_task['name'])
        return dataset_names

    def _get_date_of_revision(self, revision_number):
        info = svn.Arcadia.info('arcadia:/arc/trunk/arcadia@{}'.format(str(revision_number)))
        date_of_revision = 0
        if 'date' in info:
            date_of_revision = time.mktime(time.strptime(info['date'], '%Y-%m-%dT%H:%M:%S.%fZ'))
        return date_of_revision

    def _get_metrics_from_dataset_name(self, dataset_name, result_metrics_dict):
        result = {
            'dataset_name': dataset_name,
            'clean_time_coef': result_metrics_dict[dataset_name]['clean_time_coef'],
            'total_time_coef': result_metrics_dict[dataset_name]['total_time_coef'],
            'number_of_runs': self.Parameters.number_of_runs,
            'number_of_successful_runs': result_metrics_dict[dataset_name]['number_of_successful_runs'],
            'date': self.Context.date_of_run,
            'first_arcadia_svn_revision': self.Parameters.first_revision_number,
            'second_arcadia_svn_revision': self.Parameters.second_revision_number}

        for metric in ('clean_time', 'total_time', 'max_rss'):
            result['first_' + metric] = float(result_metrics_dict[dataset_name]['first_' + metric])
            result['second_' + metric] = float(result_metrics_dict[dataset_name]['second_' + metric])

        if self.Parameters.api_type == 'cli':
            result['first_binary_sandbox_id'] = self.Context.first_catboost_binary_id
            result['second_binary_sandbox_id'] = self.Context.second_catboost_binary_id
        elif self.Parameters.api_type == 'python-package':
            result['first_python_package_wheel_sandbox_id'] = self.Context.first_catboost_python_package_wheel_id
            result['second_python_package_wheel_sandbox_id'] = self.Context.second_catboost_python_package_wheel_id

        return result

    def _get_pretty_table_from_dataset_name(self, dataset_name, result_metrics_dict):
        from prettytable import PrettyTable
        th = ['', self.Parameters.first_catboost_name, self.Parameters.second_catboost_name]
        table = PrettyTable(th)
        for metric in ('total_time', 'clean_time', 'max_rss'):
            table.add_row(
                [
                    metric,
                    result_metrics_dict[dataset_name]['first_' + metric],
                    result_metrics_dict[dataset_name]['second_' + metric],
                ]
            )
        return '{}\n{}'.format(dataset_name, table.get_string())

    def _push_new_metrics_to_yt(self, yt_metrics_table_list):
        import yt.wrapper as yt
        import library.python.retry as retry
        import yt.logger as yt_logger

        yt_logger.LOGGER.setLevel(level='DEBUG')

        yt_config = {
            'proxy': {'url': self.Parameters.metrics_table_proxy},
            'token': sdk2.Vault.data(self.owner, self.Parameters.token_name)}

        client = yt.YtClient(config=yt_config)

        if self.Parameters.create_metrics_table:
            self._create_metrics_table(client, self.Parameters.metrics_table_path)

        retry.retry_call(
            self._append_metrics_to_table,
            (client, self.Parameters.metrics_table_path, yt_metrics_table_list),
            conf=retry.RetryConf().waiting(delay=5., backoff=2., jitter=1., limit=120.).upto(minutes=5.))

    def _agregate_metrics(self):
        json_report_path = str(sdk2.ResourceData(sdk2.Resource["CAT_BOOST_NPERF_TESTS_REPORT"].find().first()).path)

        dataset_names = self._get_dataset_names(self.Parameters.perf_test_json_task)

        metrics_table_list = []
        perf_test_task = sdk2.Task[self.Context.run_perf_test_task_id]
        if not self.Parameters.use_first_catboost_binary:
            self.Context.first_catboost_binary_id = perf_test_task.Context.first_catboost_binary_id
        else:
            self.Context.first_catboost_binary_id = self.Parameters.first_catboost_binary.id

        if not self.Parameters.use_second_catboost_binary:
            self.Context.second_catboost_binary_id = perf_test_task.Context.second_catboost_binary_id
        else:
            self.Context.second_catboost_binary_id = self.Parameters.second_catboost_binary.id

        pretty_table_str = ''
        try:
            with open(json_report_path, 'r') as result_metrics_file:
                result_metrics_dict = json.load(result_metrics_file)
            for dataset_name in dataset_names:
                metrics_table_list.append(self._get_metrics_from_dataset_name(dataset_name, result_metrics_dict))
                pretty_table_str = '{}\n{}'.format(pretty_table_str, self._get_pretty_table_from_dataset_name(dataset_name, result_metrics_dict))
        except Exception as e:
            _logger.exception(e)

        with open('metrics.json', 'w') as metrics_table_result:
            json.dump(metrics_table_list, metrics_table_result)

        metrics_resource = sdk2.ResourceData(CatBoostPerfCompareMetricsReport(
            self,
            "catbost performance test metrics resource", 'metrics.json'))
        metrics_resource.ready()

        with open('metrics_tables.txt', 'w') as metrics_table_result:
            metrics_table_result.write(pretty_table_str)

        metrics_tables_resource = sdk2.ResourceData(CatBoostPerfCompareTablesReport(
            self,
            "catbost performance compare tables resource", 'metrics_tables.txt'))
        metrics_tables_resource.ready()

        if self.Parameters.push_result_to_yt_table:
            self._push_new_metrics_to_yt(metrics_table_list)

    class Requirements(sdk2.Requirements):
        cores = 1
        disk_space = 1024  # 1 GB
        ram = 512  # 512 MB

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(RunCatBoostPerfCompareAndPushScoreParameters):
        pass

    def _check_for_success(self, task_id):
        if self.server.task[task_id].read()["status"] != 'SUCCESS':
            raise SandboxTaskFailureError('error in task {})'.format(task_id))

    def on_save(self):
        if self.Parameters.use_last_binary:
            self.Requirements.tasks_resource = SandboxTasksBinary.find(
                attrs={'target': 'catboost/bin', 'release': self.Parameters.release_type or 'stable'}
            ).first().id
        else:
            self.Requirements.tasks_resource = self.Parameters.custom_tasks_archive_resource

    def on_execute(self):
        self.Context.date_of_run = time.time()

        with self.memoize_stage.run_perf_test:
            self.Context.run_perf_test_task_id = self._run_perf_test(
                use_first_binary=self.Parameters.use_first_catboost_binary,
                first_revision_number=self.Parameters.first_revision_number,
                use_second_binary=self.Parameters.use_second_catboost_binary,
                second_revision_number=self.Parameters.second_revision_number,
                json_task_id=self.Parameters.perf_test_json_task.id,
                number_of_runs=self.Parameters.number_of_runs)
            raise sdk2.WaitTask(self.Context.run_perf_test_task_id, ctt.Status.Group.SUCCEED + ctt.Status.Group.FINISH + ctt.Status.Group.BREAK)

        self._check_for_success(self.Context.run_perf_test_task_id)

        self._agregate_metrics()
