from .tree_dimensions import (
    graded_by_to_graded_by_tree,
    region_id_to_region_name_tree,
    task_id_to_task_name_tree,
)
from nile.api.v1 import (
    aggregators as na,
    extractors as ne,
    filters as nf,
)
from qb2.api.v1 import (
    extractors as qe,
    typing,
)

DATASET_SCHEMA = {
    'region_name_tree': typing.Unicode,
    'task_name_tree':   typing.Unicode,
    'graded_by_tree':   typing.Unicode,
    'action':           typing.Unicode,

    'grade_date':       typing.Unicode,

    'true_correct':     typing.Int64,
    'true_incorrect':   typing.Int64,
    'false_correct':    typing.Int64,
    'false_incorrect':  typing.Int64,
    'conflicts':        typing.Int64,
}


def _date_extractor(column):
    date_length = len('YYYY-MM-DD')

    def datetime_to_date(datetime):
        return datetime[:date_length]

    return ne.custom(datetime_to_date, column).with_type(typing.Unicode)


def _get_basic_grades(grade_table):
    '''
    grade_table:
    | unit_id | graded_by | graded_at | value | qualification | ... |
    |---------+-----------+-----------+-------+---------------+-----|
    | ...     | ...       | ...       | ...   | ...           | ... |

    result:
    | unit_id | graded_by | grade_date | value |
    |---------+-----------+------------+-------|
    | ...     | ...       | ...        | ...   |
    '''
    return grade_table.filter(
        nf.equals('qualification', 'basic'),
    ).project(
        'unit_id',
        'graded_by',
        qe.yql_custom('value', 'CAST(Yson::ConvertToString(Yson::Parse($p0)) as Utf8)', 'value'),
        grade_date=_date_extractor('graded_at'),
    )


def _get_units_graded_by_experts(graded_units_log):
    '''
    graded_units_log:
    | unit_id | action_at | action | task_id | region_id | value | certainty | ... |
    |---------+-----------+--------+---------+-----------+-------+-----------+-----|
    | ...     | ...       | ...    | ...     | ...       | ...   | ...       | ... |

    result:
    | unit_id | action_date | action | task_id | region_id | expert_value |
    |---------+-------------+--------+---------+-----------+--------------|
    | ...     | ...         | ...    | ...     | ...       | value        |
    '''
    return graded_units_log.filter(
        nf.equals('certainty', 'expert'),
    ).project(
        'unit_id',
        'entity_domain',
        'task_id',
        'region_id',
        'action',
        expert_value='value',
        action_date=_date_extractor('action_at'),
    )


def _grade_counter(expert_value, value):
    return na.count(nf.and_(
        nf.equals('expert_value', expert_value),
        nf.equals('value', value),
    ))


def _replace_missing_counts_with_zeroes(grade_counts, *counter_columns):
    '''
    grade_counts:
    | <counter_column> | ... |
    |------------------+-----|
    | int or NULL      | ... |

    result:
    | <counter_column> | ... |
    |------------------+-----|
    | int or 0         | ... |
    '''
    return grade_counts.project(
        ne.all(),
        *[qe.coalesce(column, column, 0).with_type(typing.Int64) for column in counter_columns]
    )


def _count_grades(basic_grades, units_graded_by_experts, puid_map, tariffs_with_dates, major_regions_map):
    '''
    basic_grades:
    | unit_id | graded_by | grade_date | value |
    |---------+-----------+------------+-------|
    | ...     | ...       | ...        | ...   |

    units_graded_by_experts:
    | unit_id | action_date | action | task_id | region_id | expert_value |
    |---------+-------------+--------+---------+-----------+--------------|
    | ...     | ...         | ...    | ...     | ...       | ...          |

    puid_map:
    | puid | puid_tree | ... |
    |------+-----------+-----|
    | ...  | ...       | ... |

    tariffs_with_dates:
    | tariff_date | task_id | task_name_tree | ... |
    |-------------+---------+----------------+-----|
    | ...         | ...     | ...            | ... |

    major_regions_map:
    | region_id (string) | region_tree | ... |
    |--------------------+-------------+-----|
    | ...                | ...         | ... |

    result:
    | action_date | action | graded_by_tree | task_name_tree | region_name_tree | true_correct | true_incorrect | false_correct | false_incorrect | conflicts |
    |-------------+--------+----------------+----------------+------------------+--------------+----------------+---------------+-----------------+-----------|
    | ...         | ...    | ...            | ...            | ...              | ...          | ...            | ...           | ...             | ...       |
    '''
    return basic_grades.join(
        units_graded_by_experts,
        by='unit_id',
    ).call(
        graded_by_to_graded_by_tree, puid_map
    ).call(
        task_id_to_task_name_tree, tariffs_with_dates
    ).call(
        region_id_to_region_name_tree, major_regions_map
    ).groupby(
        'action',
        'grade_date',
        'graded_by_tree',
        'task_name_tree',
        'region_name_tree',
    ).aggregate(
        true_correct=_grade_counter(expert_value='correct', value='correct'),
        true_incorrect=_grade_counter(expert_value='incorrect', value='incorrect'),
        false_correct=_grade_counter(expert_value='incorrect', value='correct'),
        false_incorrect=_grade_counter(expert_value='correct', value='incorrect'),
        conflicts=na.count(nf.equals('expert_value', 'conflict')),
    ).call(
        _replace_missing_counts_with_zeroes,
        'true_correct',
        'true_incorrect',
        'false_correct',
        'false_incorrect',
        'conflicts',
    )


def make_dataset(grade_table, graded_units_log, puid_map, tariffs_with_dates, major_regions_map):
    return _count_grades(
        _get_basic_grades(grade_table),
        _get_units_graded_by_experts(graded_units_log),
        puid_map,
        tariffs_with_dates,
        major_regions_map,
    ).cast(
        **DATASET_SCHEMA
    )
