from . import processors as pr
from library.python import resource
from yandex.maps import geolib3
import yt.wrapper as yt

from nile.api.v1 import (
    aggregators as na,
    extractors as ne,
    stream as ns,
    Job
)
from qb2.api.v1 import (
    extractors as qe,
    filters as qf
)

import codecs
import json
import typing as tp
import xml.etree.ElementTree as ET


def get_merged_ft_type_ids() -> tp.Set[int]:
    root = ET.fromstring(resource.find('merged_ft_type_ids.xml'))
    return set([int(elem.text or '') for elem in root.iter('ft-type-id')])


def get_protected_ft_type_ids() -> tp.Set[int]:
    json_data = json.loads(resource.find('protected_ft_type_ids.json'))
    return set([int(ft_type_id) for ft_type_id in json_data['protected_ft_type_ids']])


def wkb_to_lon(wkb: str) -> str:
    if not wkb:
        return None
    point = geolib3.Point2.from_EWKB(codecs.decode(wkb, 'hex'))
    return str(point.x)


def wkb_to_lat(wkb: str) -> str:
    if not wkb:
        return None
    point = geolib3.Point2.from_EWKB(codecs.decode(wkb, 'hex'))
    return str(point.y)


def make_raw_data(
    ft: ns.Stream,
    ft_poi_attr: ns.Stream,
    ft_center: ns.Stream,
    node: ns.Stream,
    altay_data: ns.Stream,
    ft_source: ns.Stream
) -> ns.Stream:
    ft_source_org = ft_source.filter(
        qf.equals('source_type_id', 1)
    ).project(
        ft_id='ft_id',
        permalink_id_ymapsdf='source_id'
    )

    parent_ft_id = ft.project(
        p_ft_id='ft_id',
        p_ft_type_id='ft_type_id'
    )

    nk_data = ft.join(
        ft_poi_attr,
        by='ft_id',
        type='left',
        assume_unique=True
    ).join(
        ft_source_org,
        by='ft_id',
        type='left',
        assume_unique=True
    ).join(
        parent_ft_id,
        by='p_ft_id',
        type='left'
    )

    nk_data = pr.add_point_geom(
        nk_data, ft_center, node, 'geom_nk'
    )

    merged_ft_type_ids = get_merged_ft_type_ids()

    raw_data = nk_data.join(
        altay_data,
        by='ft_id',
        type='left',
        assume_unique=True
    ).filter(
        qf.one_of('ft_type_id', merged_ft_type_ids)
    ).project(
        'ft_id',
        'ft_type_id',
        'p_ft_id',
        'p_ft_type_id',
        'permalink_id',
        'permalink_id_ymapsdf',
        'geom_nk',
        'geom_altay',
        'position_quality',
        'is_closed',
        dist=ne.custom(
            lambda ga, gf, closed: geolib3.fast_geodistance(
                geolib3.Point2.from_EWKB(codecs.decode(ga, 'hex')),
                geolib3.Point2.from_EWKB(codecs.decode(gf, 'hex'))
                ) if ga and gf and not closed else None,
            'geom_nk', 'geom_altay', 'is_closed'
        ).with_type(float)
    )
    raw_data = pr.add_region_id(raw_data, 'geom_nk')

    return raw_data


