from . import constants as c
from . import ad_processors as ap
from . import ft_processors as fp
from . import stream_processors as sp
from nile.api.v1 import (
    aggregators as na,
    clusters,
    datetime as nd,
    extractors as ne,
    filters as nf,
    statface as ns
)


def make_ad_report_data(job, date, ymapsdf_path, regions, major_regions):
    stream = sp.ymapsdf_stream(job, 'ad', ymapsdf_path, regions).unique('ad_id')
    stream = sp.set_count_field(stream)
    stream = ap.add_ad_region_and_shape_type(job, stream, ymapsdf_path, regions)

    town_stream = ap.add_ad_town(job, stream, ymapsdf_path, regions)

    stream_count_by_town = sp.count_stream_by_field(
        job,
        town_stream,
        'ad',
        'town',
        'count',
        date,
        major_regions
    )
    stream_count_by_level_kind = sp.count_stream_by_field(
        job,
        stream,
        'ad',
        'level_kind',
        'count',
        date,
        major_regions
    )
    return job.concat(
        ap.count_ad_by_shape(job, stream, date, major_regions),
        stream_count_by_town,
        stream_count_by_level_kind
    ).label('result_ad')


def make_ft_report_data(job, date, ymapsdf_path, regions, major_regions):
    stream = sp.ymapsdf_stream(job, 'ft', ymapsdf_path, regions).unique('ft_id')
    stream = fp.add_ft_region(job, stream, ymapsdf_path, regions)
    stream = sp.set_count_field(stream)
    stream = fp.add_ft_category(stream)
    ft_only_poi = fp.filter_poi(job, stream, ymapsdf_path, regions)
    return job.concat(
        fp.count_ft_by_region(stream, date, major_regions),
        fp.count_ft_by_name(job, stream, ymapsdf_path, regions, major_regions, date),
        fp.count_poi(job, ft_only_poi, ymapsdf_path, regions, major_regions, date),
        fp.count_poi_by_position_quality(job, ft_only_poi, ymapsdf_path, regions, major_regions, date),
    ).label('result_ft')


def make_ft_ft_report_data(job, date, ymapsdf_path, regions, major_regions):
    stream = sp.ymapsdf_stream(
        job, 'ft_ft', ymapsdf_path, regions
    ).filter(
        nf.equals('role', 'poi-entrance-assigned')
    )

    stream_pois = stream.project(ft_id='master_ft_id').unique('ft_id')
    stream_pois = fp.add_ft_region(job, stream_pois, ymapsdf_path, regions)
    stream_pois = sp.set_count_field(stream_pois)

    stream_entrances = stream.project(ft_id='slave_ft_id').unique('ft_id')
    stream_entrances = fp.add_ft_region(job, stream_entrances, ymapsdf_path, regions)
    stream_entrances = sp.set_count_field(stream_entrances)

    return job.concat(
        sp.count_stream(job, stream_pois, 'ft poi-with-entrances', 'count', date, major_regions),
        sp.count_stream(job, stream_entrances, 'ft entrances-for-poi', 'count', date, major_regions),
    ).label('result_poi_entrance_assigned')


def make_flat_range_report_data(job, date, ymapsdf_path, regions, major_regions):
    stream_range = sp.ymapsdf_stream(job, 'entrance_flat_range', ymapsdf_path, regions)
    stream_range = stream_range.groupby('ft_id').aggregate(all_exact=na.min('is_exact'))
    stream_range = fp.add_ft_region(job, stream_range, ymapsdf_path, regions)
    stream_range = sp.set_count_field(stream_range)

    stream_exact, stream_not_exact = stream_range.split(
        nf.equals('all_exact', False)
    )

    def count_flats_in_range(first, last):
        # If flat number contains letter - it will be both in 'first' and 'last'
        return int(last) - int(first) + 1 if first != last else 1

    stream_flat = sp.ymapsdf_stream(job, 'entrance_flat_range', ymapsdf_path, regions)
    stream_flat = stream_flat.project(
        'ft_id',
        count=ne.custom(count_flats_in_range, 'flat_first', 'flat_last')
    )
    stream_flat = stream_flat.groupby('ft_id').aggregate(count=na.sum('count'))
    stream_flat = fp.add_ft_region(job, stream_flat, ymapsdf_path, regions)

    return job.concat(
        sp.count_stream(job, stream_exact, 'ft ft_type=urban-entrance flat_range exact', 'count', date, major_regions),
        sp.count_stream(job, stream_not_exact, 'ft ft_type=urban-entrance flat_range not exact', 'count', date, major_regions),
        sp.count_stream(job, stream_flat, 'flat', 'count', date, major_regions),
    ).label('result_flat_range')


