from maps.wikimap.stat.libs.common.lib.geobase_region import (
    EARTH_REGION_ID,
    FILES as GEOBASE_FILES,
    GEOBASE_JOB_MEMORY_LIMIT,
    geobase_region_id
)
from maps.wikimap.stat.tasks_payment.dictionaries.tariff.schema import TARIFF_RUB_PER_SEC

from nile.api.v1 import aggregators, extractors, filters

import re


_DEFAULT_BASIC_UNIT_TO_RUB = 1.0
_DEFAULT_REGION_TREE = b'\t10000\t'
_DEFAULT_SCHEDULE = b'unknown'
_EMPLOYEE_ACL_START_DATE = b'2020-04-27'
_FULL_ACCESS_GROUP = b'full_access_group'
_HOLIDAY_PAY_RATE_FACTOR = 2
_SCHEDULE_TREE_ROOT = b'all'
_MRC_REGION_PATTERN = b'\t__MRC_REGION__\t'
_TASK_ID_CURRENT_EMPLOYEE = b'service/current-employee'


def _is_puid_tree_contain_staff_login(puid_tree, staff_login):
    return re.search(b'\\(' + staff_login + b'\\)\t', puid_tree) is not None


def _extract_value_or_default(column_name, default_value):
    return extractors.custom(
        lambda value: value if value is not None else default_value,
        column_name
    )


def _is_employee_current(date, quit_at):
    if not quit_at:
        return True
    return date < quit_at


def _make_current_employees_log(date, staff_dump, linked_accounts):
    '''
    staff_dump:
    | login | quit_at | ... |
    |-------+---------+-----|
    | ...   | ...     | ... |

    linked_accounts:
    | staff_login | puid | is_primary_link | ... |
    |-------------+------+-----------------+-----|
    | ...         | ...  | ...             | ... |

    Result:
    | iso_datetime | puid | task_id                   | mrc_region | quantity | geom | lat_min | lon_min | lat_max | lon_max |
    |--------------+------+---------------------------+------------+----------+------+---------+---------+---------+---------|
    | date         | ...  | _TASK_ID_CURRENT_EMPLOYEE | None       |        0 | None | None    | None    | None    | None    |
    '''
    primary_accounts = linked_accounts.filter(
        filters.equals('is_primary_link', True)
    ).project(
        'staff_login', 'puid'
    )

    return staff_dump.filter(
        filters.custom(lambda quit_at: _is_employee_current(date, quit_at), 'quit_at')
    ).project(
        staff_login='login'
    ).join(
        primary_accounts,
        by='staff_login',
        type='inner',
        assume_small=True,
        assume_unique=True,
    ).project(
        iso_datetime=extractors.const(date),
        puid='puid',
        task_id=extractors.const(_TASK_ID_CURRENT_EMPLOYEE),
        mrc_region=extractors.const(None),
        quantity=extractors.const(0.0),
        geom=extractors.const(None),
        lat_min=extractors.const(None),
        lon_min=extractors.const(None),
        lat_max=extractors.const(None),
        lon_max=extractors.const(None),
    )


def _concat_logs(job, date, *logs):
    '''
    `geom` column is dropped as unused in further calculations.
    `iso_datetime` column is dropped and new column `fielddate`=`date` created instead

    logs:
    | iso_datetime | puid | task_id | mrc_region | quantity | geom | lat_min | lon_min | lat_max | lon_max |
    |--------------+------+---------+------------+----------+------+---------+---------+---------+---------|
    |          ... |  ... |     ... |        ... |      ... |  ... |     ... |     ... |     ... |     ... |

    Result:
    | fielddate | puid | task_id | mrc_region | quantity | lat_min | lon_min | lat_max | lon_max |
    |-----------+------+---------+------------+----------+---------+---------+---------+---------|
    |       ... |  ... |     ... |        ... |      ... |     ... |     ... |     ... |     ... |
    '''
    return job.concat(
        *logs
    ).project(
        extractors.all(exclude=['geom', 'iso_datetime']),
        fielddate=extractors.const(date)
    )


