from maps.wikimap.stat.assessment.tasks_dataset.lib.tasks_dataset import (
    _add_grades_to_tasks,
    _date_extractor,
    _calc_adjusted_values,
    _count_tasks,
    _count_graded_units,
    _recount_for_dimensions,
    make_dataset,
)
from maps.wikimap.stat.libs import nile_ut
from nile.api.v1 import Record
from cyson import UInt


def test_should_extract_date():
    def get_date(stream):
        return stream.project(date=_date_extractor('datetime'))

    result = nile_ut.yql_run(
        get_date,
        nile_ut.Table([Record(datetime='2021-07-01suffix')]),
    )
    assert [Record(date='2021-07-01')] == result


def test_should_count_graded_units():
    record = Record(
        task_id='task_id1',
        region_id='1',
        action_at='2021-07-01suffix',
        action_by=11,
        action='action1',
        value='correct',
    )
    result_record = record.transform(
        'action_at',
        'value',
        action_date='2021-07-01',
        correct=None,
        incorrect=None,
        conflicts=None,
    )

    result = nile_ut.yql_run(
        _count_graded_units,
        graded_units_log=nile_ut.Table([
            record,
            record.transform(test_column=92),
            record.transform(task_id='task_id2'),
            record.transform(region_id='2'),
            record.transform(action_at='2021-07-02'),
            record.transform(action_by=12),
            record.transform(action='action2'),
            record.transform(value='incorrect'),
            record.transform(value='conflict'),
        ]),
    )
    assert sorted([
        result_record.transform(correct=2, incorrect=1, conflicts=1),
        result_record.transform(correct=1, task_id='task_id2'),
        result_record.transform(correct=1, region_id='2'),
        result_record.transform(correct=1, action_date='2021-07-02'),
        result_record.transform(correct=1, action_by=12),
        result_record.transform(correct=1, action='action2'),
    ]) == sorted(result)


def test_should_count_tasks():
    record = Record(
        task_id='task_id1',
        region_id='1',
        action_at='2021-07-02suffix',
        action_by=11,
        action='action1',
    )
    result_record = record.transform(
        'action_at',
        action_date='2021-07-02'
    )

    result = nile_ut.yql_run(
        _count_tasks,
        tasks_log=nile_ut.Table([
            record.transform(test_column=91),
            record.transform(test_column=92),
            record.transform(test_column=93, task_id='task_id2'),
            record.transform(test_column=94, region_id='2'),
            record.transform(test_column=95, action_at='2021-07-03suffix'),
            record.transform(test_column=96, action_by=12),
            record.transform(test_column=97, action='action2'),
        ]),
    )
    assert sorted([
        result_record.transform(total=2),
        result_record.transform(total=1, task_id='task_id2'),
        result_record.transform(total=1, region_id='2'),
        result_record.transform(total=1, action_date='2021-07-03'),
        result_record.transform(total=1, action_by=12),
        result_record.transform(total=1, action='action2'),
    ]) == sorted(result)


def test_should_add_grades_to_tasks():
    task_record = Record(
        task_id='task_id1',
        region_id='1',
        action_date='2021-07-02',
        action='action1',
        total=12,
    )

    grade_record = Record(
        task_id='task_id1',
        region_id='1',
        action_date='2021-07-02',
        action='action1',
    )

    result = nile_ut.yql_run(
        _add_grades_to_tasks,
        job=nile_ut.Job(),
        task_counts=nile_ut.Table([
            task_record.transform(action_by=11),
            task_record.transform(action_by=22),
        ]),
        graded_unit_counts=nile_ut.Table([
            grade_record.transform(action_by=11, correct=2, incorrect=1, conflicts=1),
            grade_record.transform(action_by=11, correct=1, incorrect=1),
        ])
    )
    assert sorted([
        task_record.transform(action_by=11, correct=3, incorrect=2, conflicts=1),
        task_record.transform(action_by=22, correct=0, incorrect=0, conflicts=0),
    ]) == sorted(result)


def test_should_calc_adjusted_values():
    result = nile_ut.yql_run(
        _calc_adjusted_values,
        tasks_with_grades=nile_ut.Table([
            Record(total=12, correct=2, incorrect=1, conflicts=1, test_column=1),
            Record(total=12, correct=1, incorrect=1, conflicts=0, test_column=2),
            Record(total=12, correct=3, incorrect=2, conflicts=1, test_column=3),
        ]),
    )

    assert sorted([
        Record(total=12, correct=2, incorrect=1, conflicts=1, correct_adjusted=6.0, incorrect_adjusted=3.0, conflicts_adjusted=3.0, test_column=1),
        Record(total=12, correct=1, incorrect=1, conflicts=0, correct_adjusted=6.0, incorrect_adjusted=6.0, conflicts_adjusted=0.0, test_column=2),
        Record(total=12, correct=3, incorrect=2, conflicts=1, correct_adjusted=6.0, incorrect_adjusted=4.0, conflicts_adjusted=2.0, test_column=3),
    ]) == sorted(result)


