# -*- coding: utf-8 -*-
import yt.wrapper as yt

from datacloud.dev_utils.data.data_utils import array_tostring
from datacloud.dev_utils.yt.yt_utils import get_yt_client, create_folders
from datacloud.features.credit_scoring_events.build_config import CSEBuildConfig


@yt.with_context
class CSEAgregateReducer:
    def __init__(self, is_retro=False, external_id_key='cid'):
        self.is_retro = is_retro
        self.external_id_key = external_id_key

    def __call__(self, key, recs, context):
        base_recs = []
        for rec in recs:
            if context.table_index == 0:
                base_recs.append(dict(rec))
            elif len(base_recs) == 0:
                break
            else:
                for base_rec in base_recs:
                    rec2yield = {
                        self.external_id_key: rec[self.external_id_key],
                        'target': base_rec['target']
                    }
                    if self.is_retro:
                        rec2yield['retro_date'] = base_rec['retro_date']
                        rec2yield['partner'] = base_rec['partner']

                    yield rec2yield


@yt.with_context
class CSERetroAgregateReducer:
    def __init__(self, partners, external_id_key='external_id'):
        self.partners = partners
        self.external_id_key = external_id_key

    def __call__(self, key, recs, context):
        base_recs = []
        for rec in recs:
            if context.table_index == 0:
                base_recs.append(dict(rec))
            elif len(base_recs) == 0:
                break
            else:
                for base_rec in base_recs:
                    if rec['retro_date'] < base_rec['retro_date'] and rec['partner'] not in self.partners:
                        rec[self.external_id_key] = '{}_{}'.format(
                            base_rec[self.external_id_key],
                            base_rec['retro_date']
                        )
                        yield rec


@yt.with_context
class CSEFeaturesReducer:
    def __init__(self, checks):
        assert all(callable(check) for check in checks), 'Checks should be callable!'
        self.checks = checks
        self.features_count = len(self.checks)

    def __call__(self, key, recs, context):
        base_rec = None

        targets = []
        for rec in recs:
            if context.table_index == 0:
                base_rec = rec
            elif base_rec is None:
                return
            else:
                targets.append(rec['target'])

        features = [False] * self.features_count
        for i, check in enumerate(self.checks):
            features[i] = check(targets)

        rec2yield = dict(key)
        rec2yield['features'] = array_tostring(features)

        yield rec2yield


def build_vectors(build_config=None, yt_client=None):
    yt_client = yt_client or get_yt_client()
    build_config = build_config or CSEBuildConfig()

    with yt_client.Transaction():
        create_folders((build_config.data_dir,), yt_client)
        if build_config.is_retro:
            yt_client.run_sort(build_config.id_value_to_ext_id, sort_by=['id_type', 'id_value'])

        yt_client.run_reduce(
            CSEAgregateReducer(is_retro=build_config.is_retro, external_id_key=build_config.ext_id_key),
            [
                build_config.CREDIT_SCORING_EVENTS,
                build_config.id_value_to_ext_id,
            ],
            build_config.aggregates_table,
            reduce_by=['id_type', 'id_value'],
            spec=dict(
                title='[{}] Prepare credit scoring events agregates'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.aggregates_table,
            sort_by=build_config.ext_id_key,
            spec=dict(
                title='[{}] Sort scoring events agregates'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )

        aggregates_table = build_config.aggregates_table
        if build_config.is_retro:
            yt_client.run_reduce(
                CSERetroAgregateReducer(partners=build_config.partners),
                [
                    yt_client.TablePath(
                        build_config.raw_table,
                        attributes={'foreign': True}
                    ),
                    build_config.aggregates_table,
                ],
                build_config.aggregates_table2,
                reduce_by=build_config.ext_id_key,
                join_by=build_config.ext_id_key,
                spec=dict(
                    title='[{}] Prepare credit scoring events agregates 2'.format(build_config.tag),
                    **build_config.cloud_nodes_spec
                )
            )
            yt_client.run_sort(
                build_config.aggregates_table2,
                sort_by=build_config.ext_id_key,
                spec=dict(
                    title='[{}] Sort credit scoring events agregates 2'.format(build_config.tag),
                    **build_config.cloud_nodes_spec
                )
            )
            aggregates_table = build_config.aggregates_table2

        yt_client.run_reduce(
            CSEFeaturesReducer(build_config.target_checks),
            [
                build_config.all_ext_id_table,
                aggregates_table
            ],
            build_config.features_table,
            reduce_by=build_config.ext_id_key,
            spec=dict(
                title='[{}] Build CSE Features'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.features_table,
            sort_by=build_config.ext_id_key,
            spec=dict(
                title='[{}] Sort CSE Features'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )

        if not build_config.is_retro:
            yt_client.remove(build_config.aggregates_table)
            yt_client.copy(build_config.features_table, build_config.ready_table, force=True)
