# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import itertools

from common.models.geo import Country, Region, Settlement, Station
from common.models.transport import TransportType
from common.utils.iterrecipes import unique_everseen
from geosearch.models import NameSearchIndex
from travel.library.python.tracing.instrumentation import traced_function
from travel.rasp.wizards.proxy_api.lib.station.settlement_stations_cache import SettlementStationsCache


def _find_stations(name):
    return tuple(unique_everseen(itertools.chain(
        NameSearchIndex.find('exact', Station, name).filtered(lambda s: not s.hidden),
        Station.code_manager.get_list_by_code(name)
    )))


def _find_settlements(name):
    return tuple(unique_everseen(itertools.chain(
        NameSearchIndex.find('exact', Settlement, name).filtered(lambda s: not s.hidden),
        Settlement.hidden_manager.get_list(iata__iexact=name)
    )))


def _iter_t_type_ids(transport_code):
    if transport_code is None:
        yield None
    else:
        t_type_id = TransportType.objects.get(code=transport_code).id
        yield t_type_id

        if t_type_id == TransportType.SUBURBAN_ID:
            yield TransportType.TRAIN_ID


def _filter_stations_by_t_type_ids(stations, t_type_ids):
    t_type_ids = set(t_type_ids)
    return tuple(s for s in stations if s.t_type_id in t_type_ids)


def _get_by_key(items, key):
    filtered = tuple(itertools.islice(itertools.ifilter(key, items), 2))
    if len(filtered) == 1:
        return filtered[0]


def get_region_and_country_ids(geoid):
    try:
        settlement = Settlement.objects.get(_geo_id=geoid)
    except Settlement.DoesNotExist:
        pass
    else:
        return settlement.region_id, settlement.country_id

    try:
        region = Region.objects.get(_geo_id=geoid)
    except Region.DoesNotExist:
        pass
    else:
        return region.id, region.country_id

    try:
        country = Country.objects.get(_geo_id=geoid)
    except Country.DoesNotExist:
        pass
    else:
        return None, country.id

    return None, None


def _iter_point_filters(client_geoid):
    if client_geoid:
        region_id, country_id = get_region_and_country_ids(client_geoid)

        if region_id is not None:
            yield (lambda s: s.region_id == region_id)

        if country_id is not None:
            yield (lambda s: s.country_id == country_id)

    yield (lambda s: True)


def find_aeroexpress_station(name):
    stations = _find_stations(name)
    return _get_by_key(stations, lambda s: s.t_type_id == TransportType.PLANE_ID)


@traced_function
def find_station(name, transport_code=None, client_geoid=None):
    stations = _find_stations(name)
    settlements = _find_settlements(name)
    t_type_ids = tuple(_iter_t_type_ids(transport_code))

    if transport_code:
        stations = _filter_stations_by_t_type_ids(stations, t_type_ids)

    for f in _iter_point_filters(client_geoid):
        settlement = _get_by_key(settlements, f)
        if settlement:
            for t_type_id in t_type_ids:
                station = SettlementStationsCache.get(settlement, t_type_id)
                if station:
                    return station

        stations = filter(f, stations)
        if stations:
            return min(stations, key=lambda s: (s.majority_id, s.title))
