# -*- coding: utf-8 -*-

from itertools import imap
from collections import defaultdict
import time
import logging
import requests

from yandex.maps.proto.common2 import response_pb2
from yandex.maps.proto.masstransit.trajectory_pb2 import TRAJECTORY_METADATA, TRAJECTORY_SEGMENT_METADATA
from yandex.maps.geolib3.proto import decode_proto_polyline
from yandex.maps import geolib3


class PredictionError(RuntimeError):
    pass


class Segment(object):
    def __init__(self, geo_object):
        metadata = geo_object.metadata[0].Extensions[TRAJECTORY_SEGMENT_METADATA]
        self.time = metadata.time
        self.duration = metadata.duration
        self.geometry = decode_proto_polyline(geo_object.geometry[0].polyline)


class Vehicle(object):
    def __init__(self, geo_object):
        self.id = geo_object.metadata[0].Extensions[TRAJECTORY_METADATA].id
        self.trajectory = map(Segment, geo_object.geo_object)

    def predict_location(self, timestamp):
        """Return point where vehicle is supposed to be at given moment"""
        segment = None
        seg_start_time = self.trajectory[0].time
        if timestamp < seg_start_time:
            raise PredictionError("Trying to predict past")
        for seg in self.trajectory:
            if timestamp <= seg_start_time + seg.duration:
                segment = seg
                break
            seg_start_time += seg.duration
        if segment is None:
            raise PredictionError("Not enough data for prediction")
        return pos_in_polyline(segment.geometry, 1. * (timestamp - seg_start_time) / segment.duration)


def send_request(config, **kwargs):
    url = config['request_url'].format(**kwargs)
    headers = {'HOST': config['host'], 'Accept': 'application/x-protobuf'}
    response = requests.get(url, headers=headers)
    response.raise_for_status()
    return response.content


def fetch_vehicles(config, **kwargs):
    response = response_pb2.Response()
    response.ParseFromString(send_request(config, **kwargs))
    transports = response.reply.geo_object
    vehicles = imap(Vehicle, transports)
    return {v.id: v for v in vehicles}


def segment_point_by_position(segment, t):
    return geolib3.Point2(
        segment.start.x * (1 - t) + segment.end.x * t,
        segment.start.y * (1 - t) + segment.end.y * t
    )


def pos_in_polyline(polyline, t):
    """Return point of polyline which divides it in the ratio of t : (1 - t)."""
    EPS = 1e-4
    assert -EPS <= t <= 1 + EPS, "t must be in range [0, 1]"
    d = t * polyline.geolength()
    if d < EPS:
        return polyline[0]

    prefix_length = 0.
    for segment in polyline.segments():
        segment_length = segment.geolength()
        if d <= prefix_length + segment_length:
            if segment_length < EPS:
                logging.warning("Polyline contains segment with length %s", segment_length)
                return segment.start
            ratio = (d - prefix_length) / segment_length
            return segment_point_by_position(segment, ratio)
        prefix_length += segment_length
    return polyline[len(polyline) - 1]


def displacement(old, new):
    prediction = old.predict_location(new.trajectory[0].time)
    pos = new.trajectory[0].geometry[0]
    return geolib3.geodistance(pos, prediction)


def calculate_jumps(old_vehicles, new_vehicles, jumps, measurments, threshold):
    for v_id in old_vehicles.viewkeys() & new_vehicles.viewkeys():
        try:
            d = displacement(old_vehicles[v_id], new_vehicles[v_id])
            if d > threshold:
                jumps[v_id] += 1
        except PredictionError as ex:
            logging.warning('%s', ex)
        measurments[v_id] += 1


def measure(config, region):
    lat, lon, w, h = region['lat'], region['lon'], region['w'], region['h']
    step = config['step']
    jumps = defaultdict(int)
    measurments = defaultdict(int)
    old_vehicles = {}

    success_at_prev_iteration = False
    successful_pairs = 0
    consecutive_failures = 0
    while successful_pairs < config['iterations']:
        time.sleep(config['sleep_time'])

        try:
            shift = successful_pairs + int(success_at_prev_iteration)
            new_vehicles = fetch_vehicles(config, lon=lon + step * shift, lat=lat + step * shift, w=w, h=h)
        except requests.exceptions.RequestException as ex:
            success_at_prev_iteration = False
            consecutive_failures += 1
            if consecutive_failures > config['max_failures']:
                logging.error('Failed to get response from server. Aborting.')
                raise
            else:
                logging.warning('HTTP Error %s, retrying.', str(ex))
                continue
        consecutive_failures = 0

        if success_at_prev_iteration:
            successful_pairs += 1
            calculate_jumps(old_vehicles, new_vehicles, jumps, measurments, config['threshold'])
        success_at_prev_iteration = True
        old_vehicles = new_vehicles

    return jumps, measurments
