from nile.api.v1 import (
    extractors as ne,
    utils as nu,
    extended_schema,
)
from qb2.api.v1 import (
    extractors as qe,
    typing,
)


def action_to_action_tree(stream):
    '''
    stream:
    | action | ... |
    |--------+-----|
    | ...    | ... |

    result:
    | action_tree | ... |
    |-------------+-----|
    | ...         | ... |
    '''
    def get_tree_paths(action):
        return ['\tall\t', f'\tall\t{action}\t']

    return stream.project(
        ne.all(exclude='action'),
        qe.custom('_action_tree_paths', get_tree_paths, 'action').with_type(typing.List[typing.Unicode]).hide(),
        qe.unfold('action_tree', '_action_tree_paths'),
    )


def region_id_to_region_name_tree(stream, major_regions_map):
    '''
    stream:
    | region_id | <columns> |
    |-----------+-----------|
    | ...       | ...       |

    major_regions_map:
    | region_id (string) | region_name | ... |
    |--------------------+-------------+-----|
    | ...                | ...         | ... |

    result:
    | region_name_tree | <columns> |
    |------------------+-----------|
    | ...              | ...       |
    '''
    return stream.join(
        major_regions_map.project(
            qe.custom('region_id', int, 'region_id').with_type(typing.Int64),
            'region_name',
        ),
        type='left',
        by='region_id',
    ).project(
        ne.all(exclude=('region_id', 'region_name')),
        qe.coalesce('region_name_tree', 'region_name', '\tЗемля\t').with_type(typing.Unicode),
    )


def _coalesce_tree(stream, tree_name_col, value_col):
    @nu.with_hints(output_schema=extended_schema(**{tree_name_col: typing.Unicode}))
    def mapper(records):
        for record in records:
            if record.get(tree_name_col) is not None:
                yield record
            else:
                value = str(record[value_col])
                yield record.transform(**{tree_name_col: '\tall\t'})
                yield record.transform(**{tree_name_col: '\tall\tunknown\t'})
                yield record.transform(**{tree_name_col: f'\tall\tunknown\t{value}\t'})
    return stream.map(mapper)


def action_by_to_action_by_tree(stream, puid_map):
    '''
    stream:
    | action_by | <columns> |
    |-----------+-----------|
    | ...       | ...       |

    puid_map:
    | puid | puid_tree | ... |
    |------+-----------+-----|
    | ...  | ...       | ... |

    result:
    | action_by_tree | <columns> |
    |----------------+-----------|
    | ...            | ...       |
    '''
    return stream.join(
        puid_map.cast(
            puid=typing.Int64,
            puid_tree=typing.Unicode,
        ).project(
            action_by='puid',
            action_by_tree='puid_tree',
        ),
        by='action_by',
        type='left'
    ).call(
        _coalesce_tree, 'action_by_tree', 'action_by'
    ).project(
        ne.all(exclude='action_by'),
    )


def task_id_to_task_name_tree(stream, tariffs_with_dates):
    '''
    stream:
    | task_id | action_date | <columns> |
    |---------+-------------+-----------|
    | ...     | ...         | ...       |

    tariffs_with_dates:
    | tariff_date | task_id | task_name_tree |
    |-------------+---------+----------------|
    | ...         | ...     | ...            |

    result:
    | task_name_tree | action_date | <columns> |
    |----------------+-------------+-----------|
    | ...            | ...         | ...       |
    '''
    return tariffs_with_dates.join(
        stream,
        by_left=('task_id', 'tariff_date'),
        by_right=('task_id', 'action_date'),
        type='right',
    ).call(
        _coalesce_tree, 'task_name_tree', 'task_id'
    ).project(
        ne.all(exclude=('task_id', 'tariff_date')),
    )
