# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from collections import defaultdict

from common.data_api.baris.instance import baris
from common.models.geo import Station, Station2Settlement
from common.models.schedule import RThread, RTStation
from common.models.transport import TransportType

from travel.rasp.suggests_tasks.suggests.generate import shared_objects
from travel.rasp.suggests_tasks.suggests.generate.utils import generate_parallel, retrieve_ids
from travel.rasp.suggests_tasks.suggests.utils import enumer, split_values, print_run_time


logger = logging.getLogger('generate')


def get_thread_ttypes():
    threads = RThread.objects.exclude(t_type_id=TransportType.PLANE_ID).values_list('id', 't_type_id')
    return {
        thread_id: t_type
        for thread_id, t_type in enumer(threads, each=10 ** 6)
    }


def get_ttypes_for_ids((worker_id, rtstations_ids)):
    threads_ttypes = shared_objects.get_obj('threads_ttypes')

    rtstations = RTStation.objects.filter(id__in=rtstations_ids)
    stations_ttypes = defaultdict(set)
    for thread_id, station_id in enumer(rtstations.values_list('thread_id', 'station_id'), each=10 ** 6):
        t_type = threads_ttypes.get(thread_id)
        if t_type:
            stations_ttypes[station_id].add(t_type)

    return stations_ttypes


def merge_stations_ttypes(stations_ttypes_iter):
    result = defaultdict(set)

    for station_ttypes in stations_ttypes_iter:
        for station_id, t_types in station_ttypes.items():
            result[station_id] |= t_types

    return result


def get_baris_station_ids():
    departure_stations, arrival_stations = baris.get_station_summary()
    return {station_id for station_id in departure_stations | arrival_stations}


def get_settlements_ttypes(stations_ttypes):
    settlements_ttypes = defaultdict(set)

    for station_id, sett_id in Station.objects.values_list('id', 'settlement_id'):
        if sett_id:
            settlements_ttypes[sett_id] |= stations_ttypes[station_id]

    for station_id, sett_id in Station2Settlement.objects.values_list('station_id', 'settlement_id'):
        settlements_ttypes[sett_id] |= stations_ttypes[station_id]

    return settlements_ttypes


def values_to_list(objs):
    for k, v in objs.items():
        objs[k] = list(v)


def get_ttypes(pool_size):
    with print_run_time('generate threads ttypes', logger=logger):
        threads_ttypes = get_thread_ttypes()
        shared_objects.set_objs(threads_ttypes=threads_ttypes)

    with print_run_time('retrieve_ids(RTStation)', logger=logger):
        rts_ids = retrieve_ids(RTStation, exclude_t_type_ids=[TransportType.PLANE_ID])

    with print_run_time('generate stations ttypes from rasp db', logger=logger):
        by_worker = split_values(rts_ids, pool_size)
        t_types_for_ids = generate_parallel(get_ttypes_for_ids, by_worker, pool_size)
        stations_ttypes = merge_stations_ttypes(t_types_for_ids)

    with print_run_time('generate stations ttypes from BARIS', logger=logger):
        baris_station_ids = get_baris_station_ids()
        for station_id in baris_station_ids:
            stations_ttypes[station_id].add(TransportType.PLANE_ID)

    with print_run_time('generate settlements ttypes', logger=logger):
        settlements_ttypes = get_settlements_ttypes(stations_ttypes)

    values_to_list(stations_ttypes)
    values_to_list(settlements_ttypes)

    return stations_ttypes, settlements_ttypes
