# -*- coding: utf-8 -*-

import logging
from itertools import groupby

from common.models.compatibility import RouteUidMap, RThreadUidMap
from common.models.geo import Station
from common.models.schedule import Route, RThread, RTStation
from common.models.tariffs import ThreadTariff
from common.models_utils import fetch_related
from common.utils.iterrecipes import chunker
from common.utils.progress import PercentageStatus
from travel.rasp.admin.lib import tmpfiles
from travel.rasp.admin.lib.mysqlutils import MysqlModelUpdater, MysqlModelLoader


log = logging.getLogger(__name__)


def fast_get_route_by_ids_with_threads(route_ids):
    """
    Функция по id маршрутов достает маршруты c нитками.
    Все нитки маршрута складываются в route.threads.
    """

    route_by_ids = {route.id: route for route in Route.objects.filter(id__in=route_ids)}

    for route in route_by_ids.itervalues():
        route.threads = []

    threads = RThread.objects.filter(route_id__in=route_ids)

    for thread in threads:
        route = route_by_ids[thread.route_id]
        thread.route = route
        route.threads.append(thread)

    return route_by_ids


def fast_get_route_by_ids_with_threads_and_rtstations(route_ids):
    """
    Функция по id маршрутов достает маршруты c нитками и остановками.
    Все нитки маршрута складываются в route.threads.
    Все остановки складываются в thread.rtstations.
    """

    route_by_ids = {route.id: route for route in Route.objects.filter(id__in=route_ids)}

    thread_ids = list(RThread.objects.filter(route_id__in=route_ids).values_list('id', flat=True))
    thread_by_ids = {thread.id: thread for thread in RThread.objects.filter(id__in=thread_ids)}

    rtstations = RTStation.objects.filter(thread_id__in=thread_ids).order_by('thread', 'id')

    for thread_id, thread_rtstatoions in groupby(rtstations, lambda rts: rts.thread_id):
        thread_by_ids[thread_id].rtstations = list(thread_rtstatoions)

    for route in route_by_ids.itervalues():
        route.threads = []

    for thread in thread_by_ids.itervalues():
        route = route_by_ids[thread.route_id]
        thread.route = route
        route.threads.append(thread)

    return route_by_ids


def fast_get_threads_with_rtstations(thread_query, fetch_stations=False, fetch_settlements=False):
    """
    Функция достает нитки c остановками и опционально станциями.
    Все остановки складываются в thread.rtstations.
    Станции вытаскиваются из базы и складываются в rtstation.station
    """

    thread_ids = list(thread_query.order_by().values_list('id', flat=True))
    thread_by_ids = {thread.id: thread for thread in RThread.objects.filter(id__in=thread_ids)}

    rtstations = RTStation.objects.filter(thread_id__in=thread_ids).order_by('thread', 'id')

    for thread_id, thread_rtstatoions in groupby(rtstations, lambda rts: rts.thread_id):
        thread = thread_by_ids[thread_id]
        thread.rtstations = list(thread_rtstatoions)
        for rts in thread.rtstations:
            rts.thread = thread

    threads = thread_by_ids.values()

    if fetch_stations:
        station_ids = {stop.station_id for thread in threads for stop in thread.rtstations}
        station_by_ids = {s.id: s for s in Station.objects.filter(id__in=station_ids)}
        if fetch_settlements:
            fetch_related(station_by_ids.values(), 'settlement', model=Station)
        for thread in threads:
            for stop in thread.rtstations:
                stop.station = station_by_ids[stop.station_id]

    return threads


@tmpfiles.clean_temp
def update_thread_tariff_uids(old_uid2new_uid_map):
    chunk_size = 1000
    log.info(u'Обновляем тарифы')
    status = PercentageStatus(len(old_uid2new_uid_map), log)
    with MysqlModelUpdater(ThreadTariff, tmpfiles.get_tmp_dir('update_thread_tariff_uids'),
                           fields=('thread_uid',)) as updater:
        for old_uids in chunker(old_uid2new_uid_map.iterkeys(), chunk_size):
            for tariff in ThreadTariff.objects.filter(thread_uid__in=old_uids).only('id', 'thread_uid'):
                tariff.thread_uid = old_uid2new_uid_map[tariff.thread_uid]
                updater.add(tariff)

            status.step(len(old_uids))


@tmpfiles.clean_temp
def store_old2new_uid_maps(changed_route_uids, changed_thread_uids):
    with MysqlModelLoader(RouteUidMap, tmpfiles.get_tmp_dir('route_uid_map')) as loader:
        for old_uid, new_uid in changed_route_uids.iteritems():
            loader.add(RouteUidMap(old_uid=old_uid, new_uid=new_uid))

    with MysqlModelLoader(RThreadUidMap, tmpfiles.get_tmp_dir('thread_uid_map')) as loader:
        for old_uid, new_uid in changed_thread_uids.iteritems():
            loader.add(RThreadUidMap(old_uid=old_uid, new_uid=new_uid))