def _bbox_to_region_id(log):
    '''
    Converts 'bbox' columns to `region_id` column.

    log:
    | fielddate | puid | task_id | mrc_region | quantity | lat_min | lon_min | lat_max | lon_max |
    |-----------+------+---------+------------+----------+---------+---------+---------+---------|
    |       ... |  ... |     ... |        ... |      ... |     ... |     ... |     ... |     ... |

    Result:
    | fielddate | puid | task_id | mrc_region | quantity | region_id |
    |-----------+------+---------+------------+----------+-----------|
    |       ... |  ... |     ... |        ... |      ... |       ... |
    '''
    def get_region_id(lat_min, lon_min, lat_max, lon_max):
        if lat_min is None or lon_min is None or \
           lat_max is None or lon_max is None:
            return EARTH_REGION_ID

        lon = (lon_min + lon_max) / 2.0
        lat = (lat_min + lat_max) / 2.0
        return geobase_region_id(lon, lat)

    return log.project(
        extractors.all(exclude=['lat_min', 'lon_min', 'lat_max', 'lon_max']),
        region_id=extractors.custom(get_region_id, 'lat_min', 'lon_min', 'lat_max', 'lon_max'),
        files=GEOBASE_FILES,
        memory_limit=GEOBASE_JOB_MEMORY_LIMIT
    )


def _add_staff_login(log, linked_accounts):
    '''
    log:
    | fielddate | puid | task_id | mrc_region | quantity | region_id |
    |-----------+------+---------+------------+----------+-----------|
    |       ... |  ... |     ... |        ... |      ... |       ... |

    linked_accounts:
    | puid | login | staff_uid | staff_login |
    |------+-------+-----------+-------------|
    |  ... |   ... |       ... |         ... |

    return:
    | fielddate | puid | task_id | mrc_region | quantity | region_id | staff_login |
    |-----------+------+---------+------------+----------+-----------+-------------|
    | ...       | ...  | ...     | ...        | ...      | ...       | ...         |
    '''
    return log.join(
        linked_accounts,
        type='left',
        by='puid'
    ).project(
        'fielddate', 'puid', 'task_id', 'mrc_region', 'quantity', 'region_id', 'staff_login'
    )


def _add_user_holidays(date, log, user_holidays):
    '''
    log:
    | fielddate | puid | task_id | mrc_region | quantity | region_id | staff_login |
    |-----------+------+---------+------------+----------+-----------+-------------|
    | ...       | ...  | ...     | ...        | ...      | ...       | ...         |

    user_holidays:
    | staff_login | date |
    |-------------+------|
    | ...         | ...  |

    result:
    | fielddate | puid | task_id | mrc_region | quantity | region_id | staff_login | holiday   |
    |-----------+------+---------+------------+----------+-----------+-------------+-----------|
    | ...       | ...  | ...     | ...        | ...      | ...       | ...         | True/None |
    '''
    users_on_public_holidays = user_holidays.filter(
        filters.equals('date', date)
    )

    return log.join(
        users_on_public_holidays,
        by='staff_login',
        type='left',
        assume_small_right=True,
        assume_unique_right=True,
    ).project(
        extractors.all(exclude=['date']),
        holiday=extractors.custom(
            lambda date: True if date is not None else None,
            'date'
        )
    )


def _add_basic_units(log, basic_units):
    '''
    log:
    | fielddate | puid | task_id | mrc_region | quantity | region_id | staff_login | holiday |
    |-----------+------+---------+------------+----------+-----------+-------------+---------|
    | ...       | ...  | ...     | ...        | ...      | ...       | ...         | ...     |

    basic_units:
    | staff_login | basic_unit_to_rub |
    |-------------+-------------------|
    | ...         | ...               |

    result:
    | fielddate | puid | task_id | mrc_region | quantity | region_id | staff_login | holiday | basic_unit_to_rub |
    |-----------+------+---------+------------+----------+-----------+-------------+---------+-------------------|
    | ...       | ...  | ...     | ...        | ...      | ...       | ...         | ...     | ...               |
    '''
    return log.join(
        basic_units,
        by='staff_login',
        type='left',
        assume_small_right=True,
        assume_unique_right=True
    ).project(
        extractors.all(exclude=['basic_unit_to_rub']),
        basic_unit_to_rub=_extract_value_or_default('basic_unit_to_rub', _DEFAULT_BASIC_UNIT_TO_RUB)
    )


