import ast
import geobase6
import os
import tempfile
from qb2.api.v1 import (
    resources as qr,
    extractors as se,
    typing as qt
)
from nile.api.v1 import (
    extractors as ne,
    files as nf
)

EARTH_REGION_ID = 10000
ROOT_REGION_ID = 0
UNKNOWN_REGION_ID = -2

GEOBASE_JOB_MEMORY_LIMIT = 8 * 1024

GEODATA_RESOURCE = 'geodata6.bin'
GEODATA_YT_STATBOX = '//statbox/statbox-dict-last/geodata5.bin'
GEODATA_YT_FILES = [nf.RemoteFile(GEODATA_YT_STATBOX)]

LOOKUP = None
FILES = GEODATA_YT_FILES


def init_geodata_from_resource():
    global LOOKUP
    global FILES
    if not LOOKUP:
        LOOKUP = geobase6.Lookup(GEODATA_RESOURCE)
        FILES = []


def init_geodata_from_yt_copy(yt_client):
    global LOOKUP
    if not LOOKUP:
        geodata_path = os.path.abspath(os.path.join(tempfile.gettempdir(), GEODATA_RESOURCE))

        with open(geodata_path, "wb") as dest_file:
            for chunk in yt_client.read_file(GEODATA_YT_STATBOX).chunk_iter():
                dest_file.write(chunk)

        LOOKUP = geobase6.Lookup(geodata_path)


def _init_geodata():
    global LOOKUP
    # resource 'ArcGeobase' uses file 'geodata5.bin'
    # as described in resource description:
    # https://a.yandex-team.ru/arc/trunk/arcadia/statbox/qb2/qb2_extensions/resources/location.py?rev=7135091#L248
    LOOKUP = qr.get('ArcGeobase')


def lookup():
    global LOOKUP
    if not LOOKUP:
        _init_geodata()
    return LOOKUP


def is_point(geom):
    return len(geom) == 2 and not isinstance(geom[0], list)


def traverse(geom):
    if is_point(geom):
        yield geom
    else:
        for i in iter(geom):
            for j in traverse(i):
                yield j


def shape_center(shape):
    geom = ast.literal_eval(shape)
    if not geom:
        raise RuntimeError(
            'Wrong geometry value: "{}".'.format(shape) +
            'Did you specify correct column name?')
    lon_sum = 0
    lat_sum = 0
    points_count = 0

    for point in traverse(geom):
        lon_sum += point[0]
        lat_sum += point[1]
        points_count += 1

    lon = lon_sum / points_count
    lat = lat_sum / points_count
    return (lon, lat)


def geobase_region_id(lon, lat):
    region_id = lookup().get_region_id_by_location(lat, lon)
    if region_id in [UNKNOWN_REGION_ID, ROOT_REGION_ID]:
        region_id = EARTH_REGION_ID
    return region_id


def geobase_id_by_shape(shape):
    (lon, lat) = shape_center(shape)
    return geobase_region_id(lon, lat)


class add_region_id_field:
    def __init__(self, geom_column, geoid_column):
        self._geom_column = geom_column
        self._geoid_column = geoid_column

    def __call__(self, table):
        global FILES
        return table.project(
            ne.all(),
            se.custom(
                self._geoid_column,
                lambda x: str(geobase_id_by_shape(str(x))),
                self._geom_column
            ).with_type(qt.String),
            files=FILES,
            intensity='ultra_cpu',
            memory_limit=GEOBASE_JOB_MEMORY_LIMIT
        )
