# -*- coding: utf-8 -*-

import hashlib
import logging
import struct
from datetime import datetime, timedelta

import numpy as np
from django.conf import settings
from django.db import transaction
from typing import List, Type

from common.models.geo import Station
from common.models.schedule import RThread, RTStation
from common.models.timestamp import Timestamp
from common.models_utils import model_iter
from common.utils.date import RunMask, timedelta2minutes, MSK_TZ
from mapping.drawers import TrainPathDrawer, LimePathDrawer, StationCoordsStorage
from mapping.generators.utils import FLOAT_PRECISION
from mapping.models import LiveBus, LiveMapObject, RouteMapBlacklist, Train


log = logging.getLogger(__name__)


class MapError(Exception):
    pass


class Processor(object):
    def __init__(self, now_aware, thread, path, drawer, live_object_class, expires_at):
        self.thread = thread
        self.path = path
        self.drawer = drawer
        self.live_object_class = live_object_class

        # Сейчас
        self.now_aware = now_aware

        # Будущее сейчас (данные обновляются раз в 10 минут)
        self.future_now = expires_at

        # Среднее положение паровозика, для bounds
        self.middle = self.now_aware + timedelta(minutes=settings.MAPPING_SERVER_EXPIRE / 2)

    def process(self):
        rtstations = self.path or list(self.thread.path)

        if len(rtstations) < 2:
            # Пустая нитка, пропускаем
            return

        start_date = self.now_aware.astimezone(self.thread.pytz).date()

        first_run = RunMask.first_run(self.thread.year_days, start_date)
        if not first_run:
            return

        naive_start_dt = datetime.combine(first_run, self.thread.tz_start_time)
        start_dt = self.thread.pytz.localize(naive_start_dt)

        # Cтанции и время, когда на них поезд должен находиться
        stations = []

        for rtstation in rtstations:
            if rtstation.tz_arrival == rtstation.tz_departure:
                # Станция без остановки, не учитываем время
                # Нужна для траектории
                stations.append((rtstation.station_id, None))

            elif rtstation.tz_arrival is None:
                # Начало движения
                departure = timedelta2minutes(rtstation.get_departure_dt(naive_start_dt) - start_dt)

                stations.append((rtstation.station_id, departure))

            elif rtstation.tz_departure is None:
                # Конец движения
                arrival = timedelta2minutes(rtstation.get_arrival_dt(naive_start_dt) - start_dt)

                stations.append((rtstation.station_id, arrival))

            else:
                # Остановка, пишем со временем
                arrival = timedelta2minutes(rtstation.get_arrival_dt(naive_start_dt) - start_dt)
                departure = timedelta2minutes(rtstation.get_departure_dt(naive_start_dt) - start_dt)

                stations.append((rtstation.station_id, arrival))
                stations.append((rtstation.station_id, departure))

        # Конечная, нужна для расчета прибытия
        last_rtstation = rtstations[-1]

        # Добавляем 10-минутные стоянки в начале и конце (RASP-3923)
        stations.insert(0, (stations[0][0], stations[0][1] - 10))
        stations.append((stations[-1][0], stations[-1][1] + 10))

        coords, times = self.draw_map(stations)

        for start_dt in self.ranges(first_run, stations):
            # Переводим в минуты от начала пути
            now = timedelta2minutes(self.now_aware - start_dt)

            future_now = timedelta2minutes(self.future_now - start_dt)

            # Находим индексы отрезков, пересекающих интересующий нас интервал
            left_index = np.searchsorted(times, now, side='right')
            right_index = np.searchsorted(times, future_now, side='left')

            chunk_coords = coords[max(left_index - 1, 0):right_index + 1]
            chunk_times = times[max(left_index - 1, 0):right_index + 1]

            live_object = self.live_object_class()

            live_object.thread = self.thread

            naive_start_dt = start_dt.astimezone(self.thread.pytz).replace(tzinfo=None)

            live_object.departure = start_dt.astimezone(MSK_TZ).replace(tzinfo=None)
            live_object.arrival = last_rtstation.get_arrival_dt(naive_start_dt, MSK_TZ).replace(tzinfo=None)

            live_object.set_data((chunk_times - now) * 60, chunk_coords)

            # Находим среднюю точку
            live_object.lng, live_object.lat = live_object.current_position(
                (self.middle - self.now_aware).total_seconds())

            live_object.station_from = Station(id=stations[0][0])
            live_object.station_to = Station(id=stations[-1][0])

            yield live_object

    def draw_map(self, stations):
        """Отрисовка маршрута"""

        coords = []
        lengths = []

        prev_station_id = None

        # Времена и индексы для интерполяции
        time_indices = []
        time_values = []

        current_index = 0

        for station_id, time in stations:
            if prev_station_id is not None and prev_station_id != station_id:
                # Движение от станции к станции
                curve_coords, curve_lengths = self.get_segment(prev_station_id, station_id)

                # Первый элемент пропускаем, он должен по-идее совпадать с предыдущей станцией
                coords.append(curve_coords[1:])
                lengths.append(curve_lengths[1:])

                current_index += len(curve_coords) - 1
            else:
                # Иначе просто положение
                position = self.drawer.get_station_coords(station_id)

                if not position:
                    log.warning(u'Station %s not found', station_id)
                    raise MapError

                coords.append(np.array([position]))
                lengths.append(np.array([0.0]))

                current_index += 1

            if time is not None:
                time_indices.append(current_index - 1)
                time_values.append(time)

            prev_station_id = station_id

        coords = np.concatenate(coords)
        lengths = np.concatenate(lengths)

        # Превращаем длины в расстояния от начала пути
        positions = lengths.cumsum()

        # Интерполируем времена прохождения точек по расстоянию от начала пути
        times = np.empty_like(positions)

        time_values = np.array(time_values)

        for i in range(len(time_indices) - 1):
            i1 = time_indices[i]
            i2 = time_indices[i + 1]

            xp = positions[[i1, i2]]
            fp = time_values[[i, i + 1]]

            if np.all(np.diff(xp) > 0):
                times[i1:i2] = np.interp(positions[i1:i2], xp, fp)
            else:
                times[i1:i2] = fp[0]

        # Последнее значение должно быть временем последней станции
        times[-1] = stations[-1][1]

        return coords, times

    def get_segment(self, station_from, station_to):
        segment = self.drawer.get_arc(station_from, station_to)

        if not segment:
            log.warning("Segment %s - %s not found" % (
                station_from, station_to
            ))
            raise MapError

        # Дельты в плавающих числах
        curve_deltas = np.frombuffer(segment[0], dtype=np.dtype('<i8')).reshape(-1, 2) / FLOAT_PRECISION
        # Координаты
        curve_coords = curve_deltas.cumsum(axis=0)

        # Длины отрезков кривой
        curve_lengths = np.sqrt(np.square(curve_deltas).sum(axis=1))

        return curve_coords, curve_lengths

    def ranges(self, first_run, stations):
        # Здесь из границ нужно получить время в минутах от начала маршрута
        # Если поезд идет несколько дней, то получится несколько наборов

        # Дни, в которые отправляется нитка с начальной станции
        start_days = set(RunMask(self.thread.year_days, today=first_run).dates())

        # Время появления и исчезновения поезда (от начала движения)
        appear_delta = timedelta(minutes=stations[0][1])
        vanish_delta = timedelta(minutes=stations[-1][1])

        # Время отправления маршрута для расчетов
        # Если оно окажется в будущем от правой границы,
        # то это не страшно, потому-что оно отсеется ниже
        start_date = (self.future_now - appear_delta).astimezone(self.thread.pytz).date()
        naive_start_dt = datetime.combine(start_date, self.thread.tz_start_time)
        start_dt = self.thread.pytz.localize(naive_start_dt)

        while True:
            appear = start_dt + appear_delta
            vanish = start_dt + vanish_delta

            # Если исчезает раньше текущего момента, то дальше смотреть не надо
            if vanish < self.now_aware:
                break

            # Если появляется до следующего пересчета и
            # По маске хождения отправляется в этот день, то всё ок
            if appear <= self.future_now and start_dt.date() in start_days:
                yield start_dt

            # Смотрим предыдущий день
            start_dt -= timedelta(days=1)


