import codecs
from nile.api.v1 import (
    aggregators as na,
    extractors as ne,
    utils as nu,
    filters as nf,
)
from qb2.api.v1 import (
    typing as qt,
    filters as qf,
    extractors as qe
)
from yandex.maps import geolib3
from yt.wrapper.ypath import ypath_join
from maps.wikimap.stat.libs.common.lib import geobase_region
import cond_constants


DEFAULT_REGION_TREE = '\t10000\t'


def ymapsdf_stream(job, table_name, ymapsdf_path, regions):
    assert len(regions) > 0
    return job.concat(*(
        job.table(ypath_join(ymapsdf_path, region, table_name))
        for region in regions
    )).label(table_name)


def set_region_tree(stream, major_regions):
    return stream.join(
        major_regions,
        by='region_id',
        type='left',
        assume_small_right=True,
        allow_undefined_keys=False,
        assume_defined=True,
        memory_limit=8 * 1024
    ).project(
        ne.all(),
        region_tree=ne.custom(
            lambda rt: rt or DEFAULT_REGION_TREE, 'region_tree'
        )
    )


def _dehex(shape):
    return codecs.decode(shape, 'hex')


def _shape_length(shape):
    dehexed = _dehex(shape)
    geom_type = geolib3.get_geometry_type_from_wkb(dehexed)
    if geom_type == geolib3.GeometryType.EwkbLineString:
        polyline = geolib3.Polyline2.from_EWKB(_dehex(shape))
        return polyline.geolength()
    elif geom_type == geolib3.GeometryType.EwkbMultiLineString:
        polylines = geolib3.PolylinesVector.from_EWKB(dehexed)
        return sum(polyline.geolength() for polyline in polylines)
    raise RuntimeError('Unknown geometry type {}'.format(geom_type))


def point_region(shape):
    point = geolib3.Point2.from_EWKB(_dehex(shape))
    return str(geobase_region.geobase_region_id(point.lon, point.lat))


def region_by_polygon_shape(shape):
    if shape is None:
        return str(geobase_region.EARTH_REGION_ID)
    polygon = geolib3.Polygon2.from_EWKB(_dehex(shape))
    any_point = polygon.exterior_ring().point_at(0)
    return str(geobase_region.geobase_region_id(lon=any_point.lon, lat=any_point.lat))


def region_by_polyline_shape(shape):
    if shape is None:
        return str(geobase_region.EARTH_REGION_ID)
    dehexed = _dehex(shape)
    geom_type = geolib3.get_geometry_type_from_wkb(dehexed)
    if geom_type == geolib3.GeometryType.EwkbLineString:
        polyline = geolib3.Polyline2.from_EWKB(dehexed)
    elif geom_type == geolib3.GeometryType.EwkbMultiLineString:
        polylines = geolib3.PolylinesVector.from_EWKB(dehexed)
        polyline = polylines[0]
    else:
        raise RuntimeError('Unknown geometry type {}'.format(geom_type))
    any_point = polyline.point_at(0)
    return str(geobase_region.geobase_region_id(lon=any_point.lon, lat=any_point.lat))


def count_stream_by_field(job, stream, stream_name, field, field_to_aggregate, date, major_regions):
    '''
    stream:
    | {field_to_aggregate} | region_id | {field} | ... |
    |----------------------+-----------+---------+-----|
    |                  ... |       ... |     ... | ... |

    major_regions:
    | region_id | region_tree | ... |
    |-----------+-------------+-----|
    |       ... |         ... | ... |

    Result:
    | region_tree | {field_to_aggregate} | fielddate | category |
    |-------------+----------------------+-----------+----------|
    |         ... |                  ... |      date |      ... |
    '''
    stream = stream.groupby('region_id', field).aggregate(
        **{field_to_aggregate: na.sum(field_to_aggregate)}
    ).label(stream_name + '_aggregated_' + field)
    stream = set_region_tree(stream, major_regions).label(stream_name + '_region_tree')
    return stream.groupby('region_tree', field).aggregate(
        **{field_to_aggregate: na.sum(field_to_aggregate)}
    ).project(
        'region_tree', field_to_aggregate,
        fielddate=ne.const(date),
        category=ne.custom(
            lambda value: '{} {}={}'.format(stream_name, field, value),
            field
        )
    ).label('counted_{}_by_{}'.format(stream_name, field))


