from .tree_dimensions import (
    action_by_to_action_by_tree,
    action_to_action_tree,
    region_id_to_region_name_tree,
    task_id_to_task_name_tree,
)
from nile.api.v1 import (
    aggregators as na,
    extractors as ne,
)
from qb2.api.v1 import (
    filters as qf,
    typing,
)


DATASET_SCHEMA = {
    'region_name_tree':   typing.Unicode,
    'task_name_tree':     typing.Unicode,
    'action_by_tree':     typing.Unicode,
    'action_tree':        typing.Unicode,

    'action_date':        typing.Unicode,

    'correct':            typing.Int64,
    'incorrect':          typing.Int64,
    'conflicts':          typing.Int64,
    'total':              typing.Int64,

    'correct_adjusted':   float,
    'incorrect_adjusted': float,
    'conflicts_adjusted': float,
}


def _date_extractor(column):
    date_length = len('YYYY-MM-DD')
    return ne.custom(
        lambda column: column[:date_length],
        column
    ).with_type(typing.Unicode)


def _add(a, b, c):
    return ne.custom(
        lambda a, b, c: a + b + c,
        a, b, c
    )


def _div(a, b):
    return ne.custom(
        lambda a, b: a / b if a and b > 0 else 0,
        a, b
    ).with_type(float)


def _count_tasks(tasks_log):
    '''
    tasks_log:
    | action_at | region_id | task_id | action | action_by | ... |
    |-----------+-----------+---------+--------+-----------+-----|
    | ...       | ...       | ...     | ...    | ...       | ... |

    result:
    | action_date | region_id | task_id | action_by | action | total   |
    |-------------+-----------+---------+-----------+--------+---------|
    | ...         | ...       | ...     | ...       | ...    | count() |
    '''
    return tasks_log.project(
        ne.all(exclude='action_at'),
        action_date=_date_extractor('action_at'),
    ).groupby(
        'action_date',
        'region_id',
        'task_id',
        'action',
        'action_by',
    ).aggregate(
        total=na.count(),
    )


def _count_graded_units(graded_units_log):
    '''
    graded_units_log:
    | action_at | region_id | task_id | action | action_by | value | ... |
    |-----------+-----------+---------+--------+-----------+-------+-----|
    | ...       | ...       | ...     | ...    | ...       | ...   | ... |

    result:
    | action_date | region_id | task_id | action_by | action | correct              | incorrect              | conflicts             |
    |-------------+-----------+---------+-----------+--------+----------------------+------------------------+-----------------------|
    | ...         | ...       | ...     | ...       | ...    | count(value=correct) | count(value=incorrect) | count(value=conflict) |
    '''
    return graded_units_log.project(
        ne.all(exclude='action_at'),
        action_date=_date_extractor('action_at'),
    ).groupby(
        'action_date',
        'region_id',
        'task_id',
        'action',
        'action_by',
    ).aggregate(
        correct=na.count(predicate=qf.equals('value', 'correct')),
        incorrect=na.count(predicate=qf.equals('value', 'incorrect')),
        conflicts=na.count(predicate=qf.equals('value', 'conflict')),
    )


def _add_grades_to_tasks(job, task_counts, graded_unit_counts):
    '''
    task_counts:
    | action_date | region_id | task_id | action_by | action | total |
    |-------------+-----------+---------+-----------+--------+-------|
    | ...         | ...       | ...     | ...       | ...    | ...   |

    graded_unit_counts:
    | action_date | region_id | task_id | action_by | action | correct | incorrect | conflicts | ... |
    |-------------+-----------+---------+-----------+--------+---------+-----------+-----------+-----|
    | ...         | ...       | ...     | ...       | ...    | ...     | ...       | ...       | ... |

    result:
    | action_date | region_id | task_id | action_by | action | total | correct | incorrect | conflicts |
    |-------------+-----------+---------+-----------+--------+-------+---------+-----------+-----------|
    | ...         | ...       | ...     | ...       | ...    | ...   | ...     | ...       | ...       |
    '''
    return job.concat(
        task_counts,
        graded_unit_counts
    ).groupby(
        'task_id',
        'region_id',
        'action_date',
        'action_by',
        'action',
    ).aggregate(
        correct=na.sum('correct', 0),
        incorrect=na.sum('incorrect', 0),
        conflicts=na.sum('conflicts', 0),
        total=na.sum('total', 0),
    )