def generate_train_paths(now, expires_at):
    TrainPathDrawer.storage.preload()

    threads = RThread.objects.filter(
        t_type__code__in=['train', 'suburban'],
        type__code__in=['basic', 'change'],
        route__hidden=False,
    ).exclude(
        uid__startswith='MCZK_',
    )

    def generate():
        threads_chunks = model_iter(threads, chunksize=50000, in_chunks=True)

        for threads_chunk in threads_chunks:
            live_objects = []

            for thread in threads_chunk:
                if not RouteMapBlacklist.is_thread_mapped(thread):
                    # Не отображать на картах
                    continue

                drawer = TrainPathDrawer(thread)

                try:
                    live_objects.extend(Processor(now, thread, None, drawer, Train, expires_at).process())
                except MapError:
                    pass

            yield live_objects

    save_objects_chunks(generate(), Train, now, expires_at)


def generate_bus_paths(now, expires_at):
    LimePathDrawer.storage.preload()
    StationCoordsStorage.preload()

    threads = RThread.objects.filter(
        t_type__code='bus',
        route__hidden=False,
    )

    def generate():
        threads_chunks = model_iter(threads, chunksize=50000, in_chunks=True)

        for threads_chunk in threads_chunks:
            paths = {}

            for rts in RTStation.objects.filter(thread__in=threads_chunk):
                paths.setdefault(rts.thread_id, []).append(rts)

            live_objects = []

            for thread in threads_chunk:
                if not RouteMapBlacklist.is_thread_mapped(thread):
                    # Не отображать на картах
                    continue

                drawer = LimePathDrawer()

                path = paths[thread.id]

                try:
                    live_objects.extend(Processor(now, thread, path, drawer, LiveBus, expires_at).process())
                except MapError:
                    pass

            yield live_objects

    save_objects_chunks(generate(), LiveBus, now, expires_at)


