# -*- coding: utf-8 -*-
import yt.wrapper as yt_wrapper
from datetime import datetime, timedelta

from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.time.patterns import FMT_DATE
from datacloud.stability.crypta_stability.constants import (
    DEFAULT_KEY, TMP_FOLDER, DEFAULT_TAG, PARTNERS_TO_EXCLUDE
)


logger = get_basic_logger(__name__)


def get_creation_date(yt_client, table):
    return yt_client.get_attribute(table, 'creation_time').split('T')[0]


class EventsByDateMapper:
    def __init__(self, days_to_take=14, exclude_partners=PARTNERS_TO_EXCLUDE):
        min_dt = datetime.now() - timedelta(days=days_to_take)
        self.min_dt_str = min_dt.strftime(FMT_DATE)
        self.exclude_partners = exclude_partners

    def __call__(self, rec):
        if rec['retro_date'] >= self.min_dt_str and \
                rec['partner_id'] not in self.exclude_partners:

            yield rec


def count_reducer(key, _):
    yield {'count': 1}


def count_external_id(
        yt_client,
        table_path,
        key=DEFAULT_KEY,
        tmp_folder=TMP_FOLDER,
        tag=DEFAULT_TAG):

    with yt_client.TempTable(tmp_folder) as temp_table:
        yt_client.run_map_reduce(
            None,
            count_reducer,
            yt_client.TablePath(
                table_path,
                columns=[key]
            ),
            temp_table,
            reduce_by=key,
            spec={
                'title': '[{}] Count Distinct {}'.format(tag, key),
                'use_columnar_statistics': True
            }
        )

        return yt_client.row_count(temp_table)


@yt_wrapper.with_context
class JoinCidsFromScores:
    def __init__(self, key_field=DEFAULT_KEY):
        self.key_field = key_field

    def __call__(self, _, recs, context):
        interesting_cid = False
        for rec in recs:
            if context.table_index == 0:
                interesting_cid = True
            elif not interesting_cid:
                break
            else:
                yield {
                    self.key_field: rec[self.key_field]
                }


def join_cids_with_cids(key, recs):
    if sum(1 for rec in recs) > 1:
        yield {'cid': key['cid']}


def unique_cid_reducer(key, _):
    yield {'cid': key['cid']}


def id_value_counter(key, recs):
    yield {
        'cid': key['cid'],
        'count': sum(1 for rec in recs)
    }


@yt_wrapper.aggregator
def accumulate_counter(rows):
    yield {'count': sum(row['count'] for row in rows)}


def reduce_unique_cids(yt_client, input_paths, result, tag, sync=True):
    if yt_client.exists(result):
        result_date = get_creation_date(yt_client, result)
        input_dates = [get_creation_date(yt_client, path) for path in input_paths]

        if result_date >= max(input_dates):
            logger.info('{} already calculated, skipping it'.format(result))
            return None

    input_tables = [yt_client.TablePath(path, columns=['cid']) for path in input_paths]

    yt_client.create('table', result, force=True)
    op_id = yt_client.run_map_reduce(
        None,
        unique_cid_reducer,
        input_tables,
        result,
        reduce_by='cid',
        spec={
            'title': '[{}] get unique cids'.format(tag),
            'use_columnar_statistics': True
        },
        sync=sync
    )

    return op_id
