# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

from collections import defaultdict
from contextlib2 import contextmanager

from common.models.geo import StationCode, Settlement, Station, StationMajority
from common.models_utils.geo import Point
from travel.rasp.wizards.train_wizard_api.lib.train_city_provider import TrainCityProvider, train_city_provider  # noqa: UnusedImport


class ExpressSystemProvider(object):
    def __init__(self, train_city_provider):
        # type: (TrainCityProvider) -> None
        self._cache = None
        self._city_express_code_to_stations_express_codes = None
        self._train_city_provider = train_city_provider

    @contextmanager
    def using_precache(self):
        if not self.is_precached():
            self.build_cache()
            try:
                yield
            finally:
                self.clean()
        else:
            yield

    def is_precached(self):
        return self._cache is not None

    def clean(self):
        self._cache = None
        self._city_express_code_to_stations_express_codes = None

    def build_cache(self):
        query = StationCode.objects.filter(
            system__code='express',
        ).values_list(
            'station_id',
            'code',
        )

        self._cache = {
            station_id: int(code) for station_id, code in query
        }

        fake_settlements_id = set(Station.objects.filter(
            majority_id=StationMajority.EXPRESS_FAKE_ID).values_list('settlement_id', flat=True)
        )

        city_express_code_to_stations_express_codes = defaultdict(set)
        for station_id, settlement_id in Station.objects.filter(
                settlement_id__in=fake_settlements_id
        ).values_list('id', 'settlement_id'):
            if station_id in self._cache:
                city_express_code_to_stations_express_codes[settlement_id].add(self._cache[station_id])

        self._city_express_code_to_stations_express_codes = dict(city_express_code_to_stations_express_codes)

    def find_express_id(self, point_key):
        Model, pk = Point.parse_key(point_key)
        pk = int(pk)

        if Model is Settlement:
            station_id = self._train_city_provider.find_train_station_id(pk)
            return self._cache.get(station_id)
        if Model is Station:
            return self._cache.get(pk)

        raise RuntimeError('PointType is not supported')

    def find_related_express_ids(self, point_key):
        result = set()

        express_id = self.find_express_id(point_key)
        if express_id:
            result.add(express_id)

        Model, pk = Point.parse_key(point_key)
        pk = int(pk)

        if Model is Settlement:
            for express_id in self._city_express_code_to_stations_express_codes.get(pk, []):
                result.add(express_id)

        return result


express_system_provider = ExpressSystemProvider(
    train_city_provider=train_city_provider
)
