from maps.wikimap.stat.assessment.grades_dataset.lib.grades_dataset import (
    _get_basic_grades,
    _get_units_graded_by_experts,
    _replace_missing_counts_with_zeroes,
    _count_grades,
)
from maps.wikimap.stat.libs import nile_ut
from nile.api.v1 import Record


def test_should_get_basic_grades():
    result = nile_ut.yql_run(
        _get_basic_grades,
        grade_table=nile_ut.Table([
            Record(unit_id=1, graded_by=11, graded_at='2021-07-01suffix', value=b'value1', qualification='basic',  test_column=91),
            Record(unit_id=2, graded_by=12, graded_at='2021-07-02suffix', value=b'value2', qualification='expert', test_column=92),
            Record(unit_id=3, graded_by=13, graded_at='2021-07-03suffix', value=b'value3', qualification=None,     test_column=93),
        ]),
    )
    assert sorted([
        Record(unit_id=1, graded_by=11, grade_date='2021-07-01', value='value1'),
    ]) == sorted(result)


def test_should_get_units_graded_by_experts():
    def make_record(idx, certainty):
        assert idx in range(10)
        return Record(
            unit_id=idx,
            entity_domain=f'domain{idx}',
            action_at=f'2021-07-0{idx}suffix',
            action=f'action{idx}',
            task_id=f'task_id{idx}',
            region_id=10 + idx,
            value=f'value{idx}',
            certainty=certainty,
            test_coluimn=90 + idx,
        )

    result = nile_ut.yql_run(
        _get_units_graded_by_experts,
        graded_units_log=nile_ut.Table([
            make_record(1, certainty='expert'),
            make_record(2, certainty='any_other'),
        ]),
    )
    assert sorted([
        Record(
            unit_id=1,
            entity_domain='domain1',
            action_date='2021-07-01',
            action='action1',
            task_id='task_id1',
            region_id=11,
            expert_value='value1',
        ),
    ]) == sorted(result)


def test_should_replace_missing_counts_with_zeroes():
    result = nile_ut.yql_run(
        _replace_missing_counts_with_zeroes,
        nile_ut.Table([
            Record(cnt1=None, cnt2=0,    cnt3=10),
            Record(cnt1=10,   cnt2=None, cnt3=0),
            Record(cnt1=0,    cnt2=10,   cnt3=None),
        ]),
        'cnt1', 'cnt2'
    )
    assert sorted([
        Record(cnt1=0,  cnt2=0,  cnt3=10),
        Record(cnt1=10, cnt2=0,  cnt3=0),
        Record(cnt1=0,  cnt2=10, cnt3=None),
    ]) == sorted(result)


def test_should_count_grades():
    result = nile_ut.yql_run(
        _count_grades,
        basic_grades=nile_ut.Table([
            Record(unit_id=1, graded_by=11, grade_date='2021-07-01', value='correct'),
            Record(unit_id=1, graded_by=11, grade_date='2021-07-01', value='correct'),
            Record(unit_id=1, graded_by=11, grade_date='2021-07-01', value='correct'),
            Record(unit_id=1, graded_by=11, grade_date='2021-07-01', value='incorrect'),

            Record(unit_id=2, graded_by=11, grade_date='2021-07-01', value='incorrect'),

            Record(unit_id=3, graded_by=11, grade_date='2021-07-01', value='incorrect'),
            Record(unit_id=3, graded_by=11, grade_date='2021-07-01', value='correct'),
        ]),
        units_graded_by_experts=nile_ut.Table([
            Record(unit_id=1, action_date='2021-06-31', action='action1', task_id='task_id1', region_id=111, expert_value='correct'),
            Record(unit_id=2, action_date='2021-06-31', action='action1', task_id='task_id1', region_id=111, expert_value='incorrect'),
            Record(unit_id=3, action_date='2021-06-31', action='action1', task_id='task_id1', region_id=111, expert_value='conflict'),
        ]),
        puid_map=nile_ut.Table([
            Record(puid=11, puid_tree='puid_tree1'),
        ]),
        tariffs_with_dates=nile_ut.Table([
            Record(date='2021-06-31', task_id='task_id1', task_name_tree='task_name_tree1', test_column3=9991),
        ]),
        major_regions_map=nile_ut.Table([
            Record(region_id='111', region_name='region_name_tree1'),
        ])
    )

    assert sorted([
        Record(
            grade_date='2021-07-01',
            graded_by_tree='puid_tree1',
            task_name_tree='task_name_tree1',
            region_name_tree='region_name_tree1',
            action='action1',
            true_correct=3,
            false_correct=0,
            true_incorrect=1,
            false_incorrect=1,
            conflicts=2,
        ),
    ]) == sorted(result)