def calculate_metrics(
    raw_data: ns.Stream,
    major_regions: ns.Stream,
    group_name: str
) -> ns.Stream:
    raw_data = pr.add_region_name(raw_data, major_regions)

    metrics = raw_data.groupby('region_name').aggregate(
        total_count=na.count(),
        verified_count=na.count(predicate=qf.or_(
            qf.equals('position_quality', 5),
            qf.equals('position_quality', 4)
        )),

        to_close=na.count(predicate=qf.nonzero('is_closed')),
        no_permalink=na.count(predicate=qf.not_(qf.defined('permalink_id'))),

        dist_equal=na.count(predicate=qf.compare('dist', '<=', value=0.5, default=False)),
        dist_over_half_m=na.count(predicate=qf.compare('dist', '>', value=0.5, default=False)),
        dist_over_5m=na.count(predicate=qf.compare('dist', '>', value=5.0, default=False)),
        dist_over_10m=na.count(predicate=qf.compare('dist', '>', value=10.0, default=False)),
        dist_over_20m=na.count(predicate=qf.compare('dist', '>', value=20.0, default=False)),
        dist_over_50m=na.count(predicate=qf.compare('dist', '>', value=50.0, default=False)),
        dist_over_100m=na.count(predicate=qf.compare('dist', '>', value=100.0, default=False)),
        dist_over_500m=na.count(predicate=qf.compare('dist', '>', value=500.0, default=False)),

        pos_quality_user=na.count(predicate=qf.equals('position_quality', 5)),
        pos_quality_precise=na.count(predicate=qf.equals('position_quality', 4)),
        pos_quality_empty=na.count(predicate=qf.or_(
            qf.not_(qf.defined('position_quality')),
            qf.equals('position_quality', 0))
        )
    )

    ambiguous_permalink_metrics = raw_data.filter(
        qf.and_(
            qf.defined('permalink_id'),
            qf.not_(qf.nonzero('is_closed')))
        ).groupby(
            'region_name', 'permalink_id'
        ).aggregate(
            count=na.count()
        ).filter(
            qf.compare('count', '>', 1)
        ).groupby(
            'region_name'
        ).aggregate(
            ambiguous_permalinks=na.count(),
            ambiguous_ft_ids=na.sum('count')
        )

    ambiguous_permalink_ymapsdf_metrics = raw_data.filter(
        qf.and_(
            qf.defined('permalink_id_ymapsdf'),
            qf.not_(qf.nonzero('is_closed')))
        ).groupby(
            'region_name', 'permalink_id_ymapsdf'
        ).aggregate(
            count=na.count()
        ).filter(
            qf.compare('count', '>', 1)
        ).groupby(
            'region_name'
        ).aggregate(
            ambiguous_permalinks_ymapsdf=na.count(),
            ambiguous_ft_ids_ymapsdf=na.sum('count')
        )

    return metrics.join(
        ambiguous_permalink_metrics,
        by='region_name',
        type='left',
        assume_unique=True
    ).join(
        ambiguous_permalink_ymapsdf_metrics,
        by='region_name',
        type='left',
        assume_unique=True
    ).project(
        qe.const('group', group_name),
        'region_name',
        qe.coalesce('total_count', 'total_count', 0).with_type(int),
        qe.coalesce('verified_count', 'verified_count', 0).with_type(int),
        qe.coalesce('to_close', 'to_close', 0).with_type(int),
        qe.coalesce('no_permalink', 'no_permalink', 0).with_type(int),
        qe.coalesce('ambiguous_permalinks', 'ambiguous_permalinks', 0).with_type(int),
        qe.coalesce('ambiguous_ft_ids', 'ambiguous_ft_ids', 0).with_type(int),
        qe.coalesce('ambiguous_permalinks_ymapsdf', 'ambiguous_permalinks', 0).with_type(int),
        qe.coalesce('ambiguous_ft_ids_ymapsdf', 'ambiguous_ft_ids', 0).with_type(int),
        qe.coalesce('dist_equal', 'dist_equal', 0).with_type(int),
        qe.coalesce('dist_over_half_m', 'dist_over_half_m', 0).with_type(int),
        qe.coalesce('dist_over_5m', 'dist_over_5m', 0).with_type(int),
        qe.coalesce('dist_over_10m', 'dist_over_10m', 0).with_type(int),
        qe.coalesce('dist_over_20m', 'dist_over_20m', 0).with_type(int),
        qe.coalesce('dist_over_50m', 'dist_over_50m', 0).with_type(int),
        qe.coalesce('dist_over_100m', 'dist_over_100m', 0).with_type(int),
        qe.coalesce('dist_over_500m', 'dist_over_500m', 0).with_type(int),
        qe.coalesce('pos_quality_user', 'pos_quality_user', 0).with_type(int),
        qe.coalesce('pos_quality_precise', 'pos_quality_precise', 0).with_type(int),
        qe.coalesce('pos_quality_empty', 'pos_quality_empty', 0).with_type(int),
    )


def calculate_dist_data(
    raw_data: ns.Stream,
    major_regions: ns.Stream,
    group_name: str
) -> ns.Stream:
    raw_data = pr.add_region_name(raw_data, major_regions)

    return raw_data.filter(
        qf.and_(
            qf.defined('dist'),
            qf.compare('dist', '>', 0.5))
    ).project(
        'ft_id',
        'permalink_id',
        'dist',
        qe.const('group', group_name),
        'region_name',
        lon_nk=ne.custom(wkb_to_lon, 'geom_nk').with_type(str),
        lat_nk=ne.custom(wkb_to_lat, 'geom_nk').with_type(str),
        lon_altay=ne.custom(wkb_to_lon, 'geom_altay').with_type(str),
        lat_altay=ne.custom(wkb_to_lat, 'geom_altay').with_type(str)
    )


def calculate_close_data(
    raw_data: ns.Stream,
    major_regions: ns.Stream,
    group_name: str
) -> ns.Stream:
    raw_data = pr.add_region_name(raw_data, major_regions)

    return raw_data.filter(
        qf.nonzero('is_closed')
    ).project(
        'ft_id',
        'permalink_id',
        qe.const('group', group_name),
        'region_name',
        lon_nk=ne.custom(wkb_to_lon, 'geom_nk').with_type(str),
        lat_nk=ne.custom(wkb_to_lat, 'geom_nk').with_type(str),
        lon_altay=ne.custom(wkb_to_lon, 'geom_altay').with_type(str),
        lat_altay=ne.custom(wkb_to_lat, 'geom_altay').with_type(str)
    )