def count_stream(job, stream, stream_name, field_to_aggregate, date, major_regions):
    '''
    stream:
    | {field_to_aggregate} | region_id | ... |
    |----------------------+-----------+-----|
    |                  ... |       ... | ... |

    major_regions:
    | region_id | region_tree | ... |
    |-----------+-------------+-----|
    |       ... |         ... | ... |

    Result:
    | region_tree | {field_to_aggregate} | fielddate | category |
    |-------------+----------------------+-----------+----------|
    |         ... |                  ... |      date |      ... |
    '''
    stream = stream.groupby('region_id').aggregate(
        **{field_to_aggregate: na.sum(field_to_aggregate)}
    ).label(stream_name + '_aggregated')
    stream = set_region_tree(stream, major_regions).label(stream_name + '_region_tree')
    return stream.groupby('region_tree').aggregate(
        **{field_to_aggregate: na.sum(field_to_aggregate)}
    ).project(
        'region_tree', field_to_aggregate,
        fielddate=ne.const(date),
        category=ne.const(stream_name)
    ).label('counted_' + stream_name)


def add_addr_region(job, stream, ymapsdf_path, regions):
    node = ymapsdf_stream(job, 'node', ymapsdf_path, regions).unique('node_id')
    return stream.join(
        node,
        by='node_id',
        type='left',
        assume_unique=True,
        allow_undefined_keys=False,
        assume_defined=True
    ).project(
        ne.all(),
        region_id=ne.custom(point_region, 'shape'),
        files=geobase_region.FILES,
        memory_limit=geobase_region.GEOBASE_JOB_MEMORY_LIMIT
    ).label('addr_with_region')


def add_bld_region(job, stream, ymapsdf_path, regions):
    bld_geom = ymapsdf_stream(job, 'bld_geom', ymapsdf_path, regions).unique('bld_id')
    return stream.join(
        bld_geom,
        by='bld_id',
        type='left',
        assume_unique=True
    ).project(
        ne.all(),
        region_id=ne.custom(region_by_polygon_shape, 'shape'),
        files=geobase_region.FILES,
        memory_limit=geobase_region.GEOBASE_JOB_MEMORY_LIMIT
    ).label('bld_with_region')


def add_rd_region(job, stream, ymapsdf_path, regions):
    rd_geom = ymapsdf_stream(job, 'rd_geom', ymapsdf_path, regions).unique('rd_id')
    return stream.join(
        rd_geom,
        by='rd_id',
        type='left',
        assume_unique=True,
        assume_defined=True
    ).project(
        ne.all(),
        region_id=ne.custom(region_by_polyline_shape, 'shape'),
        files=geobase_region.FILES,
        memory_limit=geobase_region.GEOBASE_JOB_MEMORY_LIMIT
    ).label('rd_with_region')


def add_rd_el_region(job, stream, *unused_args):
    '''
    stream:
    | shape | ... |
    |-------|-----|
    |   ... | ... |

    Result:
    | shape | region_id | ... |
    |-------+-----------+-----|
    |   ... |       ... | ... |
    '''
    return stream.project(
        ne.all(),
        region_id=ne.custom(region_by_polyline_shape, 'shape'),
        files=geobase_region.FILES,
        memory_limit=geobase_region.GEOBASE_JOB_MEMORY_LIMIT
    ).label('rd_el_with_region')


def add_rd_el_named_field(job, stream, ymapsdf_path, regions):
    '''
    Set 'named' flag for records in input stream.
    rd_el has named=True if it has related rd object, named=False otherwise

    stream:
    | rd_el_id | ... |
    |----------+-----|
    |      ... | ... |

    rd_rd_el:
    | rd_el_id | rd_id |
    |----------+-------|
    |      ... |   ... |

    Result:
    | rd_el_id | rd_id | named | ... |
    |----------+-------+-------+-----|
    |      ... |   ... |   ... | ... |
    '''
    return stream.join(
        ymapsdf_stream(job, 'rd_rd_el', ymapsdf_path, regions).unique('rd_el_id'),
        by='rd_el_id',
        type='left',
        assume_unique=True
    ).project(
        ne.all(),
        qe.expression('named', qf.defined('rd_id'))
    ).label('rd_el_with_named')


def add_rd_el_restriction_field(job, stream, ymapsdf_path, regions):
    '''
    Set 'restricted' flag for records in input stream.
    rd_el has restricted=True if it has related rd_el_vehicle_restriction object
    restricted=False - otherwise

    stream:
    | rd_el_id | ... |
    |----------+-----|
    |      ... | ... |

    rd_el_vehicle_restriction:
    | rd_el_id | vehicle_restriction_id |
    |----------+------------------------|
    |      ... |                    ... |

    Result:
    | rd_el_id | vehicle_restriction_id | restricted | ... |
    |----------+------------------------+------------+-----|
    |      ... |                    ... |        ... | ... |
    '''
    return stream.join(
        ymapsdf_stream(job, 'rd_el_vehicle_restriction', ymapsdf_path, regions).unique('rd_el_id'),
        by='rd_el_id',
        type='left'
    ).project(
        ne.all(),
        qe.expression('restricted', qf.defined('vehicle_restriction_id'))
    ).label('rd_el_with_restriction')


