from datetime import datetime, timedelta

import travel.rasp.pathfinder_maps.maps_protos.common2.response_pb2 as response_pb2
import travel.rasp.pathfinder_maps.maps_protos.masstransit.route_pb2 as route_pb2
import travel.rasp.pathfinder_maps.maps_protos.masstransit.section_pb2 as section_pb2
import travel.rasp.pathfinder_maps.maps_protos.pm_protos.rasp_response_pb2 as rasp_response_pb2
import travel.rasp.pathfinder_maps.maps_protos.pm_protos.rasp_section_pb2 as rasp_section_pb2

from travel.rasp.pathfinder_maps.const import TTYPE, UTC_TZ
from travel.rasp.pathfinder_maps.models.route import Route
from travel.rasp.pathfinder_maps.utils import get_point_with_prefix, seconds_to_duration_text


def add_link(response, rasp_link, station_from_id, station_to_id, dtm):
    departure_date = dtm.strftime('%d-%m-%Y')
    link = f'{rasp_link}/?fromId=c{station_from_id}&toId=c{station_to_id}&when={departure_date}'
    reply_meta = response.reply.metadata.add().Extensions[rasp_response_pb2.RASP_RESPONSE_METADATA]
    reply_meta.rasp_link = link
    reply_meta.button_text = "Билеты на Яндекс.Расписаниях"


def build_rasp_section(section, route: Route, rasp_link: str):
    departure_date = route.departure_datetime.strftime('%d-%m-%Y')
    ttype = TTYPE(route.thread_info[2]).rasp_name
    point_from, point_to = get_point_with_prefix(route.departure_station), get_point_with_prefix(route.arrival_station)
    link = f'{rasp_link}/{ttype}/?fromId={point_from}&toId={point_to}&when={departure_date}'
    section_meta = section.metadata.add().Extensions[rasp_section_pb2.RASP_SECTION_METADATA]
    section_meta.rasp_link = link
    section_meta.polling_key = route.polling_key


def deltas_from_coords(coords):
    return [y - x for x, y in zip(coords, coords[1:])]


def form_polyline(points):
    lons, lats = zip(*[(int(point.lon * 1e6), int(point.lat * 1e6)) for point in points])
    lons_first = lons[0]
    lats_first = lats[0]
    lons_deltas = deltas_from_coords(lons)
    lats_deltas = deltas_from_coords(lats)
    return lons_first, lons_deltas, lats_first, lats_deltas


def get_arrival_time_from_geo_object(geo_object):
    arrival_time = geo_object.metadata[0].Extensions[route_pb2.ROUTE_METADATA].estimation.arrival_time.value
    return datetime.fromtimestamp(arrival_time, UTC_TZ)


def get_departure_time_from_geo_object(geo_object):
    departure_time = geo_object.metadata[0].Extensions[route_pb2.ROUTE_METADATA].estimation.departure_time.value
    return datetime.fromtimestamp(departure_time, UTC_TZ)


def get_empty_result():
    return response_pb2.Response()


def get_earliest_route(routes):
    return sorted(routes, key=lambda x: get_arrival_time_from_geo_object(x[0].reply.geo_object[-1]))[0]


def get_fastest_route(routes):
    return sorted(routes, key=lambda x: get_travel_time_from_route(x[0]))[0]


def get_travel_time_from_route(route):
    arrival_time = get_arrival_time_from_geo_object(route.reply.geo_object[-1])
    departure_time = get_departure_time_from_geo_object(route.reply.geo_object[0])
    return arrival_time - departure_time


def fix_estimation(meta, delta):
    old_dt = int(meta.value)
    new_dt = old_dt - delta
    meta.value = int(new_dt)
    meta.text = (datetime.fromtimestamp(new_dt, UTC_TZ) + timedelta(seconds=meta.tz_offset)).strftime('%H:%M')


def inject_wait_section_before_start(builded_variant, dtm):
    route = builded_variant.reply.geo_object[0]
    sections_to_fix = []

    seconds_before_transport = 0
    for section in route.geo_object:
        if not len(section.metadata) or not section.metadata[0].HasExtension(section_pb2.SECTION_METADATA):
            continue
        section_meta = section.metadata[0].Extensions[section_pb2.SECTION_METADATA]
        if not section_meta.transport:
            seconds_before_transport += section_meta.weight.time.value
            if section_meta.estimation.IsInitialized():
                sections_to_fix.append(section)
            continue
        segment_dt = datetime.fromtimestamp(section_meta.estimation.departure_time.value, UTC_TZ)
        wait_seconds = int((segment_dt - dtm - timedelta(seconds=seconds_before_transport)).total_seconds())
        break

    if not sections_to_fix:
        return

    wait_section = sections_to_fix.pop()
    sections_to_fix = sections_to_fix[::-1]

    wait_meta = wait_section.metadata[0].Extensions[section_pb2.SECTION_METADATA]

    wait_meta.weight.time.value += wait_seconds
    wait_meta.weight.time.text = seconds_to_duration_text(wait_meta.weight.time.value)
    fix_estimation(wait_meta.estimation.departure_time, wait_seconds)
    last_departure = wait_meta.estimation.departure_time

    for section in sections_to_fix:
        section_meta = section.metadata[0].Extensions[section_pb2.SECTION_METADATA]
        fix_estimation(section_meta.estimation.departure_time, wait_seconds)
        fix_estimation(section_meta.estimation.arrival_time, wait_seconds)
        last_departure = section_meta.estimation.departure_time
    route_meta = route.metadata[0].Extensions[route_pb2.ROUTE_METADATA]
    route_meta.estimation.departure_time.CopyFrom(last_departure)
    route_meta.weight.time.value = route_meta.estimation.arrival_time.value - route_meta.estimation.departure_time.value
    route_meta.weight.time.text = seconds_to_duration_text(route_meta.weight.time.value)


def join_bounds(bounds, bounding_box):
    lc_lons, lc_lats, uc_lons, uc_lats = zip(*[
        (bound.lower_corner.lon, bound.lower_corner.lat, bound.upper_corner.lon, bound.upper_corner.lat)
        for bound in bounds
    ])
    bounding_box.lower_corner.lon = min(lc_lons)
    bounding_box.lower_corner.lat = min(lc_lats)
    bounding_box.upper_corner.lon = max(uc_lons)
    bounding_box.upper_corner.lat = max(uc_lats)
