from nile.api.v1 import (
    extractors as ne,
    utils as nu,
    extended_schema,
)
from qb2.api.v1 import (
    extractors as qe,
    typing,
)


def region_id_to_region_name_tree(stream, major_regions_map):
    '''
    stream:
    | region_id | <columns> |
    |-----------+-----------|
    | ...       | ...       |

    major_regions_map:
    | region_id (string) | region_tree | ... |
    |--------------------+-------------+-----|
    | ...                | ...         | ... |

    result:
    | region_name_tree | <columns> |
    |------------------+-----------|
    | ...              | ...       |
    '''
    return stream.join(
        major_regions_map.project(
            qe.custom('region_id', int, 'region_id').with_type(typing.Int64),
            'region_name',
        ),
        type='left',
        by='region_id',
    ).project(
        ne.all(exclude=('region_id', 'region_name')),
        qe.coalesce('region_name_tree', 'region_name', '\tЗемля\t').with_type(typing.Unicode),
    )


def _coalesce_tree(stream, tree_name_col, value_col):
    @nu.with_hints(output_schema=extended_schema(**{tree_name_col: typing.Unicode}))
    def mapper(records):
        for record in records:
            if record.get(tree_name_col) is not None:
                yield record
            else:
                value = str(record[value_col])
                yield record.transform(**{tree_name_col: '\tall\t'})
                yield record.transform(**{tree_name_col: '\tall\tunknown\t'})
                yield record.transform(**{tree_name_col: f'\tall\tunknown\t{value}\t'})
    return stream.map(mapper)


def graded_by_to_graded_by_tree(stream, puid_map):
    '''
    stream:
    | graded_by | <columns> |
    |-----------+-----------|
    | ...       | ...       |

    puid_map:
    | puid | puid_tree | ... |
    |------+-----------+-----|
    | ...  | ...       | ... |

    result:
    | graded_by_tree | <columns> |
    |----------------+-----------|
    | ...            | ...       |
    '''
    return stream.join(
        puid_map.cast(
            puid=typing.Int64,
            puid_tree=typing.Unicode,
        ).project(
            graded_by='puid',
            graded_by_tree='puid_tree',
        ),
        by='graded_by',
        type='left'
    ).call(
        _coalesce_tree, 'graded_by_tree', 'graded_by'
    ).project(
        ne.all(exclude=('graded_by', )),
    )


def task_id_to_task_name_tree(stream, tariffs_with_dates):
    '''
    stream:
    | task_id | action_date | <columns> |
    |---------+------------+-----------|
    | ...     | ...        | ...       |

    tariffs_with_dates:
    | date | task_id | task_name_tree | ... |
    |------+---------+----------------+-----|
    | ...  | ...     | ...            | ... |

    result:
    | task_name_tree | <columns> |
    |----------------+-----------|
    | ...            | ...       |
    '''
    return stream.join(
        tariffs_with_dates.project(
            'task_id',
            'task_name_tree',
            action_date='date',
        ).cast(
            task_id=typing.Unicode,
            task_name_tree=typing.Unicode,
            action_date=typing.Unicode,
        ),
        by=('task_id', 'action_date'),
        type='left',
    ).call(
        _coalesce_tree, 'task_name_tree', 'task_id'
    ).project(
        ne.all(exclude=('task_id', 'action_date')),
    )