def primary_key_from_hash(h, used_keys):
    """32-битный ключ со знаком для сортировки"""

    while True:
        digest = h.digest()

        sort_key = struct.unpack('!i', digest[:4])[0]

        if sort_key not in used_keys:
            used_keys.add(sort_key)
            return sort_key

        # Разрешаем коллизии
        h.update('\xff')


@transaction.atomic
def save_objects_chunks(generator, live_object_class, now, expires_at):
    live_object_class.objects.all().delete()

    for chunk in generator:
        save_objects(chunk, live_object_class)

    Timestamp.set(live_object_class.TIMESTAMP_ID, now.astimezone(MSK_TZ).replace(tzinfo=None))
    Timestamp.set(live_object_class.EXPIRE_TIMESTAMP_ID, expires_at.astimezone(MSK_TZ).replace(tzinfo=None))


def split_by_collision(to_write, live_object_class):
    # type: (List[LiveMapObject], Type[LiveMapObject]) -> (List[LiveMapObject], List[LiveMapObject])

    pks = [o.pk for o in to_write]
    collision_pks = set(live_object_class.objects.filter(pk__in=pks).values_list('pk', flat=True))
    if not collision_pks:
        return to_write, []

    conflicting = []
    non_conflicting = []

    for o in to_write:
        if o.pk in collision_pks:
            conflicting.append(o)
        else:
            non_conflicting.append(o)

    return non_conflicting, conflicting


def save_objects(live_objects, live_object_class):
    used_keys = set()
    to_write = []

    for o in live_objects:
        o._hash = hashlib.md5()
        o._hash.update(o.thread.uid)
        o._hash.update(str(o.departure.date()))

        o.pk = primary_key_from_hash(o._hash, used_keys)
        to_write.append(o)

    while to_write:
        non_conflicting, conflicting = split_by_collision(to_write, live_object_class)
        live_object_class.objects.bulk_create(non_conflicting)

        to_write = []
        for o in conflicting:
            o.pk = primary_key_from_hash(o._hash, used_keys)
            to_write.append(o)
