# -*- coding: utf-8 -*-
import os

from nile.api.v1 import aggregators as na
from passport.backend.profile import get_cluster
from passport.backend.profile.utils.helpers import to_date_str
from retrying import retry
import yt.wrapper as yt
from yt.wrapper.errors import YtIncorrectResponse


ASCII_CLASSES_SIZE = """
 ██████╗██╗      █████╗ ███████╗███████╗███████╗███████╗    ███████╗██╗███████╗███████╗
██╔════╝██║     ██╔══██╗██╔════╝██╔════╝██╔════╝██╔════╝    ██╔════╝██║╚══███╔╝██╔════╝
██║     ██║     ███████║███████╗███████╗█████╗  ███████╗    ███████╗██║  ███╔╝ █████╗
██║     ██║     ██╔══██║╚════██║╚════██║██╔══╝  ╚════██║    ╚════██║██║ ███╔╝  ██╔══╝
╚██████╗███████╗██║  ██║███████║███████║███████╗███████║    ███████║██║███████╗███████╗
 ╚═════╝╚══════╝╚═╝  ╚═╝╚══════╝╚══════╝╚══════╝╚══════╝    ╚══════╝╚═╝╚══════╝╚══════╝

"""


def _count_targets(date_start, date_end, input_dir, output_path, yt_client, skip_count):
    date_start_str = to_date_str(date_start)
    date_end_str = to_date_str(date_end)

    input_path = os.path.join(input_dir, '{%s..%s}' % (date_start_str, date_end_str))

    if not skip_count:
        job = get_cluster().job()

        job.table(
            input_path,
        ).groupby(
            'target',
        ).aggregate(
            count=na.count(),
        ).put(
            output_path,
        )

        job.run()

    targets_count = {}
    for row in yt.read_table(output_path, format='yson', raw=False, client=yt_client):
        targets_count[row['target']] = row['count']

    return targets_count


def _prepare_balanced_dataset(date_start, date_end, input_dir, output_path, target_0_fraction, target_1_fraction):
    date_start_str = to_date_str(date_start)
    date_end_str = to_date_str(date_end)

    input_path = os.path.join(input_dir, '{%s..%s}' % (date_start_str, date_end_str))

    job = get_cluster().job()

    def separate_classes(records, target_0_output, target_1_output):
        for record in records:
            if record.target == 0:
                target_0_output(record)
            elif record.target == 1:
                target_1_output(record)

    target_0_records, target_1_records = job.table(input_path).map(separate_classes)

    target_0_records.random(fraction=target_0_fraction).put(output_path)
    target_1_records.random(fraction=target_1_fraction).put(output_path, append=True)

    job.run()


def balance_classes(t0, t1, ratio):
    if ratio == 1.0:
        return t0, 0
    elif ratio == 0.0:
        return 0, t1

    t0_ratio = t0 / float(t0 + t1)
    # если t0 меньше, чем просят
    # то берём t0 целиком
    # далее откусываем от t1
    if t0_ratio < ratio:
        t0_pick = t0
        t1_pick = t0 / ratio * (1 - ratio)
    # если t1 меньше, чем просят
    # то берём t1 целиком
    # далее откусываем от t0
    else:
        t0_pick = t1 / (1 - ratio) * ratio
        t1_pick = t1

    assert 0 <= t0_pick <= t0
    assert 0 <= t1_pick <= t1

    return int(t0_pick), int(t1_pick)


@retry(stop_max_attempt_number=3, wait_fixed=5000, retry_on_exception=(YtIncorrectResponse,))
def prepare_balanced_dataset(date_start, date_end, class_balance, input_dir, output_dir, tmp_dir, yt_client, skip_count):
    date_start_str = to_date_str(date_start)
    date_end_str = to_date_str(date_end)

    targets_count_path = os.path.join(tmp_dir, 'dataset-target-count-%s-%s' % (date_start_str, date_end_str))

    target_count = _count_targets(date_start, date_end, input_dir, targets_count_path, yt_client, skip_count)

    # Чтобы не терялось в выводе терминала
    print ASCII_CLASSES_SIZE
    print 'Class 0 size: {:,}'.format(target_count[0])
    print 'Class 1 size: {:,}'.format(target_count[1])

    target_0_count, target_1_count = balance_classes(target_count[0], target_count[1], class_balance)

    target_0_fraction = target_0_count / float(target_count[0])
    target_1_fraction = target_1_count / float(target_count[1])

    print 'Taking {}% (~{:,}) of class 0'.format(target_0_fraction * 100, target_1_count)
    print 'Taking {}% (~{:,}) of class 1'.format(target_1_fraction * 100, target_1_count)

    output_path = os.path.join(
        output_dir,
        'balanced-dataset-%d-vs-%d--%s--%s' % (
            target_0_count,
            target_1_count,
            date_start_str, date_end_str,
        ),
    )

    _prepare_balanced_dataset(date_start, date_end, input_dir, output_path, target_0_fraction, target_1_fraction)

    return output_path