def make_rd_report_data(job, date, ymapsdf_path, regions, major_regions, report_descriptor):
    stream = sp.ymapsdf_stream(job, 'rd', ymapsdf_path, regions).unique('rd_id')
    stream = sp.add_rd_region(job, stream, ymapsdf_path, regions)
    stream = report_descriptor.preprocess_stream(stream).label('rd_preprocessed')
    return sp.calc_rd_by_rd_type(
        job,
        stream,
        date,
        major_regions,
        report_descriptor.field_to_aggregate
    )


def prepare_stream(job, table_name, ymapsdf_path, regions):
    stream = sp.ymapsdf_stream(job, table_name, ymapsdf_path, regions)
    if table_name in c.TABLE_PROCESSORS:
        table_processor = c.TABLE_PROCESSORS[table_name]
        if table_processor.preprocessor is not None:
            stream = table_processor.preprocessor(job, stream, ymapsdf_path, regions)
        stream = table_processor.region_getter(job, stream, ymapsdf_path, regions)
    else:
        stream = sp.add_default_region(job, stream, ymapsdf_path, regions)

    return stream


def make_report_data(job, stream, date, report_descriptor, table_name, major_regions):
    results = []

    table_category_filters = report_descriptor.table_filters[table_name]
    for category_filter in table_category_filters:
        filtered_stream = stream
        if category_filter.predicate is not None:
            filtered_stream = stream.filter(
                category_filter.predicate
            ).label('filtered_{}'.format(category_filter.name))
        results.append(filtered_stream.groupby('region_id').aggregate(
            aggregated_field=na.sum(report_descriptor.field_to_aggregate)
        ).join(
            major_regions,
            by='region_id',
            type='left',
            assume_small_right=True,
            allow_undefined_keys=False,
            assume_defined=True,
            memory_limit=8 * 1024
        ).project(
            'aggregated_field',
            region_tree=ne.custom(lambda rt: rt or sp.DEFAULT_REGION_TREE, 'region_tree')
        ).groupby('region_tree').aggregate(
            **{report_descriptor.field_to_aggregate: na.sum('aggregated_field')}
        ).project(
            'region_tree', report_descriptor.field_to_aggregate,
            fielddate=ne.const(date),
            category=ne.const(category_filter.name)
        ).label('counted_{}'.format(category_filter.name)))

    report_data = job.concat(*results)
    if len(results) > 1:
        # Avoid multiple label() calls for the same stream
        report_data.label('result_{}'.format(table_name))
    return report_data


def make_rd_el_report_data(job, date, ymapsdf_path, regions, major_regions, report_descriptor):
    stream = prepare_stream(job, 'rd_el', ymapsdf_path, regions)
    stream = report_descriptor.preprocess_stream(stream)
    named_stream = sp.add_rd_el_named_field(job, stream, ymapsdf_path, regions)
    restricted_stream = sp.add_rd_el_restriction_field(job, stream, ymapsdf_path, regions)
    universal_stream = sp.add_rd_el_universal_field(job, stream, ymapsdf_path, regions)

    stream_count_by_fc = sp.count_stream_by_field(
        job,
        stream,
        'rd_el',
        'fc',
        report_descriptor.field_to_aggregate,
        date,
        major_regions
    )
    stream_count_by_named = sp.count_stream_by_field(
        job,
        named_stream,
        'rd_el',
        'named',
        report_descriptor.field_to_aggregate,
        date,
        major_regions
    )
    stream_count_by_restricted = sp.count_stream_by_field(
        job,
        restricted_stream,
        'rd_el',
        'restricted',
        report_descriptor.field_to_aggregate,
        date,
        major_regions
    )
    stream_count_by_universalid = sp.count_stream_by_field(
        job,
        universal_stream,
        'rd_el',
        'universal_id',
        report_descriptor.field_to_aggregate,
        date,
        major_regions
    )

    return job.concat(
        make_report_data(job, stream, date, report_descriptor, 'rd_el', major_regions),
        stream_count_by_fc,
        stream_count_by_named,
        stream_count_by_restricted,
        stream_count_by_universalid
    ).label('result_rd_el_full')


