# -*- coding: utf-8 -*-
import numpy as np
from collections import defaultdict, namedtuple
from library.python.geohash import encode, decode

import yt.wrapper as yt
from datacloud.dev_utils.data.data_utils import array_tostring, array_fromstring
from datacloud.dev_utils.geo.geo_features import haversine


@yt.with_context
def fetch_geo_logs_reducer(key, recs, context):
    base_recs = []
    for rec in recs:
        if context.table_index == 0:
            base_recs.append(rec)
        elif len(base_recs) == 0:
            break
        else:
            for base_rec in base_recs:
                if rec['timestamp'] < base_rec['timestamp']:
                    yield {
                        'external_id': base_rec['external_id'],
                        'lon': rec['lon'],
                        'lat': rec['lat'],
                        'timestamp_of_log': -rec['timestamp'],
                        'original_timestamp': base_rec['timestamp']
                    }


def filter_geo_logs(key, recs):
    ts = None
    for rec in recs:
        if ts is None:
            ts = rec['timestamp_of_log']
        elif ts != rec['timestamp_of_log']:
            break

        yield rec


PointCheckResult = namedtuple('PointCheckResult', 'use_point geohash prefix')


class PointsToGeohashReducerBase:
    def __init__(self, precision=10, prefix_size=8, max_points=1000):

        self.precision = precision
        self.prefix_size = prefix_size
        self.max_points = max_points

    def in_mother_russia(self, point):
        """
            https://en.wikipedia.org/wiki/List_of_extreme_points_of_Russia
        """
        return (point[1] > 41.220556 and point[1] < 81.843056) and (point[0] > 19.639167 or point[0] < -169.016667)

    def check_point(self, point, geohashes, prefixes):
        geohash = encode(point[1], point[0], precision=self.precision)
        prefix = geohash[:self.prefix_size]

        if self.in_mother_russia(point) and prefix not in prefixes:
            return PointCheckResult(use_point=True, geohash=geohash, prefix=prefix)

        return PointCheckResult(use_point=False, geohash=geohash, prefix=prefix)


class PointsToGeohashReducerRetro(PointsToGeohashReducerBase):
    def __call__(self, key, recs):
        prefixes = set()
        geohashes = []

        for rec in recs:
            point = rec['lon'], rec['lat']
            point_check_result = self.check_point(point, geohashes, prefixes)
            if point_check_result.use_point:
                geohashes.append(point_check_result.geohash)
                prefixes.add(point_check_result.prefix)

        for geohash in geohashes[:self.max_points]:
            point = decode(geohash)
            yield {
                'external_id': key['external_id'],
                'lon': float(point[1]),
                'lat': float(point[0]),
            }


@yt.with_context
def distance_reducer(key, recs, context):
    log_coords = []
    for rec in recs:
        if context.table_index == 0:
            log_coords.append((rec['lon'], rec['lat']))
        elif len(log_coords) == 0:
            break
        else:
            for original_coord in log_coords:
                dist = haversine(original_coord, (rec['lon'], rec['lat']))

                yield {
                    'external_id': key['external_id'],
                    'type': rec['type'],
                    'distance': dist
                }


def single_reducer(key, recs):
    for rec in recs:
        yield rec
        break


class DistancesFilterReducer:
    def __init__(self, max_n):
        self.max_n = max_n

    def __call__(self, key, recs):
        for i, rec in enumerate(recs):
            if i >= self.max_n:
                break
            yield rec


class DistanceToFMapper:
    def __init__(self, distance_thresh):
        self.distance_thresh = distance_thresh

    def distances2features(self, distances, distance_thresh=None, eps=1e-2):
        """
            Transforms given distances to geo-features.
            100 meters is threshold, which represents whether distance is acceptable.
            Sigmoid like function is applied then and distances lower than threshold would
            rapidly increase feature value up to 1.

            2 / (1 + exp(-2z)) -1 function is used in terms of having thresh / distance (which is positive)
            mapped to (0, 1) space.

            Transform examples:
                distances 9 -> feature 0.9999999995532738
                distances 90 -> feature 0.8044548002984016
                distances 300 -> feature 0.32151273753163445
                distances 1500 -> feature 0.06656807650226271
                distances 5000 -> feature 0.01999733375993107
                distances 30000 -> feature 0.0033333209877091097
        """
        distance_thresh = distance_thresh or self.distance_thresh

        reversed_distances = distance_thresh / (distances + eps)
        features = 2. / (1 + np.exp(-2 * reversed_distances)) - 1

        return features

    def distances2features_binary(self, distance, distance_thresh=None):
        distance_thresh = distance_thresh or self.distance_thresh
        feature = float(distance < distance_thresh)

        return feature

    def __call__(self, rec):
        rec['feature'] = self.distances2features_binary(rec['distance'])
        rec.pop('distance')

        yield rec


class FeaturesCompactReducer:
    """
        Combines flatten features into one binarized string
        Rows are assumed to be sorted by feature value
    """

    def __init__(self, types, max_n, fill_na, sort_order):
        self.types = sorted(set(types))
        self.max_n = max_n
        self.fill_na = fill_na
        self.sort_order = sort_order

    def __call__(self, key, recs):
        features_d = defaultdict(list)

        for rec in recs:
            if rec['type'] in self.types:
                features_d[rec['type']].append(rec['feature'])

        features = []
        for t in sorted(self.types):
            real_part = sorted(features_d[t])[::self.sort_order][:self.max_n]
            fill_na_part = [self.fill_na] * (self.max_n - len(real_part))
            features.extend(real_part + fill_na_part)

        yield {
            'external_id': key['external_id'],
            'features': array_tostring(features)
        }


@yt.with_context
class BinaryFeaturesReducer:
    def __init__(self, types, max_n, fill_na):
        self.types = sorted(set(types))
        self.f_size = len(types) * max_n
        self.fill_na = fill_na

    def __call__(self, key, recs, context):
        original_features = np.full(self.f_size, self.fill_na, float)
        for rec in recs:
            if context.table_index == 0:
                original_features = array_fromstring(rec['features'])
            else:
                binary_features = np.zeros(len(self.types))

                for i, addr_type in enumerate(sorted(self.types)):
                    binary_features[i] = int(bool(rec[addr_type] and rec[addr_type].strip()))

                all_features = np.hstack((original_features, binary_features))
                yield {
                    'external_id': key['external_id'],
                    'features': array_tostring(all_features)
                }