def test_recount_for_dimensions():
    tasks_record = Record(
        action_by=11,
        action_date='2021-07-29',
        region_id=1,
        task_id='task_id1',
    )

    result_record = Record(
        action_by_tree='user1',
        action_date='2021-07-29',
        region_name_tree='region_tree1',
        task_name_tree='task_tree13',
    )

    result = nile_ut.yql_run(
        _recount_for_dimensions,
        tasks_with_adjusted_grades=nile_ut.Table([
            tasks_record.transform(
                action='action1',
                correct=6, incorrect=2, conflicts=1, total=18,
                correct_adjusted=12.0, incorrect_adjusted=4.0, conflicts_adjusted=2.0,
            ),

            tasks_record.transform(
                action='action2',
                correct=2, incorrect=1, conflicts=0, total=9,
                correct_adjusted=6.0, incorrect_adjusted=3.0, conflicts_adjusted=0.0,
            ),
        ]),
        puid_map=nile_ut.Table([
            Record(puid=UInt(11), puid_tree='user1'),
            Record(puid=UInt(12), puid_tree='user2'),
        ]),
        tariffs_with_dates=nile_ut.Table([
            Record(tariff_date='2021-07-01', task_id='task_id1', task_name_tree='task_tree11', test_column1=9991),
            Record(tariff_date='2021-07-28', task_id='task_id1', task_name_tree='task_tree12', test_column1=9993),
            Record(tariff_date='2021-07-29', task_id='task_id1', task_name_tree='task_tree13', test_column1=9995),
        ]),
        major_regions_map=nile_ut.Table([
            Record(region_id='1', region_name='region_tree1', test_column2=99991),
            Record(region_id='2', region_name='region_tree2', test_column2=99993),
        ]),
    )

    assert sorted([
        result_record.transform(
            action_tree='\tall\t',
            total=27, correct=8, incorrect=3, conflicts=1,
            correct_adjusted=18.0, incorrect_adjusted=7.0, conflicts_adjusted=2.0,
        ),
        result_record.transform(
            action_tree='\tall\taction1\t',
            total=18, correct=6, incorrect=2, conflicts=1,
            correct_adjusted=12.0, incorrect_adjusted=4.0, conflicts_adjusted=2.0,
        ),
        result_record.transform(
            action_tree='\tall\taction2\t',
            total=9, correct=2, incorrect=1, conflicts=0,
            correct_adjusted=6.0, incorrect_adjusted=3.0, conflicts_adjusted=0.0,
        ),
    ]) == sorted(result)


def test_should_make_dataset_for_one_date():
    # NOTE: integral test, doesn't cover all cases
    result = nile_ut.yql_run(
        make_dataset,
        job=nile_ut.Job(),
        tasks_log=nile_ut.Table([
            Record(action_at='2021-07-29suffix1', region_id=1, task_id='task_id1', action='action1', action_by=11, test_column1=93),
            Record(action_at='2021-07-29suffix2', region_id=1, task_id='task_id1', action='action1', action_by=11, test_column1=94),
            Record(action_at='2021-07-29suffix3', region_id=1, task_id='task_id1', action='action1', action_by=11, test_column1=95),
            Record(action_at='2021-07-29suffix4', region_id=1, task_id='task_id1', action='action2', action_by=11, test_column1=96),
        ]),
        graded_units_log=nile_ut.Table([
            Record(action_at='2021-07-29suffix5', region_id=1, task_id='task_id1', action='action1', action_by=11, value='correct', test_column2=991),
            Record(action_at='2021-07-29suffix6', region_id=1, task_id='task_id1', action='action1', action_by=11, value='incorrect', test_column2=992),
            Record(action_at='2021-07-29suffix7', region_id=1, task_id='task_id1', action='action2', action_by=11, value='incorrect', test_column2=993),
        ]),
        puid_map=nile_ut.Table([
            Record(puid=UInt(11), puid_tree='user1'),
            Record(puid=UInt(12), puid_tree='user2'),
        ]),
        tariffs_with_dates=nile_ut.Table([
            Record(tariff_date='2021-07-01', task_id='task_id1', task_name_tree='task_tree11', test_column3=9991),
            Record(tariff_date='2021-07-28', task_id='task_id1', task_name_tree='task_tree12', test_column3=9993),
            Record(tariff_date='2021-07-29', task_id='task_id1', task_name_tree='task_tree13', test_column3=9995),
        ]),
        major_regions_map=nile_ut.Table([
            Record(region_id='1', region_name='region_tree1', test_column4=99991),
            Record(region_id='2', region_name='region_tree2', test_column4=99993),
        ]),
    )

    result_record = Record(
        action_by_tree='user1',
        action_date='2021-07-29',
        region_name_tree='region_tree1',
        task_name_tree='task_tree13',
        conflicts=0,
        conflicts_adjusted=0.0,
    )

    assert sorted([
        result_record.transform(action_tree='\tall\t',          total=4, correct=1, correct_adjusted=1.5, incorrect=2, incorrect_adjusted=2.5),
        result_record.transform(action_tree='\tall\taction1\t', total=3, correct=1, correct_adjusted=1.5, incorrect=1, incorrect_adjusted=1.5),
        result_record.transform(action_tree='\tall\taction2\t', total=1, correct=0, correct_adjusted=0.0, incorrect=1, incorrect_adjusted=1.0),
    ]) == sorted(result)
