import logging
import sys
import subprocess
import sandbox.common.types.task as ctt
from sandbox.common import errors
from sandbox import sdk2
from sandbox.projects.yabs.panther.resource_types import PantherBannerFeatureExtractorBinary, PantherLmDumpIndexBuilderBinary, PantherLmDumpMergerBinary


class ImportLinearDumpType(object):
    existing = 'existing'
    raw_dump = 'raw_dump'


class ImportBannerFeaturesType(object):
    existing = 'existing'
    caesar = 'caesar'


class ImportLinearDumpStage(object):
    initial = 'initial'
    run_child = 'run_child'
    merge = 'merge'


class PantherImportBannerFeatures(sdk2.Task):

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 5 * 60 * 60

        extractor_binary = sdk2.parameters.LastResource('Banner features extractor binary', resource_type=PantherBannerFeatureExtractorBinary, required=True)

        with sdk2.parameters.Group("YT params", collapse=False) as yt_settings:
            yt_cluster = sdk2.parameters.String("YT cluster", default="hahn", required=True)
            yt_token_vault_name = sdk2.parameters.String("YT Token vault name", default="yt_token", required=True)

        with sdk2.parameters.Group("Banner features import params", collapse=False) as banner_features_params:
            importer_banner_caesar_dump_path = sdk2.parameters.String("Caesar banners dump path", default='//home/bs/logs/AdsCaesarBannersFullDump/latest', required=True)
            importer_adgroup_caesar_dump_path = sdk2.parameters.String("Caesar ad groups dump path", default='//home/bs/logs/AdsCaesarAdGroupsFullDump/latest', required=True)
            importer_order_caesar_dump_path = sdk2.parameters.String("Caesar orders dump path", default='//home/bs/logs/AdsCaesarOrdersFullDump/latest', required=True)
            importer_output_banner_features_table_path = sdk2.parameters.String("Output banner features path", required=True)
            importer_force_create_output = sdk2.parameters.Bool("Force create tables", default=False, required=False)

    def _get_common_cmd(self, resource):
        bin_res = sdk2.ResourceData(resource)
        cmd = [
            str(bin_res.path),
        ]
        if self.Parameters.importer_force_create_output:
            cmd.extend(['--force'])
        return cmd

    def _get_feature_extractor_cmd(self):
        cmd = self._get_common_cmd(self.Parameters.extractor_binary)
        cmd.extend([
            '--banner-caesar-dump', self.Parameters.importer_banner_caesar_dump_path,
            '--adgroup-caesar-dump', self.Parameters.importer_adgroup_caesar_dump_path,
            '--order-caesar-dump', self.Parameters.importer_order_caesar_dump_path,
            '--output-yt-table', self.Parameters.importer_output_banner_features_table_path,
        ])
        return cmd

    def on_execute(self):
        logging.info('Started')
        env = {
            'YT_PROXY': self.Parameters.yt_cluster,
            'YT_TOKEN': sdk2.task.Vault.data(self.author, self.Parameters.yt_token_vault_name if self.Parameters.yt_token_vault_name else 'yt_token')
        }
        logging.info('Prepare extractor process')
        extractor_cmd = self._get_feature_extractor_cmd()
        extractor_process = subprocess.Popen(extractor_cmd, stdout=sys.stdout, stderr=sys.stderr, env=env)
        retcode = extractor_process.wait()
        if retcode > 0:
            raise errors.TaskError('Extractor process finished with exit code: {}'.format(retcode))


