import constants as c
import ft_constants as ftc
import stream_processors as sp
from nile.api.v1 import (
    aggregators as na,
    extractors as ne,
    filters as nf,
    utils as nu,
    Record,
)
from maps.wikimap.stat.libs.common.lib import geobase_region


def make_supercategory_name(ft_name):
    return ft_name + ' all'


def get_ft_type_name(ft_type_id):
    return ftc.FT_MAP.get(ft_type_id, str(ft_type_id))


def get_supercategories(category):
    supercategories = ftc.extract_ft_name_prefixes(category)
    if category in ftc.SUPERCATEGORIES:
        supercategories.append(category)
    return map(make_supercategory_name, supercategories)


def get_poi_categories(category):
    if category in ftc.POI_LANDMARKS:
        yield 'poi-landmark'
        if category in ftc.POI_LANDMARKS_RELIGION:
            yield 'poi-landmark-religion'
        elif category in ftc.POI_LANDMARKS_SPORT:
            yield 'poi-landmark-sport'
        elif category in ftc.POI_LANDMARKS_CULTURE:
            yield 'poi-landmark-culture'
    if category in ftc.POIS:
        yield 'poi'


def add_ft_category(stream):
    '''
    Set 'category' field for records in input stream with ft_type name

    stream:
    | ft_id | ft_type_id | ... |
    |-------+------------+-----|
    |   ... |        ... | ... |


    Result:
    | ft_id | ft_type_id | category | ... |
    |-------+------------+----------+-----|
    |   ... |        ... |      ... | ... |
    '''
    return stream.project(
        ne.all(),
        category=ne.custom(get_ft_type_name, 'ft_type_id')
    ).label('ft_with_categories')


def add_ft_region(job, stream, ymapsdf_path, regions):
    '''
    stream:
    | ft_id |  ... |
    |-------+------|
    |    .. |  ... |

    ft_center:
    | ft_id | node_id |
    |-------+---------|
    |   ... |     ... |

    node:
    | node_id | shape | ... |
    |---------+-------+-----|
    |     ... |   ... | ... |

    Result:
    | ft_id | region_id | ... |
    |-------+-----------+-----|
    |   ... |       ... | ... |
    '''
    return stream.join(
        sp.ymapsdf_stream(job, 'ft_center', ymapsdf_path, regions).unique('ft_id'),
        by='ft_id',
        assume_unique=True,
        allow_undefined_keys=False,
        assume_defined=True
    ).join(
        sp.ymapsdf_stream(job, 'node', ymapsdf_path, regions).unique('node_id'),
        by='node_id',
        assume_unique=True,
        allow_undefined_keys=False,
        assume_defined=True
    ).project(
        ne.all(),
        region_id=ne.custom(sp.point_region, 'shape'),
        files=geobase_region.FILES,
        memory_limit=geobase_region.GEOBASE_JOB_MEMORY_LIMIT
    ).label('ft_with_region')


@nu.with_hints(output_schema={
    'category':    str,
    'region_tree': str,
    'count':       int,
})
def count_ft_type_supercategories_reducer(groups):
    '''
    Count total number of items in each ft_type supercategory
    For example ft_type 'urban-shopping-supermarket'
    has supercategories 'urban' and 'urban-shopping'
    '''
    for region_tree, records in groups:
        supercounts = dict()
        for record in records:
            yield record
            for supercategory in get_supercategories(record['category']):
                if supercounts.get(supercategory) is None:
                    supercounts[supercategory] = Record(record, category=supercategory)
                else:
                    supercounts[supercategory] = Record(
                        record,
                        count=supercounts[supercategory]['count'] + record['count'],
                        category=supercategory
                    )
        for record in supercounts.values():
            yield record


