# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

from collections import Counter, namedtuple

from common.models.geo import Direction
from common.precache.storage import BasePrecacheStorage
from travel.rasp.wizards.suburban_wizard_api.lib.schedule_cache import schedule_cache
from travel.rasp.wizards.wizard_lib.station.direction_type import DirectionType


SuburbanDirection = namedtuple('SuburbanDirection', 'type, code')

DEPARTURE_DIRECTION = SuburbanDirection(DirectionType.DEPARTURE, 'departure')
ARRIVAL_DIRECTION = SuburbanDirection(DirectionType.ARRIVAL, 'arrival')
ALL_DIRECTION = SuburbanDirection(DirectionType.ALL, 'all')


def _make_departure_direction_factory(use_direction, stops, directions):
    if not use_direction:
        return

    direction_ids = {
        stop.departure_direction_id for thread_stops in stops.itervalues() for stop in thread_stops
    } - {None}
    return (
        lambda stop: (
            ARRIVAL_DIRECTION if _is_last_stop(stop) else
            DEPARTURE_DIRECTION if stop.departure_direction_id is None else
            SuburbanDirection(DirectionType.DIR, directions[stop.departure_direction_id])
        )
    ) if use_direction == 'dir' and len(direction_ids) > 1 else (
        lambda stop: (
            ARRIVAL_DIRECTION if _is_last_stop(stop) else
            None if stop.departure_subdir is None else
            SuburbanDirection(DirectionType.SUBDIR, stop.departure_subdir)
        )
    )


def _is_last_stop(stop):
    return stop.tz_departure is None


class SuburbanDirectionsCache(BasePrecacheStorage):
    @classmethod
    def build_cache(cls):
        directions = {
            direction_id: direction_code
            for direction_id, direction_code in Direction.objects.values_list('id', 'code')
        }
        stop_directions = {}
        station_directions = {}

        for station_id, stops in schedule_cache._station_stops_cache.iteritems():
            station = schedule_cache._stations_cache[station_id]
            direction_factory = _make_departure_direction_factory(station.use_direction, stops, directions)
            if direction_factory is None:
                continue

            all_direction_count = 0
            direction_counts = Counter()
            for thread_stops in stops.itervalues():
                for stop in thread_stops:
                    if stop.in_station_schedule:
                        all_direction_count += 1
                        stop_direction = direction_factory(stop)
                        if stop_direction is not None:
                            stop_directions[stop.id] = stop_direction
                            direction_counts[stop_direction] += 1

            if direction_counts:
                directions_with_counts = direction_counts.most_common()
                if len(direction_counts) > 1 or all_direction_count > sum(direction_counts.itervalues()):
                    directions_with_counts += [(ALL_DIRECTION, all_direction_count)]

                station_directions[station_id] = directions_with_counts

        return stop_directions, station_directions

    @classmethod
    def list_directions_with_counts(cls, station):
        _stop_directions, station_directions = cls._cache
        return station_directions.get(station.id, [])

    @classmethod
    def iter_raw_segments_with_directions(cls, raw_segments):
        stop_directions, _station_directions = cls._cache
        for raw_segment in raw_segments:
            event_stop = raw_segment.event_stop
            yield raw_segment, stop_directions.get(event_stop.id, ALL_DIRECTION)
