import sandbox.sdk2 as sdk2
from sandbox.common.types.client import Tag
import logging
import sys
from sandbox.sandboxsdk import svn


class PipelineState:  # Ersatz enum
    LEARN_LOLITA_MX, LEARN_TRAFARET_MX = range(2)


class IterationProcess:
    @classmethod
    def apply_step(cls, task_name, src_table, dst_table):
        logging.info('apply_step')
        from simulate_auction_py.lib.intrinsic_ctr.utils import apply_mx
        apply_mx(task_name,
                 date='latest',
                 src_tables=[src_table],
                 dst_tables=[dst_table],
                 target_field='baseline',
                 apply_sigmoid=False)

    @classmethod
    def start_learn_step(cls, task_path, learn_log):
        logging.info('learn_step')
        from simulate_auction_py.lib.intrinsic_ctr.utils import NirvanaLearnTask
        sys.path.append(svn.Arcadia.get_arcadia_src_dir("arcadia:/arc/trunk/arcadia/yabs/utils/learn-tasks2/leshchev/BSDEV-67490/"))
        task, = NirvanaLearnTask.read_tasks(['search_premium_ctr_trafaret_factors.yml'])
        logging.info(str(task))

    @classmethod
    def wait_learn_step(cls, workflow_id, token):
        logging.info('wait_learn_step')
        from nirvana_api import NirvanaApi, ExecutionResult, ExecutionStatus
        nirvana = NirvanaApi(token)

        meta = nirvana.get_workflow_meta_data(workflow_instance_id=workflow_id)
        flow_url = 'https://nirvana.yandex-team.ru/flow/%s' % (meta['guid'])
        instance_url = 'https://nirvana.yandex-team.ru/flow/%s/%s/graph' % (meta['guid'], meta['instanceId'])
        name = meta['name']

        execution_state = nirvana.get_execution_state(workflow_instance_id=workflow_id)
        if execution_state['status'] == ExecutionStatus.running:
            logging.info('\nRunning flow instance\n'
                         '  name: %s\n'
                         '  flow url: %s\n'
                         '  instance url: %s\n'
                         'progress=%.2f%%' % (name, instance_url, flow_url, execution_state['progress'] * 100.))
            return False, workflow_id

        elif execution_state['status'] == ExecutionStatus.completed:
            logging.info('\nFlow is completed\n'
                         '  name: %s\n'
                         '  instance url: %s\n'
                         '  flow url: %s\n'
                         'result: %s' % (name, instance_url, flow_url, execution_state['result']))

            if execution_state['result'] == ExecutionResult.failure:
                new_workflow_instance_id = nirvana.clone_workflow_instance(workflow_instance_id=workflow_id)
                logging.info('\nStarting new flow instance')
                nirvana.start_workflow(workflow_instance_id=new_workflow_instance_id)
                return False, new_workflow_instance_id

            else:
                return True, workflow_id


class IntrinsicPositionalCtrPipeline(sdk2.Task):
    class Requirements(sdk2.Requirements):
        client_tags = Tag.GENERIC

    class Parameters(sdk2.Parameters):
        description = "Intrinsic positional CTR learning pipeline: https://st.yandex-team.ru/BSDEV-67490"
        kill_timeout = 30 * 60  # TODO: 30 min is enough

        working_dir = sdk2.parameters.String('Path to working YT dir')
        learn_log = sdk2.parameters.String('Name of YT log for learning with both lolita and trafaret factors')
        pipeline_iterations = sdk2.parameters.Integer('Number of pipeline iterations')
        task_suffix = sdk2.parameters.String('Learning task specific suffix', default='__sandbox_pipeline')

        lolita_factors_task = sdk2.parameters.String('Path to base .yml file with ml-engine task (Lolita factors)')
        trafaret_factors_task = sdk2.parameters.String('Path to base .yml file with ml-engine task (Trafaret factors)')

    class Context(sdk2.Task.Context):
        pipeline_iteration = 0
        task_state = PipelineState.LEARN_LOLITA_MX

        learn_context = dict()  # {1: {'lolita_factors_mx': {'task_name':'some_task_name'}, 'trafaret_factors_mx:'{...}}, 2: {...}, ...}

    def on_execute(self):
        logging.info('on_execute (Dummy bin version)')

        if self.Context.pipeline_iteration == 0:
            logging.info('start_learn_step')
            self.Context.pipeline_iteration = 1

        elif self.Context.pipeline_iteration >= self.Parameters.pipeline_iterations:
            logging.info('C\'est tout')
            return

        else:
            logging.info('wait_learn_step')
            if IterationProcess.wait_learn_step():
                logging.info('apply_step')
                logging.info('start_learn_step')

                if self.Context.task_state == PipelineState.LEARN_TRAFARET_MX:
                    self.Context.task_state = PipelineState.LEARN_LOLITA_MX
                    self.Context.pipeline_iteration += 1
                else:
                    self.Context.task_state = PipelineState.LEARN_TRAFARET_MX

            else:
                wait_time = 10   # TODO it should be param of task
                raise sdk2.WaitTime(wait_time)
