import logging
import datetime as dt
import os

import sandbox.common.types.task as ctt
from sandbox.common import errors
from sandbox import sdk2


class PantherBuildBasesCommon(sdk2.Task):

    class TaskStages(object):
        init = 'init'
        prepare = 'prepare'
        update_state = 'update_state'
        build_index = 'build_index'
        share_bases = 'share_bases'
        done = 'done'

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 20 * 60 * 60
        push_tasks_resource = True

        with sdk2.parameters.Group("YT params", collapse=False) as yt_params:
            yt_cluster = sdk2.parameters.String("YT cluster", default="hahn", required=True)
            yt_token_vault_name = sdk2.parameters.String("YT Token vault name", default="yt_token", required=True)

        with sdk2.parameters.Group("Transport params", collapse=True) as transport_params:
            transport_yt_cluster = sdk2.parameters.String("Transport YT cluster", default="locke", required=True)
            transport_yt_path = sdk2.parameters.String("Path to yt transport dir", default="//tmp/transport_panther", required=True)
            transport_force_push = sdk2.parameters.Bool("Force push bases", default=False, required=False)
            transport_enable_fastbone = sdk2.parameters.Bool("Turn on fastbone option of yt sky_share", default=True, required=True)

        with sdk2.parameters.Group("Common builder params", collapse=False) as builders_params:
            common_builder_import_path = sdk2.parameters.String("Path to store builder inputs", required=True)
            common_builder_basegen_path = sdk2.parameters.String("Path to store builder outputs", required=True)
            common_builder_state_path = sdk2.parameters.String("Path to basegen state", required=False)
            common_builder_force_create = sdk2.parameters.Bool("Force create tables", default=False, required=False)

        with sdk2.parameters.Output():
            output_bases_info = sdk2.parameters.JSON('Bases info')

    def _check_subtasks(self):
        failed_subtasks = [t for t in self.find(id=self.Context.waitable_tasks).limit(0) if t.status not in ctt.Status.Group.SUCCEED]
        if not failed_subtasks:
            return

        raise errors.TaskError(
            '\n'.join(['Subtask {} ({}) was finished with the status of {}'.format(t.type, t.id, t.status) for t in failed_subtasks])
        )

    def _wait_tasks(self, task_ids):
        self.Context.waitable_tasks.extend(task_ids)
        raise sdk2.WaitTask(task_ids, list(ctt.Status.Group.FINISH + ctt.Status.Group.BREAK), True)

    def _get_build_task_ids(self):
        raise NotImplementedError

    def _init(self):
        logging.info('Stage: {}'.format(self.Context.current_stage))
        self.Context.waitable_tasks = []
        self.Context.current_stage = self.TaskStages.prepare

        def floor_dt(dt, interval):
            replace = (dt.minute // interval) * interval
            return dt.replace(minute=replace, second=0, microsecond=0)

        rounded_time = floor_dt(dt.datetime.now(), 5)
        self.Context.build_mtime = int((rounded_time - dt.datetime(1970, 1, 1)).total_seconds())
        self.Context.build_time_str = rounded_time.strftime('%Y-%m-%dT%H:%M:%S')
        self.Context.banners_state_table_name = "banners"
        self.Context.banners_state_table_path = os.path.join(self.Parameters.common_builder_state_path, self.Context.banners_state_table_name)

    def _prepare(self):
        logging.info('Stage: {}'.format(self.Context.current_stage))
        self._check_subtasks()
        self.Context.current_stage = self.TaskStages.update_state

    def _update_state(self):
        logging.info('Stage: {}'.format(self.Context.current_stage))
        self._check_subtasks()
        self.Context.current_stage = self.TaskStages.build_index

    def _build(self):
        logging.info('Stage: {}'.format(self.Context.current_stage))
        self._check_subtasks()
        self.Context.current_stage = self.TaskStages.share_bases

    def _share(self):
        from sandbox.projects.yabs.panther.share_bases import PantherShareBases

        logging.info('Stage: {}'.format(self.Context.current_stage))
        self._check_subtasks()
        self.Context.current_stage = self.TaskStages.done

        child_tasks_query = self.find(
            id=self._get_build_task_ids(),
            status=(ctt.Status.SUCCESS),
        ).limit(0)
        child_tasks = [t for t in child_tasks_query if t.Parameters.bases_info is not None]
        build_tasks_count = len(self._get_build_task_ids())
        if len(child_tasks) != build_tasks_count:
            raise errors.TaskError('Expected {} child tasks with bases_info output, got {}'.format(build_tasks_count, len(child_tasks)))

        merged_bases_info = dict()
        for child_task in child_tasks:
            task_bases_info = child_task.Parameters.bases_info
            for base in task_bases_info:
                merged_bases_info[base] = task_bases_info[base]

        share_bases_task = sdk2.Task[str(PantherShareBases)](
            self,
            push_tasks_resource=True,
            description="Child of {}".format(self.id),
            owner=self.owner,
            yt_cluster=self.Parameters.yt_cluster,
            yt_token_vault_name=self.Parameters.yt_token_vault_name,
            transport_enable_fastbone=self.Parameters.transport_enable_fastbone,
            transport_force_push=self.Parameters.transport_force_push,
            transport_yt_cluster=self.Parameters.transport_yt_cluster,
            transport_yt_path=self.Parameters.transport_yt_path,
            share_mtime=self.Context.build_mtime,
            share_bases=merged_bases_info,
            tags=self.Parameters.tags,
        )
        share_bases_task.enqueue()
        self.Context.share_bases_task = share_bases_task.id
        self._wait_tasks([share_bases_task.id])

    def _done(self):
        self._check_subtasks()
        share_task = list(self.find(
            id=self.Context.share_bases_task,
            status=(ctt.Status.SUCCESS),
        ).limit(1))[0]
        self.Parameters.output_bases_info = share_task.Parameters.output_bases_info

    def on_execute(self):
        logging.info('Build bases')
        if not self.Context.current_stage:
            self.Context.current_stage = self.TaskStages.init

        logging.info('Current stage: {}'.format(str(self.Context.current_stage)))

        if self.Context.current_stage == self.TaskStages.init:
            self._init()

        if self.Context.current_stage == self.TaskStages.prepare:
            self._prepare()

        if self.Context.current_stage == self.TaskStages.update_state:
            self._update_state()

        if self.Context.current_stage == self.TaskStages.build_index:
            self._build()

        if self.Context.current_stage == self.TaskStages.share_bases:
            self._share()

        if self.Context.current_stage == self.TaskStages.done:
            self._done()


class IndexBuildTypes(object):
    lm_dump = 'lm_dump'
    existing = 'existing'
    yql = 'yql'


class PantherBuildIndexBases(PantherBuildBasesCommon):

    class Parameters(PantherBuildBasesCommon.Parameters):
        with sdk2.parameters.Group("Bases params", collapse=True) as bases_params:
            trigger_tsar_build = sdk2.parameters.Bool("Trigger tsar bases build", default=True, required=True)

            with sdk2.parameters.RadioGroup("Build type") as build_type:
                build_type.values[IndexBuildTypes.existing] = build_type.Value("Use existing sharded panther table")
                build_type.values[IndexBuildTypes.lm_dump] = build_type.Value("Build index from linear model yt dump")
                build_type.values[IndexBuildTypes.yql] = build_type.Value("Build index from yql", default=True)

            with build_type.value["lm_dump"]:
                lm_dump_input_path = sdk2.parameters.String("Input linear model dump path", required=True)
                lm_dump_banner_features_table_path = sdk2.parameters.String("Banner features table path", required=True)
                lm_dump_output_index_dump_table_path = sdk2.parameters.String("Intermediate linear model dump index table path", required=True)

            with build_type.value["existing"]:
                existing_input_path = sdk2.parameters.String("Input panther path in proper format", required=True)

            with build_type.value["yql"]:
                with sdk2.parameters.Group("YQL params", collapse=False) as yql_params:
                    yql_token_vault_name = sdk2.parameters.String("YQL Token vault name", default="yql_token", required=True)

                with sdk2.parameters.Group("YQL run params", collapse=False) as yql_run_params:
                    yql_time_offset = sdk2.parameters.Integer('Time offset', default=12, required=True)
                    yql_interval_hours = sdk2.parameters.Integer('Statistics period in hours', default=24, required=True)
                    yql_sampling = sdk2.parameters.Bool("Use sampling", default=False, required=True)
                    with yql_sampling.value[True]:
                        yql_sampling_percent = sdk2.parameters.Float("Sampling percent", default=0.01, required=True)

        with sdk2.parameters.Group("Panther builders params", collapse=True) as panther_builders_params:
            panther_builder_samples_count = sdk2.parameters.Integer("Offroad samples count", default=1000, required=True)

        with sdk2.parameters.Group("Index params", collapse=False) as index_settings:
            index_per_feature_guts_limit = sdk2.parameters.Integer("Per feature guts limit", default=1000, required=True)
            index_guts_per_banner_limit = sdk2.parameters.Integer("Per banner guts limit", default=100, required=True)
            index_banners_per_gut_limit = sdk2.parameters.Integer("Per gut banners limit", default=2000, required=True)
            index_shards_count = sdk2.parameters.String("Shards count", default=10, required=True)

    def _get_build_task_ids(self):
        return [self.Context.panther_index_build_task]

    def _init(self):
        super(PantherBuildIndexBases, self)._init()

        self.Context.import_table = 'panther'
        self.Context.basegen_dir = 'panther_{}'.format(self.Context.build_time_str)

    def _prepare(self):
        super(PantherBuildIndexBases, self)._prepare()

        self.Context.input_panther_path = None
        if self.Parameters.build_type == IndexBuildTypes.existing:
            self.Context.input_panther_path = self.Parameters.existing_input_path
            return

        self.Context.input_panther_path = os.path.join(self.Parameters.common_builder_import_path, self.Context.import_table)
        if self.Parameters.build_type == IndexBuildTypes.yql:
            from sandbox.projects.yabs.panther.import_statistical_index import PantherImportStatisticalIndex
            import_statistical_index_task = PantherImportStatisticalIndex(
                self,
                kill_timeout=16 * 60 * 60,
                push_tasks_resource=True,
                description="Child of {}".format(self.id),
                owner=self.owner,
                yt_cluster=self.Parameters.yt_cluster,
                mode='last_n',
                yql_token_vault_name=self.Parameters.yql_token_vault_name,
                index_result_table_path=self.Context.input_panther_path,
                index_sampling=self.Parameters.yql_sampling,
                index_sampling_percent=self.Parameters.yql_sampling_percent,
                index_time_offset=self.Parameters.yql_time_offset,
                index_interval_hours=self.Parameters.yql_interval_hours,
                tags=self.Parameters.tags,
            )
            logging.info('Run statistics import {}'.format(self.Requirements.tasks_resource))
            import_statistical_index_task.Requirements.tasks_resource = self.Requirements.tasks_resource
            import_statistical_index_task.enqueue()
            self.Context.import_statistical_index_task = import_statistical_index_task.id
            self._wait_tasks([import_statistical_index_task.id])

        elif self.Parameters.build_type == IndexBuildTypes.lm_dump:
            from sandbox.projects.yabs.panther.import_linear_model_dump import PantherImportLinearModelDump

            import_lm_dump_task = PantherImportLinearModelDump(
                self,
                kill_timeout=10 * 60 * 60,
                push_tasks_resource=True,
                description="Child of {}".format(self.id),
                owner=self.owner,
                yt_cluster=self.Parameters.yt_cluster,
                yt_token_vault_name=self.Parameters.yt_token_vault_name,
                importer_force_create_output=self.Parameters.common_builder_force_create,
                banner_features_import_type='existing',
                lm_dump_import_type='raw_dump',
                importer_input_linear_dump_path=self.Parameters.lm_dump_input_path,
                importer_existing_banner_features_path=self.Parameters.lm_dump_banner_features_table_path,
                importer_output_index_dump_table_path=self.Parameters.lm_dump_output_index_dump_table_path,
                importer_output_table_path=self.Context.input_panther_path,
                index_per_feature_guts_limit=self.Parameters.index_per_feature_guts_limit,
                index_guts_per_banner_limit=self.Parameters.index_guts_per_banner_limit,
                index_banners_per_gut_limit=self.Parameters.index_banners_per_gut_limit,
                index_shards_count=self.Parameters.index_shards_count,
                tags=self.Parameters.tags,
            )
            logging.info('Run lm dump import with resource {}'.format(self.Requirements.tasks_resource))
            import_lm_dump_task.Requirements.tasks_resource = self.Requirements.tasks_resource
            import_lm_dump_task.enqueue()
            self.Context.import_lm_dump_task = import_lm_dump_task.id
            self._wait_tasks([import_lm_dump_task.id])

    def _update_state(self):
        super(PantherBuildIndexBases, self)._update_state()
        from sandbox.projects.yabs.panther.update_banners_state import PantherUpdateBannersState

        update_state_task = PantherUpdateBannersState(
            self,
            kill_timeout=1 * 60 * 60,
            push_tasks_resource=True,
            description="Child of {}".format(self.id),
            owner=self.owner,
            yt_cluster=self.Parameters.yt_cluster,
            yt_token_vault_name=self.Parameters.yt_token_vault_name,
            yql_token_vault_name=self.Parameters.yql_token_vault_name,
            state_import_table_path=self.Context.input_panther_path,
            state_table_path=self.Context.banners_state_table_path,
            tags=self.Parameters.tags,
        )
        logging.info('Run state update {}'.format(self.Requirements.tasks_resource))
        update_state_task.Requirements.tasks_resource = self.Requirements.tasks_resource
        update_state_task.enqueue()
        self.Context.update_state_task = update_state_task.id
        self._wait_tasks([update_state_task.id])

    def _trigger_tsar_bases_build(self):
        build_tsar_bases = sdk2.Task[str(PantherBuildTsarBases)](
            self,
            kill_timeout=20 * 60 * 60,
            push_tasks_resource=True,
            description="Child of {}".format(self.id),
            owner=self.owner,
            yt_cluster=self.Parameters.yt_cluster,
            yt_token_vault_name=self.Parameters.yt_token_vault_name,
            transport_yt_cluster=self.Parameters.transport_yt_cluster,
            transport_yt_path=self.Parameters.transport_yt_path,
            transport_force_push=self.Parameters.transport_force_push,
            transport_enable_fastbone=self.Parameters.transport_enable_fastbone,
            common_builder_import_path=self.Parameters.common_builder_import_path,
            common_builder_basegen_path=self.Parameters.common_builder_basegen_path,
            common_builder_state_path=self.Parameters.common_builder_state_path,
            common_builder_force_create=self.Parameters.common_builder_force_create,
            build_type=TsarBuildTypes.from_state,
            from_state_banner_table=self.Context.banners_state_table_name,
            tags=self.Parameters.tags,
        )
        build_tsar_bases.Requirements.tasks_resource = self.Requirements.tasks_resource
        build_tsar_bases.enqueue()
        self.Context.build_tsar_bases_task = build_tsar_bases.id

    def _build(self):
        super(PantherBuildIndexBases, self)._build()

        if self.Parameters.trigger_tsar_build:
            self._trigger_tsar_bases_build()

        from yt.wrapper import YtClient
        from sandbox.projects.yabs.panther.build_panther_index import PantherBuildPantherIndex

        self.Context.output_bases_path = os.path.join(self.Parameters.common_builder_basegen_path, self.Context.basegen_dir)
        client = YtClient(
            proxy=self.Parameters.yt_cluster,
            token=sdk2.task.Vault.data(self.author, self.Parameters.yt_token_vault_name))

        logging.info('Create output bases path: {}'.format(self.Context.output_bases_path))
        if not client.exists(self.Context.output_bases_path):
            client.create("map_node", path=self.Context.output_bases_path, recursive=True, ignore_existing=True)

        panther_index_build_task = sdk2.Task[str(PantherBuildPantherIndex)](
            self,
            push_tasks_resource=True,
            description="Child of {}".format(self.id),
            owner=self.owner,
            yt_cluster=self.Parameters.yt_cluster,
            yt_token_vault_name=self.Parameters.yt_token_vault_name,
            input_path=self.Context.input_panther_path,
            output_path=self.Context.output_bases_path,
            force_create=self.Parameters.common_builder_force_create,
            samples_count=self.Parameters.panther_builder_samples_count,
            force_mtime=self.Context.build_mtime,
            tags=self.Parameters.tags,
        )
        panther_index_build_task.Requirements.tasks_resource = self.Requirements.tasks_resource
        panther_index_build_task.enqueue()

        self.Context.panther_index_build_task = panther_index_build_task.id
        if self.Parameters.trigger_tsar_build:
            self._wait_tasks([panther_index_build_task.id, self.Context.build_tsar_bases_task])
        else:
            self._wait_tasks([panther_index_build_task.id])


class TsarBuildTypes(object):
    existing = 'existing'
    from_state = 'from_state'


class PantherBuildTsarBases(PantherBuildBasesCommon):

    class Requirements(sdk2.Requirements):
        semaphores = ctt.Semaphores(
            acquires=[
                ctt.Semaphores.Acquire(name='PANTHER_TSAR_BASES_SEM', weight=1, capacity=1)
            ],
            release=(ctt.Status.Group.BREAK, ctt.Status.Group.FINISH),
        )

    class Parameters(PantherBuildBasesCommon.Parameters):
        with sdk2.parameters.Group("Bases params", collapse=True) as raw_index_params:
            with sdk2.parameters.RadioGroup("Build type") as build_type:
                build_type.values[TsarBuildTypes.existing] = build_type.Value("Use existing tsar vectors table")
                build_type.values[TsarBuildTypes.from_state] = build_type.Value("Build index from current banners state", default=True)

            with build_type.value[TsarBuildTypes.existing]:
                existing_input_path = sdk2.parameters.String("Input panther path in proper format", required=True)

            with build_type.value[TsarBuildTypes.from_state]:
                from_state_banner_table = sdk2.parameters.String("Table name with banner state in state yt dir", default="banners", required=True)

        with sdk2.parameters.Group("Tsar builders params", collapse=True) as tsar_builders_params:
            tsar_builder_requirements_ram = sdk2.parameters.Integer("Tsar task ram requirements", default=8 << 10, required=True)
            tsar_builder_vector_size = sdk2.parameters.Integer("Vector size", default=50, required=True)

    def _get_build_task_ids(self):
        return [self.Context.tsar_index_build_task]

    def _init(self):
        super(PantherBuildTsarBases, self)._init()

        self.Context.import_table = 'tsar'
        self.Context.basegen_dir = 'tsar_{}'.format(self.Context.build_time_str)
        if self.Parameters.build_type == TsarBuildTypes.from_state:
            self.Context.banners_state_table_path = os.path.join(self.Parameters.common_builder_state_path, self.Parameters.from_state_banner_table)

    def _prepare(self):
        super(PantherBuildTsarBases, self)._prepare()

        self.Context.input_panther_path = None
        if self.Parameters.build_type == TsarBuildTypes.existing:
            self.Context.input_vectors_path = self.Parameters.existing_input_path
            return

        elif self.Parameters.build_type == TsarBuildTypes.from_state:
            from sandbox.projects.yabs.panther.import_tsar_vectors import PantherTsarImporter

            self.Context.input_vectors_path = os.path.join(self.Parameters.common_builder_import_path, self.Context.import_table)
            tsar_vectors_importer = sdk2.Task[str(PantherTsarImporter)](
                self,
                kill_timeout=10 * 60 * 60,
                push_tasks_resource=True,
                description="Child of {}".format(self.id),
                owner=self.owner,
                yt_cluster=self.Parameters.yt_cluster,
                yt_token_vault_name=self.Parameters.yt_token_vault_name,
                importer_input_banners_path=self.Context.banners_state_table_path,
                importer_output_table_path=self.Context.input_vectors_path,
                importer_force_create_output=self.Parameters.common_builder_force_create,
                tags=self.Parameters.tags,
            )
            tsar_vectors_importer.Requirements.tasks_resource = self.Requirements.tasks_resource
            tsar_vectors_importer.Requirements.ram = self.Parameters.tsar_builder_requirements_ram
            tsar_vectors_importer.enqueue()

            self.Context.importer_task_id = tsar_vectors_importer.id
            self._wait_tasks([tsar_vectors_importer.id])

    def _build(self):
        super(PantherBuildTsarBases, self)._build()

        from yt.wrapper import YtClient
        from sandbox.projects.yabs.panther.build_tsar_index import PantherBuildTsarIndex

        self.Context.output_bases_path = os.path.join(self.Parameters.common_builder_basegen_path, self.Context.basegen_dir)
        client = YtClient(
            proxy=self.Parameters.yt_cluster,
            token=sdk2.task.Vault.data(self.author, self.Parameters.yt_token_vault_name))

        logging.info('Create output bases path: {}'.format(self.Context.output_bases_path))
        if not client.exists(self.Context.output_bases_path):
            client.create("map_node", path=self.Context.output_bases_path, recursive=True, ignore_existing=True)

        tsar_index_build_task = sdk2.Task[str(PantherBuildTsarIndex)](
            self,
            kill_timeout=10 * 60 * 60,
            push_tasks_resource=True,
            description="Child of {}".format(self.id),
            owner=self.owner,
            yt_cluster=self.Parameters.yt_cluster,
            yt_token_vault_name=self.Parameters.yt_token_vault_name,
            input_path=self.Context.input_vectors_path,
            output_path=self.Context.output_bases_path,
            force_create=self.Parameters.common_builder_force_create,
            vector_size=self.Parameters.tsar_builder_vector_size,
            force_mtime=self.Context.build_mtime,
            tags=self.Parameters.tags,
        )
        tsar_index_build_task.Requirements.tasks_resource = self.Requirements.tasks_resource
        tsar_index_build_task.Requirements.ram = self.Parameters.tsar_builder_requirements_ram
        tsar_index_build_task.enqueue()

        self.Context.tsar_index_build_task = tsar_index_build_task.id
        self._wait_tasks([tsar_index_build_task.id])