def _calc_adjusted_values(tasks_with_grades):
    '''
    tasks_with_grades:
    | total | correct | incorrect | conflicts | ... |
    |-------+---------+-----------+-----------+-----|
    | ...   | ...     | ...       | ...       | ... |

    result:
    | total | correct | incorrect | conflicts | correct_adjusted  | incorrect_adjusted | conflicts_adjusted | ... |
    |-------+---------+-----------+-----------+-------------------+--------------------+--------------------+-----|
    | ...   | ...     | ...       | ...       | correct / _ratio  | incorrect / _ratio | conflicts / _ratio | ... |
    where
    _ratio = _graded / total,
    _graded = correct + incorrect + conflicts.
    '''
    return tasks_with_grades.project(
        ne.all(),
        _graded=_add('correct', 'incorrect', 'conflicts').hide(),
        _ratio=_div('_graded', 'total').hide(),
        correct_adjusted=_div('correct', '_ratio'),
        incorrect_adjusted=_div('incorrect', '_ratio'),
        conflicts_adjusted=_div('conflicts', '_ratio'),
    )


def _recount_for_dimensions(tasks_with_adjusted_grades, puid_map, tariffs_with_dates, major_regions_map):
    '''
    tasks_with_adjusted_grades:
    | action_date | region_id | task_id | action_by | action | total | correct | incorrect | conflicts | correct_adjusted | incorrect_adjusted | conflicts_adjusted |
    |-------------+-----------+---------+-----------+--------+-------+---------+-----------+-----------+------------------+--------------------+--------------------|
    | ...         | ...       | ...     | ...       | ...    | ...   | ...     | ...       | ...       | ...              | ...                | ...                |

    puid_map:
    | puid | puid_tree | ... |
    |------+-----------+-----|
    | ...  | ...       | ... |

    tariffs_with_dates:
    | tariff_date | task_id | task_name_tree | ... |
    |-------------+---------+----------------+-----|
    | ...         | ...     | ...            | ... |

    major_regions_map:
    | region_id | region_name_tree | ... |
    |-----------+------------------+-----|
    | ...       | ...              | ... |

    result:
    | task_name_tree | region_name_tree | action_date | action_by_tree | action_tree |total|correct|incorrect|conflicts|correct_adjusted|incorrect_adjusted|conflicts_adjusted|
    |----------------+------------------+-------------+----------------+-------------+-----+-------+---------+---------+----------------+------------------+------------------|
    | ...            | ...              | ...         | ...            | ...         | ... | ...   | ...     | ...     | ...            | ...              | ...              |
    '''
    return tasks_with_adjusted_grades.call(
        action_by_to_action_by_tree, puid_map
    ).call(
        task_id_to_task_name_tree, tariffs_with_dates
    ).call(
        region_id_to_region_name_tree, major_regions_map
    ).call(
        action_to_action_tree
    ).groupby(
        'action_tree',
        'action_by_tree',
        'action_date',
        'region_name_tree',
        'task_name_tree',
    ).aggregate(
        total=na.sum('total'),
        correct=na.sum('correct'),
        incorrect=na.sum('incorrect'),
        conflicts=na.sum('conflicts'),
        correct_adjusted=na.sum('correct_adjusted'),
        incorrect_adjusted=na.sum('incorrect_adjusted'),
        conflicts_adjusted=na.sum('conflicts_adjusted'),
    )


def make_dataset(job, tasks_log, graded_units_log, puid_map, tariffs_with_dates, major_regions_map):
    '''
    tasks_log:
    | action_at | region_id | task_id | action | action_by | ... |
    |-----------+-----------+---------+--------+-----------+-----|
    | ...       | ...       | ...     | ...    | ...       | ... |

    graded_units_log:
    | action_at | region_id | task_id | action | action_by | value | ... |
    |-----------+-----------+---------+--------+-----------+-------+-----|
    | ...       | ...       | ...     | ...    | ...       | ...   | ... |

    puid_map:
    | puid | puid_tree | ... |
    |------+-----------+-----|
    | ...  | ...       | ... |

    tariffs_with_dates:
    | tariff_date | task_id | task_name_tree | ... |
    |-------------+---------+----------------+-----|
    | ...         | ...     | ...            | ... |

    major_regions_map:
    | region_id | region_name_tree | ... |
    |-----------+------------------+-----|
    | ...       | ...              | ... |

    result:
    | task_name_tree | region_name_tree | action_date | action_by_tree | action_tree | correct | incorrect | conflicts | total |
    |----------------+------------------+-------------+----------------+-------------+---------+-----------+-----------+-------|
    | ...            | ...              | ...         | ...            | ...         | ...     | ...       | ...       | ...   |
    '''
    return _add_grades_to_tasks(
        job=job,
        task_counts=_count_tasks(tasks_log),
        graded_unit_counts=_count_graded_units(graded_units_log),
    ).call(
        _calc_adjusted_values
    ).call(
        _recount_for_dimensions,
        puid_map=puid_map,
        tariffs_with_dates=tariffs_with_dates,
        major_regions_map=major_regions_map,
    ).cast(
        **DATASET_SCHEMA
    )