def _add_work_schedule(log, work_schedule):
    '''
    log:
    | fielddate | puid | task_id | mrc_region | quantity | region_id | staff_login | holiday | basic_unit_to_rub |
    |-----------+------+---------+------------+----------+-----------+-------------+---------+-------------------|
    | ...       | ...  | ...     | ...        | ...      | ...       | ...         | ...     | ...               |

    work_schedule:
    | staff_login | schedule |
    |-------------+----------|
    | ...         | ...      |

    result:
    | fielddate | puid | task_id | mrc_region | quantity | region_id | staff_login | holiday | cost_basic_unit | schedule_tree                 |
    |-----------+------+---------+------------+----------+-----------+-------------+---------+-----------------+-------------------------------|
    | ...       | ...  | ...     | ...        | ...      | ...       | ...         | ...     | ...             | schedule or _DEFAULT_SCHEDULE |
    '''
    def schedule_to_tree(records):
        for record in records:
            schedule = record.get('schedule', _DEFAULT_SCHEDULE)
            yield record.transform('schedule', schedule_tree=_tree_name([_SCHEDULE_TREE_ROOT]))
            yield record.transform('schedule', schedule_tree=_tree_name([_SCHEDULE_TREE_ROOT, schedule]))

    return log.join(
        work_schedule,
        type='left',
        by='staff_login'
    ).map(
        schedule_to_tree
    )


def _region_id_to_region_name_tree(log, major_regions_map):
    '''
    Converts region identities to region names in tree format. If a region_id is
    absent in the `major_regions_map` table then default region name is used.

    log:
    | fielddate | puid | task_id | mrc_region | quantity | region_id (int) | staff_login | holiday | cost_basic_unit | schedule_tree |
    |-----------+------+---------+------------+----------+-----------------+-------------+---------+-----------------+---------------|
    | ...       | ...  | ...     | ...        | ...      | ...             | ...         | ...     | ...             | ...           |

    major_regions_map:
    | region_id (string) | region_tree | ... |
    |--------------------+-------------+-----|
    |                ... |         ... | ... |

    Note. There are several region names for a region_id (an entry for each tree
    level)

    Result:
    | fielddate | puid | task_id | mrc_region | region_name_tree | quantity | staff_login | holiday | cost_basic_unit | schedule_tree |
    |-----------+------+---------+------------+------------------+----------+-------------+---------+-----------------+---------------|
    | ...       | ...  | ...     | ...        | ...              | ...      | ...         | ...     | ...             | ...           |
    '''
    region_id_to_region_tree = major_regions_map.project(
        region_id=extractors.custom(lambda region_id: int(region_id), 'region_id'),
        region_name_tree='region_tree'
    )

    return log.join(
        region_id_to_region_tree,
        type='left',
        by='region_id',
    ).project(
        extractors.all(exclude=['region_id']),
        region_name_tree=_extract_value_or_default('region_name_tree', _DEFAULT_REGION_TREE)
    )


def _tree_name(path):
    return b'\t' + b'\t'.join(path) + b'\t'


def _check_absent_tree_values(tree_name_col, value_col):
    def check_tree_names(records):
        for record in records:
            if record.get(tree_name_col) is None:
                value = record.get(value_col)
                raise RuntimeError(f'No {tree_name_col} for {value_col} = \'{value}\'')
            yield record

    return check_tree_names


def _add_task_name_and_cost(log, task_tariff_map):
    '''
    log:
    | fielddate | puid | task_id | mrc_region | region_name_tree | quantity | staff_login | holiday | basic_unit_to_rub | schedule_tree |
    |-----------+------+---------+------------+------------------+----------+-------------+---------+-------------------+---------------|
    | ...       | ...  | ...     | ...        | ...              | ...      | ...         | ...     | ...               | ...           |

    task_tariff_map:
    | task_id | task_name_tree | seconds_per_task |
    |---------+----------------+------------------|
    |     ... |            ... |              ... |

    Result:
    | fielddate | puid | task_name_tree | mrc_region | region_name_tree | quantity | staff_login | holiday | basic_unit_to_rub | schedule_tree | time_spent_sec | cost_rub | cost_basic_unit |
    |-----------+------+----------------+------------+------------------+----------+-------------+---------+-------------------+---------------+----------------+----------+-----------------|
    | ...       | ...  | ...            | ...        | ...              | ...      | ...         | ...     | ...               | ...           | ...            | ...      | ...             |
    '''
    def calc_cost_rub(time_spent_sec, holiday):
        result = time_spent_sec * TARIFF_RUB_PER_SEC
        if holiday:
            return result * _HOLIDAY_PAY_RATE_FACTOR
        return result

    def calc_cost_basic_unit(time_spent_sec, basic_unit_to_rub):
        return time_spent_sec * TARIFF_RUB_PER_SEC / basic_unit_to_rub

    return log.join(
        task_tariff_map,
        type='left',
        by='task_id'
    ).map(
        _check_absent_tree_values(
            tree_name_col='task_name_tree', value_col='task_id'
        )
    ).project(
        extractors.all(exclude=['seconds_per_task', 'task_id']),
        time_spent_sec=extractors.custom(
            lambda seconds, quantity: quantity * seconds,
            'seconds_per_task', 'quantity'
        ),
        cost_rub=extractors.custom(calc_cost_rub, 'time_spent_sec', 'holiday'),
        cost_basic_unit=extractors.custom(calc_cost_basic_unit, 'time_spent_sec', 'basic_unit_to_rub')
    )


