from __future__ import print_function
from datacloud.dev_utils.yt import yt_config_table
from datacloud.dev_utils.logging.logger import get_basic_logger

from datacloud.model_applyer.lib.model_config import ModelConfig
from datacloud.model_applyer.lib.models import (
    PredictModel, PredictProbaModel, InvertedPredictProbaModel, CatboostPredictProbaModel)

logger = get_basic_logger(__name__)


DEFAULT_API_MODELS_CONFIG_TABLE_PATH = '//home/x-products/production/config/datacloud/api-models-config'

__all_models = [
    PredictModel, PredictProbaModel, InvertedPredictProbaModel, CatboostPredictProbaModel,
]

all_models = {
    model.__name__: model for model in __all_models
}

# TODO: Написать тесты


__all__ = [
    'ApiModelsConfigTable'
]


class ApiModelsConfigTable(yt_config_table.ConfigTable):
    TRANSFER_ON_KEY = 'transfer_on'
    SCORE_BLEND_ON_KEY = 'score_blend_on'
    CRYPTA_BLEND_ON_KEY = 'crypta_blend_on'

    def __init__(self, table_path=DEFAULT_API_MODELS_CONFIG_TABLE_PATH, yt_client=None):
        schema = [
            {'name': 'partner_id', 'type': 'string', 'sort_order': 'ascending'},
            {'name': 'score_name', 'type': 'string', 'sort_order': 'ascending'},
            {'name': 'model_class', 'type': 'string'},
            {'name': 'features', 'type': 'string'},
            {'name': 'save_info', 'type': 'boolean'},
            {'name': 'is_active', 'type': 'boolean'},
            {'name': 'additional', 'type': 'any'},
        ]
        super(ApiModelsConfigTable, self).__init__(
            table_path, schema, yt_client
        )

    def add_model(self, partner_id, score_name, model_class, features,
                  save_info=False, is_active=True, additional=None):
        additional = additional or {}

        record = {
            'partner_id': partner_id,
            'score_name': score_name,
            'model_class': model_class,
            'features': features,
            'save_info': save_info,
            'is_active': is_active,
            'additional': additional
        }
        self.insert_records([record])

    def remove_model(self, partner_id, score_name):
        self.remove_records([{'partner_id': partner_id, 'score_name': score_name}])

    def list_models(self):
        for record in self.list_records():
            yield record

    def list_active_models(self):
        for model_rec in self.list_models():
            if model_rec.get('is_active', False):
                yield model_rec

    def get_model(self, partner_id, score_name):
        request = \
            '*' \
            ' FROM [{table_path}]' \
            ' WHERE partner_id = "{partner_id}" AND score_name = "{score_name}"'.format(
                table_path=self.table_path, partner_id=partner_id, score_name=score_name)
        return self.get_record(request)

    def get_model_or_raise(self, partner_id, score_name):
        config_rec = self.get_model(partner_id, score_name)
        if config_rec is None:
            raise RuntimeError('Not found {partner_id} {score_name} in config!'.format(
                partner_id=partner_id,
                score_name=score_name
            ))

        return config_rec

    def get_model_class_and_config(self, partner_id, score_name):
        record = self.get_model(partner_id, score_name)
        assert record is not None, 'Record is None, check partner_id and score_name'
        model_config = ModelConfig.from_json(record)
        model_class = all_models[record['model_class']]
        return model_class, model_config

    def get_all_active_model_classes_and_configs(self):
        for record in self.list_active_models():
            # TODO: костыль, убрать модели с ClusterFeatureBase763 из таблицы
            if 'ClusterFeatureBase763' in record['features']:
                continue
            yield self.get_model_class_and_config(record['partner_id'], record['score_name'])

    def get_score_blending_on(self, partner_id, score_name):
        model = self.get_model(partner_id, score_name)
        return model['additional'].get(self.SCORE_BLEND_ON_KEY, False)

    def save_info_on(self, partner_id, score_name):
        model = self.get_model(partner_id, score_name)
        model['save_info'] = True
        self.insert_records([model])

    def save_info_off(self, partner_id, score_name):
        model = self.get_model(partner_id, score_name)
        model['save_info'] = False
        self.insert_records([model])

    def turn_on(self, partner_id, score_name):
        model = self.get_model(partner_id, score_name)
        model['is_active'] = True
        self.insert_records([model])

    def turn_off(self, partner_id, score_name):
        model = self.get_model(partner_id, score_name)
        model['is_active'] = False
        self.insert_records([model])

    def change_model_class(self, partner_id, score_name, model_class):
        model = self.get_model(partner_id, score_name)
        model['model_class'] = model_class
        self.insert_records([model])

    def turn_transfer_on(self, partner_id, score_name):
        model = self.get_model(partner_id, score_name)
        model['additional'][self.TRANSFER_ON_KEY] = True
        self.insert_records([model])

    def turn_transfer_off(self, partner_id, score_name):
        model = self.get_model(partner_id, score_name)
        model['additional'][self.TRANSFER_ON_KEY] = False
        self.insert_records([model])

    def turn_score_blending_on(self, partner_id, score_name):
        model = self.get_model(partner_id, score_name)
        model['additional'][self.SCORE_BLEND_ON_KEY] = True
        self.insert_records([model])

    def turn_score_blending_off(self, partner_id, score_name):
        model = self.get_model(partner_id, score_name)
        model['additional'][self.SCORE_BLEND_ON_KEY] = False
        self.insert_records([model])


def sample_add_model_to_table(is_create_table=False, is_upload_models=False, is_list_models=False):
    config = ApiModelsConfigTable()

    if is_create_table:
        config.create_table()  # force=True

    if is_upload_models:
        config.add_model(
            partner_id='mkb',
            score_name='mkb_xprod_960_m3',
            model_class='PartnerModel',
            features='DSSMFeatureBase400 ClusterFeatureBase767',
            is_active=True)

    # List models +
    if is_list_models:
        print('Models are')
        for record in config.list_models():
            print(record)