def add_rd_el_universal_field(job, stream, ymapsdf_path, regions):
    '''
    Add universal_id field for records in input stream.
    universal_id has values if universal_id in vehicle_restriction table
    is not null

    stream:
    | rd_el_id | ... |
    |----------+-----|
    |      ... | ... |

    rd_el_vehicle_restriction:
    | rd_el_id | vehicle_restriction_id |
    |----------+------------------------|
    |      ... |                    ... |

    vehicle_restriction:
    | vehicle_restriction_id | universal_id |
    |------------------------+--------------|
    |                    ... |          ... |

    Result:
    | rd_el_id | vehicle_restriction_id | universal_id | ... |
    |----------+------------------------+--------------+-----|
    |      ... |                    ... |          ... | ... |
    '''
    rd_el_veh_rest = ymapsdf_stream(job, 'rd_el_vehicle_restriction', ymapsdf_path, regions)
    veh_rest = ymapsdf_stream(job, 'vehicle_restriction', ymapsdf_path, regions)
    return stream.join(
        rd_el_veh_rest,
        by='rd_el_id',
        type='left'
    ).join(
        veh_rest,
        by='vehicle_restriction_id',
        type='left'
    ).project(
        ne.all(),
        universal_id=ne.custom(
            lambda universal_id: universal_id if universal_id else 'False',
            'universal_id'
        )
    ).label('rd_el_with_restriction_universal_id')


def add_cond_restriction_field(job, stream, ymapsdf_path, regions):
    '''
    Set 'restricted' flag for records in input stream.
    cond has restricted=True if it has related cond_vehicle_restriction object,
    restricted=False otherwise

    stream:
    | cond_id | ... |
    |---------+-----|
    |     ... | ... |

    cond_vehicle_restriction:
    | cond_id | vehicle_restriction_id |
    |---------+------------------------|
    |     ... |                    ... |

    Result:
    | cond_id | vehicle_restriction_id | restricted | ... |
    |---------+------------------------+------------+-----|
    |     ... |                    ... |        ... | ... |
    '''
    return stream.join(
        ymapsdf_stream(job, 'cond_vehicle_restriction', ymapsdf_path, regions).unique('cond_id'),
        by='cond_id',
        type='left'
    ).project(
        ne.all(),
        qe.expression('restricted', qf.defined('vehicle_restriction_id'))
    ).label('cond_with_restriction')


def prepare_cond_restrict(job, stream, field_count, field_to_aggregate, date, major_regions, type_restriction):
    '''
    If {field_count} value is not None,
    set 'all' in {field_count} field in input stream
    when {type_restriction} == 'all'.
    {field_count} field for calculating statistics for conds
    with a specific type of restrictions.
    Set value in {field_count} field in input stream
    when {type_restriction} == 'value'.
    {field_count} field for calculating statistics for conds
    with a specific value of specific type of restrictions

    stream_prepare_values - prepare table for count_stream_by_field function:
    - filtering by {field_count}
    - set 'all' or value in {field_count}

    stream:
    | cond_id | {field_count} | ... |
    |---------+---------------+-----|
    |     ... |           ... | ... |

    stream_prepare_values:
    | cond_id | {field_count} | ... |
    |---------+---------------+-----|
    |     ... | 'all'/'value' | ... |

    Result:
    | cond_id | region_tree | {field_count} | fielddate | category |
    |---------|-------------+---------------+-----------+----------|
    |         |         ... |           ... |      date |      ... |
    '''
    stream_prepare_values = stream.filter(
        qf.defined(field_count),
        nf.not_(nf.equals(field_count, False))
    ).project(
        ne.all(),
        **{
            field_count: ne.custom(
                lambda val: 'all' if type_restriction == 'all' else val,
                field_count
            )
        }
    ).label('cond_prepare_{}_{}'.format(field_count, type_restriction))
    return count_stream_by_field(
        job,
        stream_prepare_values,
        'cond',
        field_count,
        field_to_aggregate,
        date,
        major_regions
    )


