from maps.wikimap.stat.assessment.datasets.fixed_units.lib.dataset import (
    _get_incorrect_units,
    _add_puid_info,
    _count_unit_fixes,
)
from maps.wikimap.stat.libs import nile_ut
from nile.api.v1 import Record
from qb2.api.v1 import typing


def test_should_get_incorrect_units_by_basic_grades():
    result = nile_ut.yql_run(
        _get_incorrect_units,
        unit_table=nile_ut.Table(
            data=[
                Record(
                    basic_incorrect=2, basic_correct=1, last_expert_value=None, fixed_at=None,
                    entity_domain='domain1', action_by=1, action_at='2022-07-02suffix1',
                    test_column='test1'
                ),
                Record(
                    basic_incorrect=1, basic_correct=0, last_expert_value=None, fixed_at='2022-07-03suffix2',
                    entity_domain='domain2', action_by=2, action_at='2022-07-02suffix2',
                    test_column='test2'
                ),
                Record(
                    basic_incorrect=2, basic_correct=0, last_expert_value=None, fixed_at='2022-07-02suffix3',
                    entity_domain='domain3', action_by=3, action_at='2022-07-02suffix3',
                    test_column='test3'
                ),

                Record(
                    basic_incorrect=1, basic_correct=2, last_expert_value=None, fixed_at=None,
                    entity_domain='domain4', action_by=4, action_at='2022-07-02suffix4',
                    test_column='test4'
                ),
                Record(
                    basic_incorrect=0, basic_correct=0, last_expert_value=None, fixed_at='2022-07-03suffix5',
                    entity_domain='domain5', action_by=5, action_at='2022-07-02suffix5',
                    test_column='test5'
                ),
                Record(
                    basic_incorrect=1, basic_correct=1, last_expert_value=None, fixed_at='2022-07-02suffix6',
                    entity_domain='domain6', action_by=6, action_at='2022-07-02suffix6',
                    test_column='test6'
                ),
            ],
            schema={
                'basic_incorrect':   typing.Int64,
                'basic_correct':     typing.Int64,
                'last_expert_value': typing.Optional[typing.Yson],
                'fixed_at':          typing.Optional[typing.Unicode],
                'entity_domain':     typing.Unicode,
                'action_by':         typing.Int64,
                'action_at':         typing.Unicode,
                'test_column':       typing.Unicode,
            }
        ),
        min_date='2022-07-02',
        max_date='2022-07-02',
    )

    assert sorted(result) == sorted([
        Record(fixed_at=None,                entity_domain='domain1', puid=1, date='2022-07-02'),
        Record(fixed_at='2022-07-03suffix2', entity_domain='domain2', puid=2, date='2022-07-02'),
        Record(fixed_at='2022-07-02suffix3', entity_domain='domain3', puid=3, date='2022-07-02'),
    ])


def test_should_get_incorrect_units_by_last_expert_grade():
    result = nile_ut.yql_run(
        _get_incorrect_units,
        unit_table=nile_ut.Table(
            data=[
                Record(
                    basic_incorrect=1, basic_correct=2, last_expert_value=b'incorrect', fixed_at=None,
                    entity_domain='domain1', action_by=1, action_at='2022-07-02suffix1',
                    test_column='test1'
                ),
                Record(
                    basic_incorrect=0, basic_correct=1, last_expert_value=b'incorrect', fixed_at='2022-07-03suffix2',
                    entity_domain='domain2', action_by=2, action_at='2022-07-02suffix2',
                    test_column='test2'
                ),
                Record(
                    basic_incorrect=0, basic_correct=0, last_expert_value=b'incorrect', fixed_at='2022-07-02suffix3',
                    entity_domain='domain3', action_by=3, action_at='2022-07-02suffix3',
                    test_column='test3'
                ),
                Record(
                    basic_incorrect=2, basic_correct=1, last_expert_value=b'correct', fixed_at=None,
                    entity_domain='domain4', action_by=4, action_at='2022-07-02suffix4',
                    test_column='test4'
                ),
                Record(
                    basic_incorrect=1, basic_correct=0, last_expert_value=b'correct', fixed_at='2022-07-03suffix5',
                    entity_domain='domain5', action_by=5, action_at='2022-07-02suffix5',
                    test_column='test5'
                ),
            ],
            schema={
                'basic_incorrect':   typing.Int64,
                'basic_correct':     typing.Int64,
                'last_expert_value': typing.Optional[typing.Yson],
                'fixed_at':          typing.Optional[typing.Unicode],
                'entity_domain':     typing.Unicode,
                'action_by':         typing.Int64,
                'action_at':         typing.Unicode,
                'test_column':       typing.Unicode,
            }
        ),
        min_date='2022-07-02',
        max_date='2022-07-02',
    )

    assert sorted(result) == sorted([
        Record(fixed_at=None,                entity_domain='domain1', puid=1, date='2022-07-02'),
        Record(fixed_at='2022-07-03suffix2', entity_domain='domain2', puid=2, date='2022-07-02'),
        Record(fixed_at='2022-07-02suffix3', entity_domain='domain3', puid=3, date='2022-07-02'),
    ])


