# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

from datetime import date, datetime
from typing import List, Optional

from common.apps.suburban_events.utils import (
    collect_rts_by_threads, collect_threads_by_number, get_threads_suburban_keys, ThreadEventsTypeCodes, ThreadKey
)
from common.models.geo import Station
from common.models.schedule import RTStation, RThread
from common.models.transport import TransportType
from common.utils.date import MSK_TZ


def calculate_cancelled_stations(event, log):
    # type: (CppkEvent) -> None

    # Если это отмена отмены - оставляем пустой список отменённых станций
    if event.is_annuled:
        return

    thread_path = [rts.station for rts in event.thread_path]
    if event.from_station not in thread_path:
        log.warning('from_station {} is not in path'.format(event.from_station))
        return

    if event.to_station not in thread_path:
        log.warning('to_station {} is not in path'.format(event.to_station))
        return

    from_index = thread_path.index(event.from_station)
    to_index = len(thread_path) - thread_path[::-1].index(event.to_station)
    event.cancelled_stations = event.thread_path[from_index:to_index]


def calculate_thread_departure(event):
    # type: (CppkEvent) -> None
    thread_start_date = event.start_rts.calc_thread_start_date('departure', event.departure_date, MSK_TZ)
    event.thread_start_dt = datetime.combine(thread_start_date, event.thread.tz_start_time)


class CppkEvent(object):
    def __init__(self, raw_event):
        self.raw_event = raw_event

        self.stations = {
            'start': None,
            'finish': None,
            'from': None,
            'to': None
        }
        self.thread = None            # type: Optional[RThread]
        self.thread_start_dt = None
        self.suburban_key = None
        self.thread_path = []         # type: List[RTStation]
        self.cancelled_stations = []  # type: List[RTStation]
        self.is_annuled = False
        self.start_rts = None         # type: Optional[RTStation]

        self._thread_key = None

    @property
    def thread_key(self):
        # type: (CppkEvent) -> ThreadKey

        if self._thread_key is None:
            self._thread_key = ThreadKey(
                thread_key=self.suburban_key,
                thread_start_date=self.thread_start_dt,
                thread_type=ThreadEventsTypeCodes.SUBURBAN
            )
        return self._thread_key

    @property
    def create_dt(self):
        return self.raw_event['create_dt']

    @property
    def train_number(self):
        # type: (CppkEvent) -> unicode
        return self.raw_event['train_number']

    @property
    def departure_date(self):
        # type: (CppkEvent) -> date
        return self.raw_event['departure_date'].date()

    @property
    def start_station(self):
        # type: (CppkEvent) -> Station
        return self.stations['start']

    @start_station.setter
    def start_station(self, station):
        self.stations['start'] = station

    @property
    def finish_station(self):
        # type: (CppkEvent) -> Station
        return self.stations['finish']

    @finish_station.setter
    def finish_station(self, station):
        self.stations['finish'] = station

    @property
    def from_station(self):
        # type: (CppkEvent) -> Optional[Station]
        return self.stations['from']

    @from_station.setter
    def from_station(self, station):
        self.stations['from'] = station

    @property
    def to_station(self):
        # type: (CppkEvent) -> Optional[Station]
        return self.stations['to']

    @to_station.setter
    def to_station(self, station):
        self.stations['to'] = station

    def all_stations_matched(self):
        # type: (CppkEvent) -> bool

        if self.is_annuled:
            return not (self.start_station is None or self.finish_station is None)
        return all(self.stations.values())

    def is_whole_thread_matched(self):
        # type: (CppkEvent) -> bool

        return self.is_annuled or self.cancelled_stations