class PantherImportLmDumpIndex(sdk2.Task):

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 3 * 60 * 60

        dump_index_builder_binary = sdk2.parameters.LastResource('Lm dump index builder binary', resource_type=PantherLmDumpIndexBuilderBinary, required=True)

        with sdk2.parameters.Group("YT params", collapse=False) as yt_settings:
            yt_cluster = sdk2.parameters.String("YT cluster", default="hahn", required=True)
            yt_token_vault_name = sdk2.parameters.String("YT Token vault name", default="yt_token", required=True)

        with sdk2.parameters.Group("Lm dump import params", collapse=False) as lm_dump_params:
            importer_input_linear_dump_path = sdk2.parameters.String("Input linear model dump path", required=True)
            importer_output_index_dump_table_path = sdk2.parameters.String("Output index dump path", required=True)
            importer_force_create_output = sdk2.parameters.Bool("Force create tables", default=False, required=False)

        with sdk2.parameters.Group("Index params", collapse=False) as index_settings:
            index_per_feature_guts_limit = sdk2.parameters.Integer("Per feature guts limit", default=1000, required=True)

    def _get_common_cmd(self, resource):
        bin_res = sdk2.ResourceData(resource)
        cmd = [
            str(bin_res.path),
        ]
        if self.Parameters.importer_force_create_output:
            cmd.extend(['--force'])
        return cmd

    def _get_dump_index_builder_cmd(self):
        cmd = self._get_common_cmd(self.Parameters.dump_index_builder_binary)
        cmd.extend([
            '--input-yt-dump-path', self.Parameters.importer_input_linear_dump_path,
            '--output-yt-table', self.Parameters.importer_output_index_dump_table_path,
            '--keep-top-n', str(self.Parameters.index_per_feature_guts_limit),
        ])
        return cmd

    def on_execute(self):
        logging.info('Started')
        env = {
            'YT_PROXY': self.Parameters.yt_cluster,
            'YT_TOKEN': sdk2.task.Vault.data(self.author, self.Parameters.yt_token_vault_name if self.Parameters.yt_token_vault_name else 'yt_token')
        }
        logging.info('Prepare builder process')
        builder_cmd = self._get_dump_index_builder_cmd()
        builder_process = subprocess.Popen(builder_cmd, stdout=sys.stdout, stderr=sys.stderr, env=env)

        retcode = builder_process.wait()
        if retcode > 0:
            raise errors.TaskError('Builder process finished with exit code: {}'.format(retcode))


