import stream_processors as sp
from nile.api.v1 import (
    aggregators as na,
    extractors as ne,
    utils as nu,
    Record,
)
from maps.wikimap.stat.libs.common.lib import geobase_region


GEOCODER_SOURCE_ID_MAP = '//home/maps/geocoder/geosrc/latest_state/toponyms'


@nu.with_hints(output_schema={
    'shape_type':  str,
    'region_tree': str,
    'count':       int,
})
def add_any_shape_ad_count_reducer(groups):
    for _, records in groups:
        count = 0
        for record in records:
            yield record
            count += record['count']
        yield Record(record, count=count, shape_type='all')


def set_ad_region_reducer(groups):
    '''
    Set geo_id for all ad with geo_id of its parent ad
    '''

    for _, records in groups:
        ad2record = dict()
        for record in records:
            assert record['ad_id'] is not None
            ad2record[record['ad_id']] = record

        def set_geoid(ad_id):
            record = ad2record[ad_id]
            if record.get('geo_id') is not None:
                return record['geo_id']
            if record.get('p_ad_id') is None:
                return geobase_region.EARTH_REGION_ID
            # Maximum recursion depth is bounded by real nesting depth of ad
            # Seems it is no more than 10
            ad2record[ad_id] = Record(record, geo_id=set_geoid(record['p_ad_id']))
            return ad2record[ad_id]['geo_id']

        for ad_id in ad2record:
            set_geoid(ad_id)
            yield ad2record[ad_id]


def add_ad_region_and_shape_type(job, stream, ymapsdf_path, regions):
    '''
    Set region_id and shape_type property for ad.
    Get region_id from geocoder_source_id_map table,
    its mutable_source_id field matches with ad_id.
    Shape type can be point or polygon:
    polygon if ad has geometry, point otherwise

    stream:
    | ad_id | count | level_kind | ... |
    |-------+-------+------------+-----|
    |   ... |     1 |        ... | ... |

    ad_face:
    | ad_id | face_id | ... |
    |-------+---------+-----|
    |   ... |     ... | ... |

    geocoder_source_id_map:
    | mutable_source_id | geo_id | ... |
    |-----------+-------+-----|
    |       ... |   ... | ... |

    Result:
    | ad_id | count | shape_type | region_id | level_kind |
    |-------+-------+------------+-----------+------------|
    |   ... |   ... |        ... |       ... |        ... |
    '''

    return stream.unique('ad_id').join(
        job.table(GEOCODER_SOURCE_ID_MAP).label('geocoder_source_id_map'),
        by_left='ad_id', by_right='mutable_source_id',
        type='left',
        assume_defined=True,
        assume_unique=True,
        memory_limit=1024
    ).groupby().reduce(
        # crutch: process all data from table as one group
        set_ad_region_reducer,
        memory_limit=1024 * 10,
    ).label('ad_with_geoid').join(
        sp.ymapsdf_stream(job, 'ad_face', ymapsdf_path, regions).unique('ad_id'),
        by='ad_id',
        type='left',
        assume_unique=True,
        assume_small_right=True,
        memory_limit=1024
    ).project(
        'ad_id', 'count', 'level_kind',
        region_id=ne.custom(lambda geo_id: str(geo_id), 'geo_id'),
        shape_type=ne.custom(lambda face_id: 'point' if face_id is None else 'polygon', 'face_id')
    ).label('ad_with_region_and_shape_type')


def add_ad_town(job, stream, ymapsdf_path, regions):
    '''
    stream:
    | ad_id | ... |
    |-------+-----|
    |   ... | ... |

    locality:
    | ad_id | town | ... |
    |-------+------+-----|
    |   ... |  ... | ... |

    Result:
    | ad_id | town | ... |
    |-------+------+-----|
    |   ... |  ... | ... |
    '''
    return stream.join(
        sp.ymapsdf_stream(job, 'locality', ymapsdf_path, regions).unique('ad_id'),
        by='ad_id',
        type='left',
        assume_unique_left=True,
        assume_small_right=True,
        memory_limit=1024 * 8
    ).label('ad_with_town')


def count_ad_by_shape(job, stream, date, major_regions):
    '''
    stream:
    | ad_id | count | shape_type | region_id |
    |-------+-------+------------+-----------|
    |   ... |   ... |        ... |       ... |

    major_regions:
    | region_id | region_tree | ... |
    |-----------+-------------+-----|
    |       ... |         ... | ... |

    Result:
    | region_tree | count | fielddate | category |
    |-------------+-------+-----------+----------|
    |         ... |   ... |      date |      ... |
    '''
    stream = stream.groupby('region_id', 'shape_type').aggregate(
        count=na.sum('count')
    ).label('region_id_aggregated_ad')
    stream = sp.set_region_tree(stream, major_regions)
    return stream.groupby('region_tree', 'shape_type').aggregate(
        count=na.sum('count')
    ).label('counted_ad_by_shape_type').groupby('region_tree').reduce(
        add_any_shape_ad_count_reducer
    ).label('ad_after_add_any_shape_ad_count_reducer').project(
        'region_tree', 'count',
        fielddate=ne.const(date),
        category=ne.custom(lambda shape_type: 'ad ' + shape_type, 'shape_type')
    ).label('counted_ad_shape_type')
