import textwrap
from yql.api.v1.client import YqlClient
from yt.wrapper import with_context
from datacloud.geo.lib.h3_lib import h3
from datacloud.geo.lib import general


def _base_geo_log_to_grid(yt_client, input_table, output_table, grid_mapper, key='yuid'):
    with yt_client.TempTable() as tmp_table:
        yt_client.run_map_reduce(
            grid_mapper,
            general.unique_reducer,
            input_table,
            tmp_table,
            reduce_by=[key, 'geo_id'],
            reduce_combiner=general.unique_reducer,
        )
        yt_client.run_sort(tmp_table, output_table, sort_by=key)


def lat_lon_to_h3(yt_client, resolution, input_table, output_table):
    yt_client.run_map(
        _LatLonToH3Mapper(resolution, is_replace=True),
        input_table,
        output_table,
    )


class _LatLonToH3Mapper(object):
    def __init__(self, resolution, is_replace):
        self._res = resolution
        self._is_replace = is_replace

    def __call__(self, rec):
        lat, lon = rec['lat'], rec['lon']
        if lat is not None and lon is not None:
            geo_id = h3.geo_to_h3(lat, lon, self._res)
            if self._is_replace:
                del rec['lat']
                del rec['lon']
            rec['geo_id'] = geo_id
            yield rec


def geo_log_to_grid_h3(yt_client, input_table, output_table, resolution=9, good_point_func=None):
    grid_mapper = GridMapperH3(resolution, good_point_func)
    _base_geo_log_to_grid(yt_client, input_table, output_table, grid_mapper, key='yuid')


class GridMapperH3(object):
    def __init__(self, resolution, good_point_func=None):
        assert 1 < resolution <= 15, 'Wrong resolution 1<={}<=15'.format(resolution)
        self._good_point_func = good_point_func
        self._res = resolution

    def __call__(self, rec):
        lat, lon = rec['lat'], rec['lon']
        if self._good_point_func is None or self._good_point_func(lat, lon):
            geo_id = h3.geo_to_h3(lat, lon, self._res)
            yield {'geo_id': geo_id, 'yuid': rec['yuid']}


def geo_hash_to_center(yt_client, input_table, output_table):
    yt_client.run_map(_center_mapper, input_table, output_table)


def _center_mapper(rec):
    geo_id = rec['geo_id']
    rec['lat'], rec['lon'] = h3.h3_to_center(geo_id)
    yield rec


def enhance_with_geobase(yt_client, yql_token, input_table, output_table):
    yql_client = YqlClient(db='hahn', token=yql_token)
    query = textwrap.dedent("""
        pragma yt.ForceInferSchema = '1';

        INSERT INTO `{output_table}`
        WITH TRUNCATE
        SELECT
            geo_id,
            lat,
            lon,
            Geo::RoundRegionByLocation(lat, lon, "country").en_name AS country,
            Geo::RoundRegionByLocation(lat, lon, "country").name AS country_ru,
            Geo::RoundRegionByLocation(lat, lon, "country").id AS country_id,
            Geo::RoundRegionByLocation(lat, lon, "city").en_name AS city,
            Geo::RoundRegionByLocation(lat, lon, "city").name AS city_ru,
            Geo::RoundRegionByLocation(lat, lon, "city").id AS city_id,
            Geo::RoundRegionByLocation(lat, lon, "town").en_name AS district,
            Geo::RoundRegionByLocation(lat, lon, "town").name AS district_ru,
            Geo::RoundRegionByLocation(lat, lon, "town").id AS district_id,
        FROM
            `{input_table}`
    """)

    query = query.format(input_table=input_table, output_table=output_table)
    request = yql_client.query(query, syntax_version=1)
    request.run()
    request.get_results()


def upscale_h3(yt_client, input_table, output_table, resolutions):
    yt_client.run_map(
        _UpscaleH3Mapper(resolutions),
        input_table,
        output_table,
    )


class _UpscaleH3Mapper(object):
    def __init__(self, resolutions):
        self._resolutions = resolutions

    def __call__(self, rec):
        for r in self._resolutions:
            rec['hash_{}'.format(r)] = h3.h3_to_parent(rec['geo_id'], r)
        yield rec


def enhance_with_boundary(yt_client, input_table, output_table, resolutions):
    yt_client.run_map(
        _EnhanceWithBoundaryMapperH3(resolutions),
        input_table,
        output_table,
    )


class _EnhanceWithBoundaryMapperH3(object):
    def __init__(self, resolutions):
        self._resolutions = resolutions

    def __call__(self, rec):
        for r in self._resolutions:
            key = 'hash_{}'.format(r)
            rec[key + '_coords'] = str(h3.h3_to_geo_boundary(rec[key]))
        yield rec


def enhance_geo_id_with_city(yt_client, input_table, geo_id_to_city_table, output_table):
    yt_client.run_reduce(
        _enhance_city_reducer,
        [geo_id_to_city_table, input_table],
        output_table,
        reduce_by='geo_id',
    )


@with_context
def _enhance_city_reducer(_, recs, context):
    city_id = None
    hash_9 = None
    for rec in recs:
        if context.table_index == 0:
            city_id = rec['city_id']
            hash_9 = rec['hash_9']
        elif city_id:
            rec['city_id'] = city_id
            rec['hash_9'] = hash_9
            yield rec