class CppkMatcher(object):
    def __init__(self, raw_events, log):
        self.log = log
        self.events = [CppkEvent(e) for e in raw_events]

    def match(self):
        # type: (CppkMatcher) -> List[CppkEvent]

        if not self.events:
            self.log.info('no events - no matching.')
            return []

        self.match_stations()
        self.match_threads()

        return self.events

    def match_stations(self):
        express_station_ids = set()
        for event in self.events:
            if event.raw_event['from_express_id'] is None and event.raw_event['to_express_id'] is None:
                event.is_annuled = True
            event_express_stations = {event.raw_event['start_express_id'], event.raw_event['finish_express_id']}
            if not event.is_annuled:
                event_express_stations.update({event.raw_event['from_express_id'], event.raw_event['to_express_id']})
            express_station_ids.update(event_express_stations)

        stations = {}
        for station in Station.objects.filter(express_id__in=list(express_station_ids)):
            stations[int(station.express_id)] = station

        not_matched_station_ids = express_station_ids - set(stations.keys())
        if not_matched_station_ids:
            self.log.warning('not matched express station ids: {}'.format(list(not_matched_station_ids)))

        matched_events = []
        for event in self.events:
            event.start_station = stations.get(event.raw_event['start_express_id'])
            event.finish_station = stations.get(event.raw_event['finish_express_id'])
            if not event.is_annuled:
                event.from_station = stations.get(event.raw_event['from_express_id'])
                event.to_station = stations.get(event.raw_event['to_express_id'])

            if event.all_stations_matched():
                matched_events.append(event)
            else:
                self.log.warning("can't match {} (raw id: {}). Missing stations: {}".format(
                    event.train_number,
                    event.raw_event['_id'],
                    [station_type for (station_type, station) in event.stations.items() if station is None]
                ))

        self.events = matched_events

    def _match_threads_by_stations(self):
        station_ids = {
            station.id
            for e in self.events
            for station in [e.start_station, e.finish_station]
        }

        # Собираем все RTStation элекричек,
        # у которых Station - проматченные ранее станции отправления и прибытия, полученные от Мовисты
        rtstations = RTStation.objects.filter(
            thread__t_type=TransportType.SUBURBAN_ID,
            station_id__in=list(station_ids),
            thread__type__code='basic'
        ).exclude(
            thread__uid__contains='MCZK',
        ).select_related(
            'thread', 'station'
        ).only(
            'thread__id', 'thread__number', 'station__id'
        ).order_by('id')

        thread_paths = collect_rts_by_threads(rtstations)
        threads_by_number = collect_threads_by_number(thread_paths)

        for event in self.events:
            for thread in threads_by_number[event.train_number]:
                thread_path = thread_paths[thread]

                event.start_rts = None
                finish_found = False
                for rts in thread_path:
                    # Начальная станция в нитке мовисты может не соответствовать начальной станции нашей нитки
                    # https://st.yandex-team.ru/RASPFRONT-9587
                    if not event.start_rts and rts.station == event.start_station:
                        event.start_rts = rts
                    if event.start_rts and rts.station == event.finish_station:
                        finish_found = True
                        break

                if event.start_rts and finish_found:
                    event.thread = thread
                    break

            if event.thread is None:
                self.log.warning("can't match {} (raw event: {}). Missing thread".format(
                    event.train_number,
                    event.raw_event['_id']
                ))

    def _fill_events_with_thread_data(self):
        rtstations = RTStation.objects.filter(
            thread__id__in=list(e.thread.id for e in self.events)
        ).select_related(
            'thread', 'station'
        ).order_by('id')
        thread_paths = collect_rts_by_threads(rtstations)

        suburban_keys = get_threads_suburban_keys([c.thread for c in self.events])
        for event in self.events:
            event.thread_path = thread_paths[event.thread]
            calculate_cancelled_stations(event, self.log)
            calculate_thread_departure(event)
            event.suburban_key = suburban_keys[event.thread.id]

    def match_threads(self):
        self._match_threads_by_stations()
        self.events = [e for e in self.events if e.thread]
        self._fill_events_with_thread_data()
        self.events = [e for e in self.events if e.is_whole_thread_matched()]