def _embed_mrc_region_into_task_name_tree(log):
    '''
    log:
    | fielddate | puid | task_name_tree | mrc_region | region_name_tree | quantity | staff_login | holiday | basic_unit_to_rub | schedule_tree | time_spent_sec | cost_rub | cost_basic_unit |
    |-----------+------+----------------+------------+------------------+----------+-------------+---------+-------------------+---------------+----------------+----------+-----------------|
    | ...       | ...  | ...            | ...        | ...              | ...      | ...         | ...     | ...               | ...           | ...            | ...      | ...             |

    Result:
    | fielddate | puid | task_name_tree | region_name_tree | quantity | staff_login | holiday | basic_unit_to_rub | schedule_tree | time_spent_sec | cost_rub | cost_basic_unit |
    |-----------+------+----------------+------------------+----------+-------------+---------+-------------------+---------------+----------------+----------+-----------------|
    | ...       | ...  | ...            | ...              | ...      | ...         | ...     | ...               | ...           | ...            | ...      | ...             |
    '''
    def check_and_insert_mrc_region(task_name_tree, mrc_region):
        if task_name_tree.find(_MRC_REGION_PATTERN) >= 0:
            assert mrc_region, f'No mrc_region for task {task_name_tree}'
            mrc_region = b'\t' + mrc_region + b'\t'
            task_name_tree = task_name_tree.replace(_MRC_REGION_PATTERN, mrc_region)
        return task_name_tree

    return log.project(
        extractors.all(exclude=['task_name_tree', 'mrc_region']),
        task_name_tree=extractors.custom(
            check_and_insert_mrc_region, 'task_name_tree', 'mrc_region'
        )
    )


def _remove_non_aggregatable_data_from_records_without_login(records):
    '''
    It is impossible to aggregate records belong to different users by columns
    `holiday` and `basic_unit_to_rub`. Therefore, None is written for all top
    levels.
    '''
    for record in records:
        staff_login = record.get('staff_login')
        if not staff_login:
            yield record.transform(holiday=None, basic_unit_to_rub=None)
            continue

        puid_tree = record.get('puid_tree')
        if _is_puid_tree_contain_staff_login(puid_tree, staff_login):
            yield record
            continue

        yield record.transform(holiday=None, basic_unit_to_rub=None)


def _add_puid_tree(log, puid_map):
    '''
    log:
    | fielddate | puid | task_name_tree | region_name_tree | quantity | staff_login | holiday | basic_unit_to_rub | schedule_tree | time_spent_sec | cost_rub | cost_basic_unit |
    |-----------+------+----------------+------------------+----------+-------------+---------+-------------------+---------------+----------------+----------+-----------------|
    | ...       | ...  | ...            | ...              | ...      | ...         | ...     | ...               | ...           | ...            | ...      | ...             |

    puid_map:
    | puid | puid_tree |
    |------+-----------|
    |  ... |       ... |

    Note. There are several puid_tree values for a puid (an entry for each tree
    level).

    Result:
    | fielddate | puid_tree | task_name_tree | region_name_tree | quantity | staff_login | holiday   | basic_unit_to_rub  | schedule_tree | time_spent_sec | cost_rub | cost_basic_unit |
    |-----------+-----------+----------------+------------------+----------+-------------+-----------+--------------------+---------------+----------------+----------+-----------------|
    |           |           |                |                  |          |             | holiday/basic_unit_to_rub if   |               |                |          |                 |
    |           |           |                |                  |          |             | puid_tree contains staff_login |               |                |          |                 |
    | ...       | ...       | ...            | ...              | ...      | ...         | None otherwise                 | ...           | ...            | ...      | ...             |
    '''
    return log.join(
        puid_map,
        type='left',
        by='puid'
    ).map(
        _check_absent_tree_values(
            tree_name_col='puid_tree', value_col='puid'
        )
    ).project(
        extractors.all(exclude=['puid'])
    ).map(
        _remove_non_aggregatable_data_from_records_without_login
    )