def test_should_get_incorrect_units_by_date_range():
    result = nile_ut.yql_run(
        _get_incorrect_units,
        unit_table=nile_ut.Table([
            Record(
                basic_incorrect=0, basic_correct=0, last_expert_value=b'incorrect', fixed_at=None,
                entity_domain='domain1', action_by=1, action_at='2022-07-01suffix1',
                test_column='test1'
            ),
            Record(
                basic_incorrect=0, basic_correct=0, last_expert_value=b'incorrect', fixed_at='2022-07-03suffix2',
                entity_domain='domain2', action_by=2, action_at='2022-07-02suffix2',
                test_column='test2'
            ),
            Record(
                basic_incorrect=0, basic_correct=0, last_expert_value=b'incorrect', fixed_at='2022-07-02suffix3',
                entity_domain='domain3', action_by=3, action_at='2022-07-03suffix3',
                test_column='test3'
            ),
            Record(
                basic_incorrect=0, basic_correct=0, last_expert_value=b'incorrect', fixed_at=None,
                entity_domain='domain4', action_by=4, action_at='2022-07-04suffix4',
                test_column='test4'
            ),
            Record(
                basic_incorrect=0, basic_correct=0, last_expert_value=b'incorrect', fixed_at='2022-07-03suffix5',
                entity_domain='domain5', action_by=5, action_at='2022-07-05suffix5',
                test_column='test5'
            ),
        ]),
        min_date='2022-07-02',
        max_date='2022-07-04',
    )

    assert sorted(result) == sorted([
        Record(fixed_at='2022-07-03suffix2', entity_domain='domain2', puid=2, date='2022-07-02'),
        Record(fixed_at='2022-07-02suffix3', entity_domain='domain3', puid=3, date='2022-07-03'),
        Record(fixed_at=None,                entity_domain='domain4', puid=4, date='2022-07-04'),
    ])


def test_should_add_puid_info():
    staff_puid = 1100000000000001

    result = nile_ut.yql_run(
        _add_puid_info,
        units=nile_ut.Table([
            Record(puid=1,          date='2022-07-02', test_column='test1'),
            Record(puid=2,          date='2022-07-02', test_column='test2'),
            Record(puid=staff_puid, date='2022-07-02', test_column='test3'),
        ]),
        puid_info=nile_ut.Table([
            Record(date='2022-07-02', puid=1, involvement='inv1', payment='pay1', group='grp1', person='usr1'),
            Record(date='2022-07-03', puid=2, involvement='inv2', payment='pay2', group='grp2', person='usr2'),
            Record(date='2022-07-03', puid=1, involvement='inv3', payment='pay3', group='grp3', person='usr3'),
        ]),
    )

    assert sorted(result) == sorted([
        Record(date='2022-07-02', involvement='inv1',  payment='pay1',    group='grp1',    person='usr1',                    test_column='test1'),
        Record(date='2022-07-02', involvement='user',  payment='free',    group='common',  person='common user',             test_column='test2'),
        Record(date='2022-07-02', involvement='staff', payment='unknown', group='unknown', person=f'unknown ({staff_puid})', test_column='test3'),
    ])


def test_should_count_unit_fixes():
    result = nile_ut.yql_run(
        _count_unit_fixes,
        nile_ut.Table([
            Record(entity_domain='domain1', date='date1', payment='pay1', involvement='inv1', group='grp1', person='usr1', fixed_at=None),
            Record(entity_domain='domain1', date='date1', payment='pay1', involvement='inv1', group='grp1', person='usr1', fixed_at='fx1'),
            Record(entity_domain='domain1', date='date1', payment='pay1', involvement='inv1', group='grp1', person='usr1', fixed_at='fx2'),

            Record(entity_domain='domain1', date='date2', payment='pay1', involvement='inv1', group='grp1', person='usr1', fixed_at=None),
        ]),
    )

    assert sorted(result) == sorted([
        Record(entity_domain='domain1', date='date1', payment='pay1', involvement='inv1', group='grp1', person='usr1', fixed_count=2, total_count=3),
        Record(entity_domain='domain1', date='date2', payment='pay1', involvement='inv1', group='grp1', person='usr1', fixed_count=0, total_count=1),
    ])
