import yt.wrapper as yt_wrapper
from datacloud.config.yt import YT_PROXY
from datacloud.dev_utils.data import data_utils as du
from datacloud.dev_utils.yt import yt_utils
from datacloud.dev_utils.logging.logger import get_basic_logger


logger = get_basic_logger(__name__)


@yt_wrapper.with_context
class CombineFeaturesReducer(object):
    def __init__(self, model_features, ext_id_key='cid'):
        self.model_features = model_features
        self.ext_id_key = ext_id_key
        self.n_tables = len(model_features)

    def check(self, features_presence):
        return sum(features_presence) == self.n_tables

    def __call__(self, key, recs, context):
        features_presence = [False for idx in range(self.n_tables)]
        extracted_features = [None for idx in range(self.n_tables)]
        for rec in recs:
            feature = self.model_features[context.table_index]
            features_presence[context.table_index] = True
            extracted_features[context.table_index] = feature.extract(rec)

        for idx, is_present in enumerate(features_presence):
            if not is_present:
                feature = self.model_features[idx]
                if feature.has_default():
                    extracted_features[idx] = feature.fill_with_default()
                    features_presence[idx] = True
        if self.check(features_presence):
            yield {
                self.ext_id_key: key[self.ext_id_key],
                'features': du.array_tostring(du.combine_features(extracted_features))
            }


class ApplyMap(object):
    def __init__(self, model, ext_id_key, feature_key='features', save_info=False):
        self.model = model
        self.ext_id_key = ext_id_key
        self.feature_key = feature_key
        self.save_info = save_info

        if model.is_local:
            model.load_binary()

    def __call__(self, rec):
        data = du.array_fromstring(rec[self.feature_key])
        res_rec = {
            self.ext_id_key: rec[self.ext_id_key],
            'score': self.model.apply(data)
        }
        if self.save_info:
            res_rec[self.feature_key] = rec[self.feature_key]
            res_rec['score_name'] = self.model.score_name
            res_rec['date'] = self.model.date_str
        yield res_rec


class ModelApplyer(object):
    def __init__(self, models_dict, features_list, combined_features_table,
                 result_scores_table, ext_id_key='cid',
                 is_yt_binary=False, yt_client=None):
        self.models_dict = models_dict
        self.features_list = features_list
        self.combined_features_table = combined_features_table
        self.result_scores_table = result_scores_table
        self.ext_id_key = ext_id_key
        self.yt_client = yt_client or yt_utils.get_yt_client(YT_PROXY)

        self.score_table_schema = [
            {'name': self.ext_id_key, 'type': 'string'},
            {'name': 'features', 'type': 'string'},
            {'name': 'score', 'type': 'double'},
            {'name': 'score_name', 'type': 'string'},
            {'name': 'date', 'type': 'string'}
        ]

    def configure_scores_table(self, yt_table_path):
        logger.info(' yt table path is {}'.format(yt_table_path))
        return self.yt_client.TablePath(yt_table_path, attributes={
            'optimize_for': 'scan',
            'schema': self.score_table_schema,
            'compression_codec': 'brotli_6',
            'erasure_codec': 'lrc_12_2_2'
        })

    def get_features(self, yt_client, date_str):
        raise NotImplementedError()

    def apply(self, yt_client):
        raise NotImplementedError()

    def init_folders(sefl, yt_client):
        raise NotImplementedError()


class SingleModelApplyer(ModelApplyer):
    def __init__(self, models_dict, features_list, combined_features_table,
                 result_scores_table, ext_id_key='cid', is_yt_binary=False):
        super(SingleModelApplyer, self).__init__(
            models_dict=models_dict,
            features_list=features_list,
            combined_features_table=combined_features_table,
            result_scores_table=result_scores_table,
            ext_id_key=ext_id_key,
            is_yt_binary=is_yt_binary)

    def get_features(self, yt_client, date_str=None):
        logger.info(' Start get_features')
        if date_str:
            logger.info(' Features date: {}'.format(date_str))
            tables = [f.get_table_for_date(date_str) for f in self.features_list]
        else:
            logger.info(' No Features date provided, will use default `yt_path`')
            tables = [f.get_last_weekly_table() for f in self.features_list]
        yt_client.run_reduce(
            CombineFeaturesReducer(self.features_list, self.ext_id_key),
            tables,
            self.combined_features_table,
            reduce_by='cid',
            spec={'title': 'Combine features'}
        )
        logger.info(' Done get_features')

    def apply(self, yt_client):
        logger.info(' Start apply')
        logger.info('Self models dict is {}'.format(self.models_dict))
        # It is expected that dict contains one model
        model = self.models_dict.values()[0]
        scores_table = self.configure_scores_table(self.result_scores_table)

        yt_files = []
        if not model.is_local:
            yt_files = [model.binary_path]

        yt_client.run_map(
            ApplyMap(model, self.ext_id_key, save_info=model.save_info),
            self.combined_features_table,
            scores_table,
            yt_files=yt_files,
            spec={
                'title': 'Apply {} model'.format(model.model_name),
                # 'data_size_per_job': 1 * 1024 * 1024
            }
        )
        yt_client.run_sort(
            scores_table,
            sort_by=self.ext_id_key,
            spec={'title': 'Sort after apply model'}
        )
        logger.info(' Done apply')