def _make_acl(log):
    '''
    Grants access so that a person could see only tasks performed by himself. No
    access is granted for upper levels.

    Writes person's staff login to `_acl` column if puid_tree contains his staff
    login.

    Warning. This functions relays on the fact that puid trees has following
    format for all staff members:
    `path<tab>staff name (staff login)<tab>nmaps login (puid)<tab>`.

    log:
    | fielddate | puid_tree | task_name_tree | region_name_tree | quantity | staff_login | holiday | basic_unit_to_rub | schedule_tree  | time_spent_sec | cost_rub | cost_basic_unit |
    |-----------+-----------+----------------+------------------+----------+-------------+---------+-------------------+----------------+----------------+----------+-----------------|
    | ...       | ...       | ...            | ...              | ...      | ...         | ...     | ...               | ...            | ...            | ...      | ...             |

    Result:
    | fielddate | puid_tree | task_name_tree | region_name_tree | quantity | holiday | basic_unit_to_rub | schedule_tree | time_spent_sec | cost_rub | cost_basic_unit | _acl |
    |-----------+-----------+----------------+------------------+----------+---------+-------------------+---------------+----------------+----------+-----------------+------|
    | ...       | ...       | ...            | ...              | ...      | ...     | ...               | ...           | ...            | ...      | ...             | ...  |
    '''
    def acl_extractor(fielddate, puid_tree, staff_login):
        if fielddate < _EMPLOYEE_ACL_START_DATE:  # compare date strings lexicographically
            return _FULL_ACCESS_GROUP

        if not staff_login:
            return _FULL_ACCESS_GROUP

        if _is_puid_tree_contain_staff_login(puid_tree, staff_login):
            return _FULL_ACCESS_GROUP + b',@' + staff_login

        return _FULL_ACCESS_GROUP

    return log.project(
        extractors.all(exclude=['staff_login']),
        _acl=extractors.custom(
            lambda fielddate, puid_tree, staff_login: acl_extractor(fielddate, puid_tree, staff_login),
            'fielddate', 'puid_tree', 'staff_login'
        )
    )


def _aggregate_tasks(log):
    '''
    Aggregates data by all but `quantity`, `time_spent_sec`, `cost_rub` and `cost_basic_unit` columns.

    log:
    | fielddate | puid_tree | task_name_tree | region_name_tree | quantity | holiday | basic_unit_to_rub | schedule_tree | time_spent_sec | cost_rub | cost_basic_unit | _acl |
    |-----------+-----------+----------------+------------------+----------+---------+-------------------+---------------+----------------+----------+-----------------+------|
    | ...       | ...       | ...            | ...              | ...      | ...     | ...               | ...           | ...            | ...      | ...             | ...  |

    Result:
    | fielddate | puid_tree | task_name_tree | region_name_tree | schedule_tree | holiday | basic_unit_to_rub | quantity_total | time_spent_total_sec | cost_total_rub | cost_total_basic_unit | _acl |
    |-----------+-----------+----------------+------------------+---------------+---------+-------------------+----------------+----------------------+----------------+-----------------------+------|
    | ...       | ...       | ...            | ...              | ...           | ...     | ...               | SUM(...)       | SUM(...)             | SUM(...)       | SUM(...)              | ...  |
    '''
    return log.groupby(
        'fielddate', 'puid_tree', 'task_name_tree', 'region_name_tree', 'schedule_tree', 'holiday', 'basic_unit_to_rub', '_acl'
    ).aggregate(
        quantity_total=aggregators.sum('quantity'),
        time_spent_total_sec=aggregators.sum('time_spent_sec'),
        cost_total_rub=aggregators.sum('cost_rub'),
        cost_total_basic_unit=aggregators.sum('cost_basic_unit')
    )


def prepare_report(
    job,
    date,
    *logs,
    major_regions_map,
    task_tariff_map,
    puid_map,
    linked_accounts,
    work_schedule,
    user_holidays,
    basic_units,
    staff_dump,
):
    date = date.encode('utf-8')

    current_employees_log = _make_current_employees_log(date, staff_dump, linked_accounts)

    log = _concat_logs(job, date, *logs + (current_employees_log,))
    log = _bbox_to_region_id(log)
    log = _add_staff_login(log, linked_accounts)
    log = _add_user_holidays(date, log, user_holidays)
    log = _add_basic_units(log, basic_units)
    log = _add_work_schedule(log, work_schedule)
    log = _region_id_to_region_name_tree(log, major_regions_map)
    log = _add_task_name_and_cost(log, task_tariff_map)
    log = _embed_mrc_region_into_task_name_tree(log)
    log = _add_puid_tree(log, puid_map)
    log = _make_acl(log)
    return _aggregate_tasks(log)