def make_cond_report_data(job, date, ymapsdf_path, regions, major_regions, report_descriptor):
    stream = prepare_stream(job, 'cond', ymapsdf_path, regions)
    stream = report_descriptor.preprocess_stream(stream)
    restricted_stream = sp.add_cond_restriction_field(job, stream, ymapsdf_path, regions)
    stream_count_by_restricted = sp.count_stream_by_field(
        job,
        restricted_stream,
        'cond',
        'restricted',
        report_descriptor.field_to_aggregate,
        date,
        major_regions
    )
    restricted_type_stream = sp.add_cond_vehicle_restricted_field(
        job,
        stream,
        ymapsdf_path,
        regions,
        report_descriptor.field_to_aggregate,
        date,
        major_regions
    )

    return job.concat(
        make_report_data(job, stream, date, report_descriptor, 'cond', major_regions),
        stream_count_by_restricted,
        restricted_type_stream
    ).label('result_cond_full')


def publish_data(report_data, report_name, scale, statface_client, job, date):
    report = ns.StatfaceReport() \
        .path(report_name) \
        .scale(scale) \
        .client(statface_client)

    report_data.publish(
        report,
        allow_change_job=True
    )

    if scale != 'daily':
        return job

    for scale in ['weekly', 'monthly', 'yearly']:
        cur_period = nd.round_period(date, scale)
        if cur_period == date:
            report = ns.StatfaceReport() \
                .path(report_name) \
                .scale(scale) \
                .client(statface_client)

            report_data.publish(
                report,
                allow_change_job=True
            )
    return job


def make_job(
    job, scale, statface_client, date, report_descriptor,
    ymapsdf_path, regions, tables, result_path=None
):
    if result_path is not None:
        clusters.yt.Hahn().write(result_path, [])

    for table_name in tables:
        assert table_name in report_descriptor.available_tables

        major_regions = job.table(c.MAJOR_REGIONS_TABLE).label('major_regions')
        if table_name == 'ad':
            report_data = make_ad_report_data(job, date, ymapsdf_path, regions, major_regions)
        elif table_name == 'ft':
            report_data = make_ft_report_data(job, date, ymapsdf_path, regions, major_regions)
        elif table_name == 'ft_ft':
            report_data = make_ft_ft_report_data(job, date, ymapsdf_path, regions, major_regions)
        elif table_name == 'entrance_flat_range':
            report_data = make_flat_range_report_data(job, date, ymapsdf_path, regions, major_regions)
        elif table_name == 'rd':
            report_data = make_rd_report_data(job, date, ymapsdf_path, regions, major_regions, report_descriptor)
        elif table_name == 'rd_el':
            report_data = make_rd_el_report_data(job, date, ymapsdf_path, regions, major_regions, report_descriptor)
        elif table_name == 'cond':
            report_data = make_cond_report_data(job, date, ymapsdf_path, regions, major_regions, report_descriptor)
        else:
            stream = prepare_stream(job, table_name, ymapsdf_path, regions)
            stream = report_descriptor.preprocess_stream(stream)
            report_data = make_report_data(job, stream, date, report_descriptor, table_name, major_regions)

        if result_path is not None:
            report_data.put(result_path, append=True)

        job = publish_data(report_data, report_descriptor.report_name, scale, statface_client, job, date)
    return job
