# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from collections import defaultdict
import six
from six.moves import map
try:
    import pathlib
except ImportError:
    import pathlib2 as pathlib

from travel.library.python.dicts.station_repository import StationRepository
from travel.library.python.dicts.station_to_settlement_repository import StationToSettlementRepository

import config
from util.point_key import PointKey
from util.lazy_setuper import LazySetuper


logger = logging.getLogger(__name__)


class PointRelationsProvider(LazySetuper):

    def __init__(self):
        super(PointRelationsProvider, self).__init__()
        self._children = {}
        self._parents = {}

    def _setup(self, station_repo=None, station2settlement_repo=None):
        children = defaultdict(set)
        parents = defaultdict(set)

        if station_repo is None:
            station_repo = StationRepository()
            station_repo.load_from_file(
                six.text_type(pathlib.Path(config.RASP_DATA_PATH) / "station.bin"))

        for station in station_repo.itervalues():
            if station.SettlementId:
                children[PointKey.settlement(station.SettlementId)].add(PointKey.station(station.Id))

        if station2settlement_repo is None:
            station2settlement_repo = StationToSettlementRepository()
            station2settlement_repo.load_from_file(
                six.text_type(pathlib.Path(config.RASP_DATA_PATH) / "settlement2station.bin"))

        for station_to_settlement in station2settlement_repo.itervalues():
            children[
                PointKey.settlement(station_to_settlement.SettlementId)
            ].add(
                PointKey.station(station_to_settlement.StationId)
            )

        for settlement_key, station_keys in six.iteritems(children):
            for station_key in station_keys:
                parents[station_key].add(settlement_key)

        self._children = dict(children)
        self._parents = dict(parents)

    @staticmethod
    def _parse_point_key(raw_point_key):
        try:
            return PointKey.load(raw_point_key)
        except ValueError:
            logger.warning('Invalid point_key: %s', raw_point_key)
            return None

    @LazySetuper.setup_required
    def get_children(self, raw_point_key):
        return map(six.text_type, self._children.get(self._parse_point_key(raw_point_key), ()))

    @LazySetuper.setup_required
    def get_parents(self, raw_point_key):
        return map(six.text_type, self._parents.get(self._parse_point_key(raw_point_key), ()))


point_relations_provider = PointRelationsProvider()