def count_ft_by_region(stream, date, major_regions):
    '''
    stream:
    | category | count | region_id | ... |
    |----------+-------+-----------+-----|
    |      ... |   ... |       ... | ... |

    major_regions:
    | region_id | region_tree | ... |
    |-----------+-------------+-----|
    |       ... |         ... | ... |

    Result:
    | region_tree | count | fielddate | category |
    |-------------+-------+-----------+----------|
    |         ... |   ... |      date |      ... |
    '''
    stream = stream.groupby('category', 'region_id').aggregate(
        count=na.sum('count')
    )
    stream = sp.set_region_tree(stream, major_regions)
    return stream.groupby('region_tree', 'category').aggregate(
        count=na.sum('count')
    ).groupby('region_tree').reduce(
        count_ft_type_supercategories_reducer
    ).label('ft_with_supercategories').project(
        'region_tree', 'count',
        fielddate=ne.const(date),
        category=ne.custom(lambda category: 'ft ft_type=' + category, 'category')
    ).label('result_ft_by_region')


def add_ft_named_flag(job, stream, ymapsdf_path, regions):
    '''
    Set 'named' flag for records in input stream.
    named=True if corresponding ft object has name and named=False otherwise

    stream:
    | ft_id | ... |
    |-------+-----|
    |   ... | ... |

    ft_nm:
    | ft_id | nm_id | ... |
    |-------+-------+-----|
    |   ... |   ... | ... |

    Result:
    | ft_id | named | ... |
    |-------+-------+-----|
    |   ... |   ... | ... |
    '''
    return stream.filter(
        nf.custom(lambda cat: cat in ftc.NAMED_CATEGORIES, 'category')
    ).label('filtered_ft_by_named_categories').join(
        sp.ymapsdf_stream(job, 'ft_nm', ymapsdf_path, regions).unique('ft_id'),
        by='ft_id',
        type='left',
        assume_unique=True,
    ).label('ft_with_nm_id').project(
        ne.all(),
        named=ne.custom(lambda nm_id: True if nm_id is not None else False, 'nm_id')
    ).label('ft_with_named')


def add_ft_named_supercategories_mapper(records):
    '''
    Set categories for records for which we calculate
    named/not named statistics
    '''
    for record in records:
        ft_name = get_ft_type_name(record['ft_type_id'])
        yield Record(record, category='ft ft_type={} named={}'.format(ft_name, record['named']))
        for sc in ftc.NAMED_SUPERCATEGORIES:
            if ft_name.startswith(sc):
                yield Record(
                    record,
                    category='ft ft_type={} named={}'.format(make_supercategory_name(sc), record['named'])
                )


def count_ft_by_name(job, stream, ymapsdf_path, regions, major_regions, date):
    '''
    Choose from ft stream ft_types for which we want to calculate
    named/not named statistics. Calculate it

    stream:
    | ft_type_id | count | region_id | ... |
    |------------+-------+-----------+-----|
    |        ... |   ... |       ... | ... |

    major_regions:
    | region_id | region_tree | ... |
    |-----------+-------------+-----|
    |       ... |         ... | ... |

    Result:
    | region_tree | count | fielddate | category |
    |-------------+-------+-----------+----------|
    |         ... |   ... |      date |      ... |
    '''
    stream = add_ft_named_flag(job, stream, ymapsdf_path, regions).groupby(
        'region_id', 'ft_type_id', 'named'
    ).aggregate(
        count=na.sum('count')
    ).label('ft_aggregated_named_and_ft_id')
    stream = sp.set_region_tree(stream, major_regions)
    stream = stream.map(add_ft_named_supercategories_mapper).label('ft_with_named_supercategories')

    return stream.groupby('region_tree', 'category').aggregate(
        count=na.sum('count')
    ).project(
        'region_tree', 'count', 'category',
        fielddate=ne.const(date),
    ).label('counted_ft_by_named')


