# -*- coding: utf-8 -*-
from datacloud.dev_utils.yt import yt_utils, converter
from datacloud.stability import stability
from datacloud.dev_utils.logging.logger import get_basic_logger

logger = get_basic_logger(__name__)

__all__ = [
    'CExpandBySegment',
    'CStabilitySegment'
]

TAG = 'StabilitySegments'
BORDER = 1000


def _expand_by_segment(rec):
    if 'stability_segments' in rec:
        for segment in rec['stability_segments']:
            yield {
                'partner_id': rec['partner_id'],
                'score_id': rec['score_id'],
                'features': rec['features'],
                'segment': segment
            }


class CExpandBySegment(converter.TableConverter):
    def __init__(self, yt_client):
        super(CExpandBySegment, self).__init__(yt_client)

    def _convert(self, input_table, output_table):
        self._yt_client.run_map(
            _expand_by_segment,
            input_table,
            output_table,
            spec={'title': '[{}] Expand by segment'.format(TAG)}
        )
        self._yt_client.run_sort(
            output_table,
            sort_by=('partner_id', 'score_id', 'segment'),
            spec={'title': '[{}] Sort expanded by segment'.format(TAG)}
        )


def _count_segments_reducer(_, recs):
    counter = 0
    for rec in recs:
        counter += 1
    yield {'segment': rec['segment'], 'counter': counter}


class CStabilitySegment(converter.TableConverter):
    def __init__(self, yt_client, stability_rec):
        super(CStabilitySegment, self).__init__(yt_client)
        self._stability_rec = stability_rec
        self._output_table_schema = [  # TODO: Create schema for output table

        ]

    def _convert(self, input_table, output_table):
        with self._yt_client.Transaction():
            self._yt_client.run_sort(
                input_table,
                sort_by=['partner_id', 'score_id', 'segment'],
                spec={'title': '[Stability] Sort features'})
            stability_tables = []
            for segment in self._get_available_segments(input_table):
                table = '//tmp/xprod-stability-part-table-{}'.format(segment)
                stability.FeaturesToStabilityBinsConverter(self._yt_client, n_bins=100)(
                    yt_utils.TablePath(  # Extract part of table by exact_key
                        input_table,
                        exact_key=[self._stability_rec.partner_id, self._stability_rec.score_id, segment]),
                    table
                )
                stability_tables.append(table)
            if stability_tables:
                self._yt_client.run_sort(
                    stability_tables,
                    output_table,
                    sort_by=('segment', 'feature'),
                    spec={'title': '[{}] Merge segment stability tables'.format(TAG)}
                )
            else:
                logger.warn('No input tables to create segment stability table')
                if not self._yt_client.exists(output_table):
                    self._yt_client.create('table', output_table)

    def _get_available_segments(self, input_features_table):
        segments = []
        with self._yt_client.TempTable('//tmp') as segments_table:
            self._yt_client.run_map_reduce(
                None,
                _count_segments_reducer,
                yt_utils.TablePath(input_features_table, columns=['segment']),
                segments_table,
                reduce_by='segment',
                spec={'title': '[{}] Count records in each segment'.format(TAG)}
            )
            for rec in self._yt_client.read_table(segments_table):
                if rec['counter'] > BORDER:
                    segments.append(rec['segment'])
        return segments