class PantherImportLinearModelDump(sdk2.Task):

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 10 * 60 * 60
        push_tasks_resource = True

        extractor_binary = sdk2.parameters.LastResource('Banner features extractor binary', resource_type=PantherBannerFeatureExtractorBinary, required=True)
        dump_index_builder_binary = sdk2.parameters.LastResource('Lm dump index builder binary', resource_type=PantherLmDumpIndexBuilderBinary, required=True)
        merger_binary = sdk2.parameters.LastResource('Banner features with lm dump merger binary', resource_type=PantherLmDumpMergerBinary, required=True)

        with sdk2.parameters.Group("YT params", collapse=False) as yt_settings:
            yt_cluster = sdk2.parameters.String("YT cluster", default="hahn", required=True)
            yt_token_vault_name = sdk2.parameters.String("YT Token vault name", default="yt_token", required=True)

        with sdk2.parameters.Group("Banner features import params", collapse=False) as banner_features_params:
            with sdk2.parameters.RadioGroup("Banner features import type") as banner_features_import_type:
                banner_features_import_type.values[ImportBannerFeaturesType.existing] = banner_features_import_type.Value("Use existing banner features table")
                banner_features_import_type.values[ImportBannerFeaturesType.caesar] = banner_features_import_type.Value("Import banner features from caesar")

            with banner_features_import_type.value[ImportBannerFeaturesType.existing]:
                importer_existing_banner_features_path = sdk2.parameters.String("Existing banner features table in proper format path", required=True)

            with banner_features_import_type.value[ImportBannerFeaturesType.caesar]:
                importer_banner_caesar_dump_path = sdk2.parameters.String("Caesar banners dump path", default='//home/bs/logs/AdsCaesarBannersFullDump/latest', required=True)
                importer_adgroup_caesar_dump_path = sdk2.parameters.String("Caesar ad groups dump path", default='//home/bs/logs/AdsCaesarAdGroupsFullDump/latest', required=True)
                importer_order_caesar_dump_path = sdk2.parameters.String("Caesar orders dump path", default='//home/bs/logs/AdsCaesarOrdersFullDump/latest', required=True)
                importer_output_banner_features_table_path = sdk2.parameters.String("Output banner features path", required=True)

        with sdk2.parameters.Group("Lm dump import params", collapse=False) as lm_dump_params:
            with sdk2.parameters.RadioGroup("Lm dump import type") as lm_dump_import_type:
                lm_dump_import_type.values[ImportLinearDumpType.existing] = lm_dump_import_type.Value("Use existing lm dump index table")
                lm_dump_import_type.values[ImportLinearDumpType.raw_dump] = lm_dump_import_type.Value("Import lm dump from raw tables")

            with lm_dump_import_type.value[ImportLinearDumpType.existing]:
                importer_existing_linear_dump_index_path = sdk2.parameters.String("Existing lm dump index table path", required=True)

            with lm_dump_import_type.value[ImportLinearDumpType.raw_dump]:
                importer_input_linear_dump_path = sdk2.parameters.String("Input linear model dump path", required=True)
                importer_output_index_dump_table_path = sdk2.parameters.String("Output index dump path", required=True)
                index_per_feature_guts_limit = sdk2.parameters.Integer("Per feature guts limit", default=1000, required=True)

        with sdk2.parameters.Group("Merger params", collapse=False) as lm_dump_params:
            importer_output_table_path = sdk2.parameters.String("Output table path", required=True)
            importer_force_create_output = sdk2.parameters.Bool("Force create tables", default=False, required=False)

        with sdk2.parameters.Group("Index params", collapse=False) as index_settings:
            index_guts_per_banner_limit = sdk2.parameters.Integer("Per banner guts limit", default=100, required=True)
            index_banners_per_gut_limit = sdk2.parameters.Integer("Per gut banners limit", default=2000, required=True)
            index_shards_count = sdk2.parameters.String("Shards count", default=10, required=True)

    def _get_common_cmd(self, resource):
        bin_res = sdk2.ResourceData(resource)
        cmd = [
            str(bin_res.path),
        ]
        if self.Parameters.importer_force_create_output:
            cmd.extend(['--force'])
        return cmd

    def _get_merger_cmd(self):
        cmd = self._get_common_cmd(self.Parameters.merger_binary)
        cmd.extend([
            '--input-dump-index-yt-table', self.Context.lm_index_path,
            '--input-banners-features-yt-table', self.Context.banner_features_path,
            '--output-yt-table', self.Parameters.importer_output_table_path,
            '--top-guts-per-banner', str(self.Parameters.index_guts_per_banner_limit),
            '--top-banners-per-gut', str(self.Parameters.index_banners_per_gut_limit),
        ])
        return cmd

    def _check_subtasks(self):
        failed_subtasks = [t for t in self.find(id=self.Context.waitable_tasks).limit(0) if t.status not in ctt.Status.Group.SUCCEED]
        if not failed_subtasks:
            return

        raise errors.TaskError(
            '\n'.join(['Subtask {} ({}) was finished with the status of {}'.format(t.type, t.id, t.status) for t in failed_subtasks])
        )

    def _wait_tasks(self, task_ids):
        if len(task_ids) == 0:
            return
        self.Context.waitable_tasks.extend(task_ids)
        raise sdk2.WaitTask(task_ids, list(ctt.Status.Group.FINISH + ctt.Status.Group.BREAK), True)

    def on_create(self):
        self.Context.stage = ImportLinearDumpStage.initial

    def on_execute(self):
        logging.info('Started')

        if self.Context.stage == ImportLinearDumpStage.initial:
            self.Context.stage = ImportLinearDumpStage.run_child
            self.Context.waitable_tasks = []
            self.Context.lm_index_path = (self.Parameters.importer_output_index_dump_table_path
                                            if self.Parameters.lm_dump_import_type == ImportLinearDumpType.raw_dump
                                            else self.Parameters.importer_existing_linear_dump_index_path)
            self.Context.banner_features_path = (self.Parameters.importer_output_banner_features_table_path
                                                    if self.Parameters.banner_features_import_type == ImportBannerFeaturesType.caesar
                                                    else self.Parameters.importer_existing_banner_features_path)

        if self.Context.stage == ImportLinearDumpStage.run_child:
            self.Context.stage = ImportLinearDumpStage.merge
            tasks_to_wait = []
            if self.Parameters.banner_features_import_type == ImportBannerFeaturesType.caesar:
                logging.info('Prepare features import task')
                extractor_task = PantherImportBannerFeatures(
                    self,
                    extractor_binary=self.Parameters.extractor_binary,
                    push_tasks_resource=True,
                    description="Child of {}".format(self.id),
                    owner=self.owner,
                    yt_cluster=self.Parameters.yt_cluster,
                    yt_token_vault_name=self.Parameters.yt_token_vault_name,
                    importer_banner_caesar_dump_path=self.Parameters.importer_banner_caesar_dump_path,
                    importer_adgroup_caesar_dump_path=self.Parameters.importer_adgroup_caesar_dump_path,
                    importer_order_caesar_dump_path=self.Parameters.importer_order_caesar_dump_path,
                    importer_output_banner_features_table_path=self.Parameters.importer_output_banner_features_table_path,
                    importer_force_create_output=self.Parameters.importer_force_create_output,
                    tags=self.Parameters.tags,
                )
                logging.info('Run extractor import {}'.format(self.Requirements.tasks_resource))
                extractor_task.Requirements.tasks_resource = self.Requirements.tasks_resource
                extractor_task.enqueue()
                self.Context.extractor_task_id = extractor_task.id
                tasks_to_wait.append(self.Context.extractor_task_id)

            if self.Parameters.lm_dump_import_type == ImportLinearDumpType.raw_dump:
                logging.info('Prepare index builder task')
                lm_index_builder = PantherImportLmDumpIndex(
                    self,
                    dump_index_builder_binary=self.Parameters.dump_index_builder_binary,
                    push_tasks_resource=True,
                    description="Child of {}".format(self.id),
                    owner=self.owner,
                    yt_cluster=self.Parameters.yt_cluster,
                    yt_token_vault_name=self.Parameters.yt_token_vault_name,
                    importer_input_linear_dump_path=self.Parameters.importer_input_linear_dump_path,
                    importer_output_index_dump_table_path=self.Parameters.importer_output_index_dump_table_path,
                    importer_force_create_output=self.Parameters.importer_force_create_output,
                    index_per_feature_guts_limit=self.Parameters.index_per_feature_guts_limit,
                    tags=self.Parameters.tags,
                )
                logging.info('Run lm index import {}'.format(self.Requirements.tasks_resource))
                lm_index_builder.Requirements.tasks_resource = self.Requirements.tasks_resource
                lm_index_builder.enqueue()
                self.Context.lm_index_builder_id = lm_index_builder.id
                tasks_to_wait.append(self.Context.lm_index_builder_id)
            self._wait_tasks(tasks_to_wait)

        if self.Context.stage == ImportLinearDumpStage.merge:
            self._check_subtasks()
            merger_cmd = self._get_merger_cmd()
            env = {
                'YT_PROXY': self.Parameters.yt_cluster,
                'YT_TOKEN': sdk2.task.Vault.data(self.author, self.Parameters.yt_token_vault_name if self.Parameters.yt_token_vault_name else 'yt_token')
            }
            merger_returncode = subprocess.check_call(merger_cmd, stdout=sys.stdout, stderr=sys.stderr, env=env)
            if merger_returncode > 0:
                raise errors.TaskError('Merger process finished with exit code: {}'.format(merger_returncode))
