import logging
from sandbox import sdk2

from sandbox.common.utils import server_url
from datetime import datetime, timedelta
from operator import itemgetter
from sandbox.projects.modadvert import resource_types
from sandbox.projects.modadvert.common import modadvert
from sandbox.projects.common.apihelpers import get_last_resource_with_attrs
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.projects.modadvert.common.constants import STARTREK_CLIENT_ENVIROMENTS


class ModadvertRunManualAutobudget(modadvert.ModadvertBaseYtTask):
    """
    MODADVERT-229: Autobudget for manual moderation
    """

    class Requirements(sdk2.Task.Requirements):
        environments = STARTREK_CLIENT_ENVIROMENTS + (
            PipEnvironment('yandex-yt', '0.8.38a1', use_wheel=True),
            PipEnvironment('yandex-yt-yson-bindings-skynet', use_wheel=True),
        )

    class Context(sdk2.Task.Context):
        enriched_dssm_models = []
        enriched_bsml_models = []
        output_table = ''

    class Parameters(modadvert.ModadvertBaseYtTask.Parameters):
        with sdk2.parameters.Group('Binary') as binary_group:  # TODO: Create custom resource types
            gatherer_binary = sdk2.parameters.Resource(
                'Resource with data gatherer', resource_type=resource_types.YA_PACKAGE
            )
            solver_binary = sdk2.parameters.Resource('Resource with solver', resource_type=resource_types.YA_PACKAGE)

        with sdk2.parameters.Group('ML models') as ml_models_group:
            # SDK2 does not support lists of resources
            try_vip = sdk2.parameters.Bool('Try vip factors for models during solving', default=True)
            try_latest_dssm_model = sdk2.parameters.Bool('Try latest DSSM model', default=True)
            try_production_dssm_model = sdk2.parameters.Bool('Try current DSSM model', default=True)
            dssm_models = sdk2.parameters.Resource(
                'DSSM models to try',
                resource_type=resource_types.MODADVERT_DSSM_MODEL,
                multiple=True,
                default=[]
            )

            try_latest_bsml_model = sdk2.parameters.Bool('Try latest BSML model', default=True)
            try_production_bsml_model = sdk2.parameters.Bool('Try current BSML model', default=True)
            bsml_models = sdk2.parameters.Resource(
                'BSML models to try',
                resource_type=resource_types.MODADVERT_BSML_MODEL,
                multiple=True,
                default=[]
            )

        yt_secondary_proxy_url = modadvert.YtCluster('YT Secondary proxy URL', default='hahn')
        run_gatherer = sdk2.parameters.Bool('Collect data first', default=True, required=True)
        with run_gatherer.value[True]:
            prod_table = sdk2.parameters.String('prod-banners table', default='//home/direct-moderate/prod-banners')
            lyncher_results_table = sdk2.parameters.String('Lyncher results table', default='//home/modadvert/lyncher/Results')
            factors = sdk2.parameters.List('Lyncher factors')
            offset_days = sdk2.parameters.Integer('Offset in days', default=2)
            period_days = sdk2.parameters.Integer('Period in days', default=7)
            use_synchrophazotron = sdk2.parameters.Bool('Use synchrophazotron', default=True)

        output_table = sdk2.parameters.String('Output table', default='//home/modadvert/test/korneev/manual_autobudget_data')
        optimized_factors = sdk2.parameters.List('Optimized factors')
        autobudget_history_table = sdk2.parameters.String('Autobudget history table', default='//home/modadvert/test/korneev/manual_autobudget_history')
        autobudget_results_table = sdk2.parameters.String(
            'Autobudget results table (same as history) but does rewrite instead of append',
            default='//home/modadvert/test/korneev/manual_autobudget_results'
        )
        autobudget_solution_node = sdk2.parameters.String('Autobudget solution node', default='//home/modadvert/lyncher/rules')
        delta_step = sdk2.parameters.Integer('Autobudget delta step', default=100)
        limit = sdk2.parameters.List('Weekly limit')
        success_rate = sdk2.parameters.List('Required success rate')
        factor_values = sdk2.parameters.List('Factor values')

    def validate(self):
        super(ModadvertRunManualAutobudget, self).validate()

        if self.Parameters.factor_values and len(self.Parameters.optimized_factors) != len(self.Parameters.factor_values):
            raise ValueError('Lenght of factor values must be equal length of optimized factors')

    @staticmethod
    def parse_model(row, model_prefix):
        return set(
           int(key.split('_')[-1])
           for key in row
           if key.startswith(model_prefix)
        )

    def get_autobudget_models(self, model_prefix):
        import yt.wrapper
        yt_client = yt.wrapper.client.YtClient(config={
            'proxy': {'url': self.Parameters.yt_proxy_url},
            'token': self.get_yt_token(),
        })

        rows = sorted(
            yt_client.read_table(self.Parameters.autobudget_results_table),
            key=itemgetter('accuracy'),
            reverse=True
        )

        for row in rows:
            if 1 == len(self.parse_model(row, 'bsml')) and 1 == len(self.parse_model(row, 'dssm')):
                break
        else:
            raise ValueError('No row containins non-conflicting models')

        return self.parse_model(row, model_prefix)

    def on_before_execute(self):
        super(ModadvertRunManualAutobudget, self).on_before_execute()

        if not self.Context.enriched_bsml_models:
            enriched_bsml_models = [r.id for r in self.Parameters.bsml_models]
            if self.Parameters.try_production_bsml_model:
                production_bsml_model = self.get_autobudget_models('bsml')
                enriched_bsml_models.extend(production_bsml_model)
                logging.info('Production BSML model resource is {}'.format(enriched_bsml_models[-1]))
            if self.Parameters.try_latest_bsml_model:
                last_bsml_model = get_last_resource_with_attrs(resource_types.MODADVERT_BSML_MODEL.name, None).id
                enriched_bsml_models.append(last_bsml_model)
                logging.info('Latest BSML model resource is {}'.format(last_bsml_model))
            self.Context.enriched_bsml_models = list(set(enriched_bsml_models))

        if not self.Context.enriched_dssm_models:
            enriched_dssm_models = [r.id for r in self.Parameters.dssm_models]
            if self.Parameters.try_production_dssm_model:
                production_dssm_model = self.get_autobudget_models('dssm')
                enriched_dssm_models.extend(production_dssm_model)
                logging.info('Production DSSM model resource is {}'.format(enriched_dssm_models[-1]))
            if self.Parameters.try_latest_dssm_model:
                last_dssm_model = get_last_resource_with_attrs(resource_types.MODADVERT_DSSM_MODEL.name, None).id
                enriched_dssm_models.append(last_dssm_model)
                logging.info('Latest DSSM model resource is {}'.format(last_dssm_model))
            self.Context.enriched_dssm_models = list(set(enriched_dssm_models))

        super(ModadvertRunManualAutobudget, self).on_before_execute()
        self.untar_resource(self.Parameters.gatherer_binary)
        self.untar_resource(self.Parameters.solver_binary)

    def on_execute_inner(self):
        self.Context.output_table = self.Parameters.output_table
        if self.Parameters.run_gatherer:
            self.run_command(
                self.create_data_gatherer_command(),
                log_prefix='data_gatherer',
                env=self.create_env()
            )

        for limit in (self.Parameters.limit or []):
            if limit:
                self.run_command(
                    self.create_solver_command(limit=limit),
                    log_prefix='solver',
                    env=self.create_env()
                )

        for success_rate in (self.Parameters.success_rate or []):
            if success_rate:
                self.run_command(
                    self.create_solver_command(success_rate=success_rate),
                    log_prefix='solver',
                    env=self.create_env()
                )

        if len(self.Parameters.factor_values or []) > 0:
            self.run_command(
                self.create_solver_command(factor_values=self.Parameters.factor_values),
                log_prefix='solver',
                env=self.create_env()
            )

    def create_solver_command(self, limit=None, success_rate=None, factor_values=None):
        factor_names = self.Parameters.optimized_factors or []

        if self.Context.enriched_bsml_models:
            templates = ['bsml_{}']
            if self.Parameters.try_vip:
                templates.append('bsml_vip_{}')
            for template in templates:
                factor_names.append('|'.join(
                    template.format(resource_id)
                    for resource_id in sorted(set(self.Context.enriched_bsml_models))
                ))

        if self.Context.enriched_dssm_models:
            templates = ['dssm_class_yellow_{}']
            if self.Parameters.try_vip:
                templates.append('dssm_class_yellow_vip_{}')
            for template in templates:
                factor_names.append('|'.join(
                    template.format(resource_id)
                    for resource_id in sorted(set(self.Context.enriched_dssm_models))
                ))

        command = [
            './autobudget_solver',
            '--yt-proxy-url', self.Parameters.yt_proxy_url,
            '--yt-secondary-proxy-url', self.Parameters.yt_secondary_proxy_url,
            '--data-table-name', self.Context.output_table,
            '--output-table-name', self.Parameters.autobudget_history_table,
            '--output-solution-node', self.Parameters.autobudget_solution_node or '',
            '--factor-names', ','.join(factor_names),
            '--delta-step', str(self.Parameters.delta_step)
        ]

        if self.Parameters.autobudget_results_table:
            command += ['--results-table-name', self.Parameters.autobudget_results_table]

        if limit is not None:
            command += ['--mode', 'fixed_limit', '--limit', str(limit)]
        elif success_rate is not None:
            command += ['--mode', 'fixed_success_percentage', '--success-percentage', str(success_rate)]
        else:
            command += ['--mode', 'fixed_params', '--factor-values', ','.join(factor_values)]

        return command

    def create_data_gatherer_command(self):
        end_time = datetime.now().replace(minute=0, hour=0, second=0, microsecond=0) - timedelta(days=self.Parameters.offset_days)
        start_time = end_time - timedelta(days=self.Parameters.period_days)
        self.Context.output_table = '{}_{}_{}'.format(
            self.Parameters.output_table,
            datetime.strftime(start_time, '%Y%m%d'),
            datetime.strftime(end_time, '%Y%m%d')
        )

        command = [
            './autobudget_data_gatherer',
            '--yt-proxy', self.Parameters.yt_proxy_url,
            '--prod-table', self.Parameters.prod_table,
            '--results-table', self.Parameters.lyncher_results_table,
            '--output-table', self.Context.output_table,
            '--start-time', datetime.strftime(start_time, '%Y-%m-%d %H:%M:%S'),
            '--end-time', datetime.strftime(end_time, '%Y-%m-%d %H:%M:%S'),
            '--sandbox-url', server_url()
        ]
        command += ['--factors'] + (self.Parameters.factors or [])
        if self.Context.enriched_bsml_models:
            command.append('--bsml-resource-ids')
            command.extend(str(r) for r in self.Context.enriched_bsml_models)

        if self.Context.enriched_dssm_models:
            command.append('--dssm-resource-ids')
            command.extend(str(r) for r in self.Context.enriched_dssm_models)

        if self.Parameters.use_synchrophazotron:
            command.extend(['--synchrophazotron-command', str(self.synchrophazotron)])

        return command