def calculate_no_permalink_data(
    raw_data: ns.Stream,
    major_regions: ns.Stream,
    group_name: str
) -> ns.Stream:
    raw_data = pr.add_region_name(raw_data, major_regions)

    return raw_data.filter(
        qf.not_(qf.defined('permalink_id'))
    ).project(
        'ft_id',
        qe.const('group', group_name),
        'region_name',
        lon_nk=ne.custom(wkb_to_lon, 'geom_nk').with_type(str),
        lat_nk=ne.custom(wkb_to_lat, 'geom_nk').with_type(str)
    )


def calculate_ambiguous_data(
    raw_data: ns.Stream,
    major_regions: ns.Stream,
    permalink_field_name: str,
    group_name: str
) -> ns.Stream:
    ambiguous_permalinks = raw_data.filter(
        qf.and_(
            qf.defined(permalink_field_name),
            qf.not_(qf.nonzero('is_closed')))
        ).project(
            permalink_field_name
        ).groupby(
            permalink_field_name
        ).aggregate(
            count=na.count()
        ).filter(
            qf.compare('count', '>', 1)
        )
    raw_data = raw_data.join(
        ambiguous_permalinks,
        by=permalink_field_name,
        type='inner',
        assume_unique_right=True
    )

    raw_data = pr.add_region_name(raw_data, major_regions)

    return raw_data.project(
        'ft_id',
        permalink_field_name,
        qe.const('group', group_name),
        'region_name',
        lon_nk=ne.custom(wkb_to_lon, 'geom_nk').with_type(str),
        lat_nk=ne.custom(wkb_to_lat, 'geom_nk').with_type(str),
        lon_altay=ne.custom(wkb_to_lon, 'geom_altay').with_type(str),
        lat_altay=ne.custom(wkb_to_lat, 'geom_altay').with_type(str)
    )


def calculate_ambiguous_data_by_altay(
    raw_data: ns.Stream,
    major_regions: ns.Stream,
    group_name: str
) -> ns.Stream:
    return calculate_ambiguous_data(
        raw_data,
        major_regions,
        'permalink_id',
        group_name)


def calculate_ambiguous_data_by_ymapsdf(
    raw_data: ns.Stream,
    major_regions: ns.Stream,
    group_name: str
) -> ns.Stream:
    return calculate_ambiguous_data(
        raw_data,
        major_regions,
        'permalink_id_ymapsdf',
        group_name)


def calculate_stat_to_table(
    job: Job,
    calculate_function: tp.Callable[[ns.Stream, ns.Stream, str], ns.Stream],
    poi_groups: tp.Dict[str, ns.Stream],
    major_regions: ns.Stream,
    result_path: str,
    table_name: str,
    date: str
) -> None:
    job.concat(
        *[calculate_function(value, major_regions, key) for key, value in poi_groups.items()]
    ).project(
        qe.all(),
        qe.const('date', date)
    ).put(
        yt.ypath.ypath_join(result_path, table_name, date)
    )


def calculate_stat(
    job: Job,
    raw_data: ns.Stream,
    major_regions: ns.Stream,
    result_path: str,
    date: str
) -> None:
    protected_ft_type_ids = get_protected_ft_type_ids()

    poi_protected = raw_data.filter(
        qf.one_of('ft_type_id', protected_ft_type_ids)
    )
    poi_unprotected = raw_data.filter(
        qf.not_(qf.one_of('ft_type_id', protected_ft_type_ids))
    )
    poi_verified = raw_data.filter(
        qf.or_(
            qf.equals('position_quality', 5),
            qf.equals('position_quality', 4)
        )
    )
    poi_user = raw_data.filter(
        qf.equals('position_quality', 5)
    )
    poi_indoor = raw_data.filter(
        qf.equals('p_ft_type_id', 2302)
    )

    poi_groups = {
        'Все': raw_data,
        'Защищённые': poi_protected,
        'Незащищённые': poi_unprotected,
        'Верифицированные': poi_verified,
        'Пользовательские': poi_user,
        'Схемы помещений': poi_indoor
    }

    calculate_stat_to_table(
        job,
        calculate_metrics,
        poi_groups,
        major_regions,
        result_path,
        'metrics',
        date)
    calculate_stat_to_table(
        job,
        calculate_dist_data,
        poi_groups,
        major_regions,
        result_path,
        'dist',
        date)
    calculate_stat_to_table(
        job,
        calculate_close_data,
        poi_groups,
        major_regions,
        result_path,
        'close',
        date)
    calculate_stat_to_table(
        job,
        calculate_no_permalink_data,
        poi_groups,
        major_regions,
        result_path,
        'no-permalink',
        date)
    calculate_stat_to_table(
        job,
        calculate_ambiguous_data_by_altay,
        poi_groups,
        major_regions,
        result_path,
        'ambiguous-permalinks',
        date)
    calculate_stat_to_table(
        job,
        calculate_ambiguous_data_by_ymapsdf,
        poi_groups,
        major_regions,
        result_path,
        'ambiguous-permalinks-ymapsdf',
        date)