def add_cond_vehicle_restricted_field(job, stream, ymapsdf_path, regions, field_to_aggregate, date, major_regions):
    '''
    Add restriction_type field for records in input stream.
    The restriction_type field contains values from the restriction fields

    stream:
    | cond_id | ... |
    |---------+-----|
    |     ... | ... |

    cond_vehicle_restriction:
    | cond_id | vehicle_restriction_id |
    |---------+------------------------|
    |     ... |                    ... |

    vehicle_restriction:
    | vehicle_restriction_id | weight_limit | axle_weight_limit | max_weight_limit | height_limit | width_limit | length_limit | payload_limit | min_eco_class | trailer_not_allowed |
    |------------------------+--------------+-------------------+------------------+--------------+-------------+--------------+---------------+---------------+---------------------|
    |                    ... |          ... |               ... |              ... |          ... |         ... |          ... |           ... |           ... |                 ... |

    Result:
    | region_tree | COND_V_RESTRICTION_VALUE_FIELD | fielddate | category |
    |-------------+--------------------------------+-----------+----------|
    |         ... |                            ... |      date |      ... |
    '''
    cond_v_restriction = ymapsdf_stream(job, 'cond_vehicle_restriction', ymapsdf_path, regions)
    v_restriction = ymapsdf_stream(job, 'vehicle_restriction', ymapsdf_path, regions)
    stream_join = stream.join(
        cond_v_restriction,
        by='cond_id',
        type='left'
    ).join(
        v_restriction,
        by='vehicle_restriction_id',
        type='left'
    ).label('cond_join_restrictions')

    stream_cond_v_restriction = [
        job.concat(
            prepare_cond_restrict(
                job, stream_join, field_rest, field_to_aggregate, date, major_regions, 'all'
            ),
            prepare_cond_restrict(
                job, stream_join, field_rest, field_to_aggregate, date, major_regions, 'value'
            )
        ).label(
            'cond_restriction_{}'.format(field_rest)
        ) for field_rest in cond_constants.COND_V_RESTRICTION_VALUE_FIELDS
    ]

    return job.concat(
        *stream_cond_v_restriction
    ).label('cond_with_restriction_type')


def calc_rd_by_rd_type(job, stream, date, major_regions, field_to_aggregate):
    '''
    Calculate total number or length of named roads and named roads by rd_type

    stream:
    | rd_type | {field_to_aggregate} | region_id | ... |
    |---------+----------------------+-----------+-----|
    |     ... |                  ... |       ... | ... |

    Result:
    | region_tree | {field_to_aggregate} | fielddate | category |
    |-------------+----------------------+-----------+----------|
    |         ... |                  ... |      date |      ... |
    '''
    stream = count_stream_by_field(job, stream, 'rd', 'rd_type', field_to_aggregate, date, major_regions)
    return job.concat(
        stream,
        stream.groupby('region_tree').aggregate(
            **{field_to_aggregate: na.sum(field_to_aggregate)}
        ).project(
            'region_tree', field_to_aggregate,
            fielddate=ne.const(date),
            category=ne.const('rd all')
        )
    ).label('result_rd')


@nu.with_hints(output_schema={
    'cond_type': int,
    'cond_id':   int,
    'cond_seq_id': int,
    'access_id': qt.Int16,
    'seq_num': int,
    'rd_jc_id': qt.Optional[qt.Int64],
    'rd_el_id': int,
    'vehicle_restriction_id': int,
    'restricted': bool
})
def exclude_autogenerated_conds_reducer(groups):
    '''
    Autogenerated maneuvers are not needed in statistics, so exclude them here.
    Such maneuvers have cond_type == 1 (forbidden) and all equal rd_el_ids,
    generated for forbidden u-turns
    '''
    for _, records in groups:
        rd_el_id = None
        result_record = None
        filter_passed = False
        for record in records:
            if record['seq_num'] == 0:
                result_record = record
            if not filter_passed:
                rd_el_id = rd_el_id or record['rd_el_id']
                filter_passed = record['cond_type'] != 1 or rd_el_id != record['rd_el_id']
        if filter_passed:
            yield result_record


def remove_autogenerated_conds(job, stream, ymapsdf_path, regions):
    cond_rd_seq = ymapsdf_stream(
        job, 'cond_rd_seq', ymapsdf_path, regions
    )
    return stream.unique('cond_id').join(
        cond_rd_seq,
        by='cond_seq_id',
        assume_defined=True,
        allow_undefined_keys=False
    ).label('cond_join_cond_rd_seq').groupby(
        'cond_id', 'cond_seq_id'
    ).reduce(
        exclude_autogenerated_conds_reducer
    ).label('excluded_autogenerated_conds')


def add_cond_region(job, stream, ymapsdf_path, regions):
    return stream.join(
        ymapsdf_stream(job, 'rd_jc', ymapsdf_path, regions).unique('rd_jc_id'),
        by='rd_jc_id',
        type='left',
        assume_unique_right=True,
        assume_defined=True,
        allow_undefined_keys=False
    ).label('cond_join_cond_rd_seq_join_rd_jc_id').project(
        ne.all(),
        region_id=ne.custom(point_region, 'shape'),
        files=geobase_region.FILES,
        memory_limit=geobase_region.GEOBASE_JOB_MEMORY_LIMIT
    ).label('cond_with_region')


def add_default_region(job, stream, *unused_args):
    return stream.project(
        ne.all(),
        region_id=ne.const(str(geobase_region.EARTH_REGION_ID))
    )


def calc_shape_length(stream):
    return stream.project(
        ne.all(),
        length=ne.custom(_shape_length, 'shape')
    )


def set_count_field(stream):
    return stream.project(
        ne.all(),
        count=ne.const(1)
    )