def filter_poi(job, stream, ymapsdf_path, regions):
    '''
    Take from stream records which we consider as poi:
      * have poi ft_type
      * have no corresponding ft_face
      * have no corresponding ft_edge

    stream:
    | ft_id | category | ... |
    |-------+----------+-----|
    |   ... |      ... | ... |

    ft_face:
    | ft_id | face_id | ... |
    |-------+---------+-----|
    |   ... |     ... | ... |

    ft_edge:
    | ft_id | edge_id |
    |-------+---------|
    |   ... |     ... |

    Result:
    | ft_id | category | ... |
    |-------+----------+-----|
    |   ... |      ... | ... |
    '''
    return stream.filter(
        nf.custom(lambda cat: cat in ftc.POIS, 'category')
    ).label('filtered_ft_by_poi_categories').join(
        sp.ymapsdf_stream(job, 'ft_face', ymapsdf_path, regions).unique('ft_id'),
        by='ft_id',
        type='left',
        assume_unique=True,
    ).label('ft_with_face_id').filter(
        nf.custom(lambda face_id: face_id is None, 'face_id')
    ).join(
        sp.ymapsdf_stream(job, 'ft_edge', ymapsdf_path, regions).unique('ft_id'),
        by='ft_id',
        type='left',
        assume_unique=True,
    ).label('ft_with_edge_id').filter(
        nf.custom(lambda edge_id: edge_id is None, 'edge_id')
    ).label('ft_only_poi')


def add_poi_categories_mapper(records):
    '''
    Set categories for records for which we calculate
    named/not named statistics
    '''
    for record in records:
        for cat in get_poi_categories(record['category']):
            yield Record(
                record,
                category='ft ft_type={}'.format(cat)
            )


def count_poi(job, stream, ymapsdf_path, regions, major_regions, date):
    '''
    Calculate poi counts

    stream:
    | ft_id | region_id | count | category | ... |
    |-------+-----------+-------+----------+-----|
    |   ... |       ... |   ... |      ... | ... |

    Result:
    | region_tree | count | fielddate | category |
    |-------------+-------+-----------+----------|
    |         ... |   ... |      date |      ... |
    '''
    stream = stream.groupby(
        'region_id', 'category'
    ).aggregate(
        count=na.sum('count')
    ).label('ft_aggregated_poi')
    stream = sp.set_region_tree(stream, major_regions)
    stream = stream.map(add_poi_categories_mapper).label('ft_with_poi_categories')

    return stream.groupby('region_tree', 'category').aggregate(
        count=na.sum('count')
    ).project(
        'region_tree', 'count', 'category',
        fielddate=ne.const(date),
    ).label('counted_ft_poi')


def count_poi_by_position_quality(job, stream, ymapsdf_path, regions, major_regions, date):
    '''
    Calculate poi counts by different position_quality values

    stream:
    | ft_id | region_id | count | ... |
    |-------+-----------+-------+-----|
    |   ... |       ... |   ... | ... |

    ft_poi_attr:
    | ft_id | position_quality | ... |
    |-------+------------------+-----|
    |   ... |              ... | ... |

    Result:
    | region_tree | count | fielddate | category |
    |-------------+-------+-----------+----------|
    |         ... |   ... |      date |      ... |
    '''
    stream = stream.join(
        sp.ymapsdf_stream(job, 'ft_poi_attr', ymapsdf_path, regions).unique('ft_id'),
        by='ft_id',
        type='inner',
        assume_unique=True,
    )

    pos_quality_precise, \
    pos_quality_user, \
    pos_quality_not_set = stream.split(
        nf.equals('position_quality', c.POI_POSITION_QUALITY['precise']),
        nf.equals('position_quality', c.POI_POSITION_QUALITY['user']),
        multisplit=True,
        strategy='stop_if_true'
    )

    return job.concat(
        sp.count_stream(job, pos_quality_precise, 'ft position_quality=precise', 'count', date, major_regions),
        sp.count_stream(job, pos_quality_user, 'ft position_quality=user', 'count', date, major_regions),
        sp.count_stream(job, pos_quality_not_set, 'ft position_quality=not_set', 'count', date, major_regions),
    ).label('counted_ft_poi_by_position_quality')
