# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from collections import defaultdict
from datetime import datetime
from functools import partial
from typing import List

from django.conf import settings
from mongoengine import Q
from pybreaker import CircuitBreaker
from pymongo import ReadPreference

from common.apps.suburban_events.models import ThreadStationState, ThreadState
from common.apps.suburban_events.utils import (
    get_rtstation_key, ThreadStationKey, ThreadKey, get_threads_suburban_keys, get_thread_type_and_clock_dir
)
from travel.library.python.tracing.instrumentation import traced_function

# импортим здесь, чтобы было частью API
from common.apps.suburban_events.utils import EventStateType  # noqa
from common.dynamic_settings.default import conf
from common.settings.utils import define_setting
from travel.rasp.library.python.common23.logging import log_run_time

log = logging.getLogger(__name__)
log_run_time = partial(log_run_time, logger=log, log_level=logging.DEBUG)


READ_PREF = ReadPreference.SECONDARY_PREFERRED

define_setting('SUBURBAN_EVENTS_BREAKER_PARAMS', default={'fail_max': 5, 'reset_timeout': 20})

suburban_breaker = CircuitBreaker(**settings.SUBURBAN_EVENTS_BREAKER_PARAMS)


class EventState(object):
    __slots__ = ['type', 'minutes_from', 'minutes_to', 'dt', 'tz']

    def __init__(self, type_=None, minutes_from=None, minutes_to=None, dt=None, tz=None):
        self.type = type_                 # type: EventStateType
        self.minutes_from = minutes_from  # type: int
        self.minutes_to = minutes_to      # type: int
        self.dt = dt                      # type: datetime
        self.tz = tz                      # type: str


class SegmentEventState(EventState):
    __slots__ = EventState.__slots__ + ['key']

    def __init__(self, key, *args, **kwargs):
        self.key = key  # type: ThreadStationKey
        super(SegmentEventState, self).__init__(*args, **kwargs)


class ScheduleRouteEventState(EventState):
    __slots__ = EventState.__slots__ + ['key']

    def __init__(self, key, *args, **kwargs):
        self.key = key  # type: ThreadStationKey
        super(ScheduleRouteEventState, self).__init__(*args, **kwargs)


class StationEventState(object):
    __slots__ = ['key', 'arrival', 'departure']

    def __init__(self, key=None, *args, **kwargs):
        self.key = key         # type: ThreadStationKey
        self.arrival = None    # type: EventState
        self.departure = None  # type: EventState


class ThreadEventState(EventState):
    __slots__ = EventState.__slots__ + ['key']

    def __init__(self, key=None, *args, **kwargs):
        self.key = key  # type: ThreadKey
        super(ThreadEventState, self).__init__(*args, **kwargs)


class SegmentStates(object):
    __slots__ = ['arrival', 'departure', 'from_arrival', 'to_departure', 'thread']

    def __init__(self):
        self.arrival = None       # type: SegmentEventState # событие прибытия на станцию прибытия
        self.departure = None     # type: SegmentEventState # событие отправления со станции отправления
        self.from_arrival = None  # type: SegmentEventState # событие прибытия на станцию отправления
        self.to_departure = None  # type: SegmentEventState # событие отправления со станции прибытия
        self.thread = None        # type: SegmentEventState


class ScheduleRouteStates(object):
    __slots__ = ['arrival', 'departure', 'thread']

    def __init__(self):
        self.arrival = None    # type: ScheduleRouteEventState
        self.departure = None  # type: ScheduleRouteEventState
        self.thread = None     # type: ScheduleRouteEventState


class ThreadCancelsState(object):
    def __init__(self, is_fully_cancelled, cancelled_segments):
        self.is_fully_cancelled = is_fully_cancelled  # type: bool
        self.cancelled_segments = cancelled_segments  # type: List[CancelledSegment]


class CancelledSegment(object):
    def __init__(self, rtstation_from, rtstation_to):
        self.rtstation_from = rtstation_from  # type: RTStation
        self.rtstation_to = rtstation_to      # type: RTStation


def is_events_enabled():
    enabled_in_export = conf.SUBURBAN_EXPORT_EVENTS_ENABLED and getattr(settings, 'ENABLE_SUBURBAN_STATES', False)
    return conf.SUBURBAN_EVENTS_ENABLED or enabled_in_export


@suburban_breaker
@traced_function
def get_segments_keys(segments):
    with log_run_time('get subkeys'):
        suburban_keys = get_threads_suburban_keys([s.thread for s in segments])

    segements_keys = {}
    with log_run_time('get_states_for_segments: create keys'):
        for segment in segments:
            suburban_key = suburban_keys.get(segment.thread.id)
            if not suburban_key:
                continue

            # Дни хождения ниток одного рейса не пересекаются с точки зрения РЖД по московскому времени,
            # поэтому для ключа используется дата старта в Europe/Moscow

            thread_start_dt = segment.start_dt.replace(tzinfo=None)
            thread_type, clock_dir = get_thread_type_and_clock_dir(segment.thread)

            segements_keys[segment] = {
                'thread': ThreadKey(suburban_key, thread_start_dt, thread_type, clock_dir),
                'from': ThreadStationKey(
                    suburban_key, thread_start_dt, get_rtstation_key(segment.rtstation_from),
                    arrival=segment.rtstation_from.tz_arrival,
                    departure=segment.rtstation_from.tz_departure,
                    thread_type=thread_type,
                    clock_direction=clock_dir
                ),
                'to': ThreadStationKey(
                    suburban_key, thread_start_dt, get_rtstation_key(segment.rtstation_to),
                    arrival=segment.rtstation_to.tz_arrival,
                    departure=segment.rtstation_to.tz_departure,
                    thread_type=thread_type,
                    clock_direction=clock_dir
                ),
            }

    return segements_keys


@suburban_breaker
@traced_function
def gen_station_states_by_keys(station_keys, cancels_as_possible_delay=True):
    station_states_by_key = defaultdict(list)
    if not is_events_enabled():
        return station_states_by_key

    with log_run_time('get station states for {} station_keys'.format(len(station_keys))):
        # Десериализация результата в объекты mongoengine - слишком тяжелая операция.
        # Используем aggregate чтобы получить просто dict'ы.

        station_keys_dicts = [k.to_mongo_dict() for k in station_keys]
        q = ThreadStationState.objects.read_preference(READ_PREF).filter(
            key__in=station_keys_dicts, outdated__ne=True
        )

        station_states = list(q.aggregate())

    with log_run_time('gen states_by_key for {} station_states'.format(len(station_states))):
        for station_state in station_states:
            if cancels_as_possible_delay:
                replace_cancelled_with_possible_delay(station_state)
            station_key = ThreadStationKey(**station_state['key'])
            station_states_by_key[station_key].append(station_state)
    return station_states_by_key


@suburban_breaker
@traced_function
def gen_thread_states_by_keys(thread_keys):
    threads_state_by_key = defaultdict(dict)
    if not is_events_enabled():
        return threads_state_by_key

    with log_run_time('get thread states for {} thread_keys'.format(len(thread_keys))):
        # Десериализация результата в объекты mongoengine - слишком тяжелая операция.
        # Используем aggregate чтобы получить просто dict'ы.
        thread_keys_dicts = [k.to_mongo_dict() for k in thread_keys]
        thread_states = list(
            ThreadState.objects
            .read_preference(READ_PREF)
            .filter(key__in=thread_keys_dicts).aggregate()
        )

    with log_run_time('gen threads_by_key for {} thread_states'.format(len(thread_states))):
        for thread_state in thread_states:
            thread_key = ThreadKey(**thread_state['key'])
            threads_state_by_key[thread_key] = thread_state

    return threads_state_by_key


@suburban_breaker
@traced_function
def get_states_by_segment_keys(segments_keys, all_keys=False, cancels_as_possible_delay=True):
    """
    :return: {
        segment_key_1: SegmentStates(),
        segment_key_2: SegmentStates(),
        ...
    }
    """
    segments_states = defaultdict(SegmentStates)
    if not is_events_enabled():
        return segments_states

    thread_keys, station_keys = [], []
    with log_run_time('organize keys'):
        for keys_by_type in segments_keys.values():
            for key_by_type, tss_key in keys_by_type.items():
                if key_by_type == 'thread':
                    thread_keys.append(tss_key)
                else:
                    station_keys.append(tss_key)

    station_states_by_key = gen_station_states_by_keys(
        station_keys = station_keys, cancels_as_possible_delay = cancels_as_possible_delay
    )
    threads_state_by_key = gen_thread_states_by_keys(thread_keys=thread_keys)

    with log_run_time('match segments_states'):
        for segment, key_by_type in segments_keys.items():
            key_thread = key_by_type.get('thread')
            if key_thread:
                thread_state = threads_state_by_key.get(key_thread)
                if thread_state:
                    segments_states[segment].thread = ThreadEventState(
                        key=key_by_type['thread'].to_str(),
                        type_=thread_state['state']['type'],
                        minutes_from=thread_state['state'].get('minutes_from'),
                        minutes_to=thread_state['state'].get('minutes_to'),
                        dt=thread_state['state'].get('dt'),
                        tz=thread_state['tz'],
                    )
                elif all_keys:
                    segments_states[segment].thread = ThreadEventState(
                        key=key_by_type['thread'].to_str(),
                        type_=EventStateType.UNDEFINED
                    )

            for key_type in ['from', 'to']:
                tss_key = key_by_type.get(key_type)
                if not tss_key:
                    continue

                station_state = get_tss_by_key(station_states_by_key, tss_key)
                key_str = tss_key.to_str()

                for event in ['departure', 'arrival']:
                    if event == 'arrival' and key_type == 'from':
                        segment_state_name = 'from_arrival'
                    elif event == 'departure' and key_type == 'to':
                        segment_state_name = 'to_departure'
                    else:
                        segment_state_name = event
                    state_name = event + '_state'
                    event_state = station_state.get(state_name) if station_state else None

                    if getattr(tss_key, event, None) is not None:
                        if station_state and event_state:
                            setattr(segments_states[segment], segment_state_name, SegmentEventState(
                                key=key_str,
                                type_=event_state['type'],
                                minutes_from=event_state.get('minutes_from'),
                                minutes_to=event_state.get('minutes_to'),
                                tz=station_state['tz'],
                                dt=event_state.get('dt'),
                            ))
                        elif all_keys:
                            setattr(segments_states[segment], segment_state_name, SegmentEventState(
                                key=key_str,
                                type_=EventStateType.UNDEFINED
                            ))

    return segments_states


@suburban_breaker
@traced_function
def get_states_by_segment_keys_plain(segments_str_keys, cancels_as_possible_delay=True):
    segments_keys = defaultdict(dict)
    for i, key_by_type in enumerate(segments_str_keys):
        key_from = key_by_type.get('departure')
        if key_from:
            segments_keys[i]['from'] = ThreadStationKey.from_str(key_from)

        key_to = key_by_type.get('arrival')
        if key_to:
            segments_keys[i]['to'] = ThreadStationKey.from_str(key_to)

        key_thread = key_by_type.get('thread')
        if key_thread:
            segments_keys[i]['thread'] = ThreadKey.from_str(key_thread)

    segments_states = get_states_by_segment_keys(
        segments_keys, all_keys=True, cancels_as_possible_delay=cancels_as_possible_delay
    )
    return [state for i, state in sorted(segments_states.items(), key=lambda kv: kv[0])]


@suburban_breaker
@traced_function
def get_states_for_segments(segments, all_keys=False, cancels_as_possible_delay=True):
    """
    :param segments: list of RThreadSegment
    :param all_keys:
    :param cancels_as_possible_delay: replace cancels with possible_delay
    :return: {
        RThreadSegment(): {
            'arrival': EventState(),
            'departure': EventState(),
            'thread': EventState(),
        },
        RThreadSegment: {
            'arrival': EventState(),
        },
    }
    """

    segments_keys = get_segments_keys(segments)
    segments_states = get_states_by_segment_keys(
        segments_keys, all_keys=all_keys, cancels_as_possible_delay=cancels_as_possible_delay
    )

    return segments_states


@suburban_breaker
@traced_function
def get_states_for_thread(thread, start_date, rtstations, all_keys=False, cancels_as_possible_delay=True):
    if not is_events_enabled():
        return None, {}

    suburban_key = get_threads_suburban_keys([thread]).get(thread.id)
    if not suburban_key:
        return None, {}

    thread_type, clock_dir = get_thread_type_and_clock_dir(thread)
    thread_start_dt = datetime.combine(start_date, thread.tz_start_time)
    key = ThreadKey(suburban_key, thread_start_dt, thread_type, clock_dir)

    try:
        state = ThreadState.objects.read_preference(READ_PREF).get(key=key.to_mongo_dict())
        thread_state = ThreadEventState(
            key=state.key.to_str(),
            type_=state.state.type,
            minutes_from=state.state.minutes_from,
            minutes_to=state.state.minutes_to,
            dt=state.state.dt,
            tz=state.tz,
        )
    except ThreadState.DoesNotExist:
        thread_state = ThreadEventState(key=key.to_str(), type_=EventStateType.UNDEFINED) if all_keys else None

    if clock_dir is None:
        type_filter = Q()
    else:
        type_filter = Q(key__thread_type=thread_type, key__clock_direction=clock_dir)

    stations_states = list(ThreadStationState.objects.read_preference(READ_PREF).filter(
        Q(key__thread_key=key.thread_key,
          key__thread_start_date=key.thread_start_date,
          outdated__ne=True) & type_filter
    ).aggregate())
    if cancels_as_possible_delay:
        for stations_state in stations_states:
            replace_cancelled_with_possible_delay(stations_state)

    station_states_by_key = defaultdict(list)
    for sts in stations_states:
        station_key = ThreadStationKey(**sts['key'])
        station_states_by_key[station_key].append(sts)

    states_by_rtstation = {}
    for rts in rtstations:
        tss_key = ThreadStationKey(
            thread_key=key.thread_key,
            thread_start_date=thread_start_dt,
            station_key=get_rtstation_key(rts),
            arrival=rts.tz_arrival,
            departure=rts.tz_departure,
            thread_type=thread_type,
            clock_direction=clock_dir
        )

        station_state = get_tss_by_key(station_states_by_key, tss_key)

        states = {}
        for event_name in ['arrival', 'departure']:
            if getattr(rts, 'tz_{}'.format(event_name), None) is not None:
                state_name = event_name + '_state'
                if station_state and station_state.get(state_name):
                    states[event_name] = EventState(
                        type_=station_state[state_name]['type'],
                        minutes_from=station_state[state_name].get('minutes_from'),
                        minutes_to=station_state[state_name].get('minutes_to'),
                        dt=station_state[state_name].get('dt'),
                        tz=station_state['tz']
                    )

        if states or all_keys:
            rts_state = StationEventState(key=tss_key.to_str())

            for name, state in states.items():
                setattr(rts_state, name, state)

            states_by_rtstation[rts] = rts_state

    return thread_state, states_by_rtstation


@suburban_breaker
@traced_function
def get_schedule_routes_keys(schedule_routes):
    with log_run_time('get suburban_keys'):
        suburban_keys = get_threads_suburban_keys([route.thread for route in schedule_routes])

    schedule_routes_keys = {}
    with log_run_time('calc schedule route keys'):

        for route in schedule_routes:
            suburban_key = suburban_keys.get(route.thread.id)
            if not suburban_key:
                continue

            thread_type, clock_dir = get_thread_type_and_clock_dir(route.thread)

            thread_station_key = ThreadStationKey(
                suburban_key, route.naive_start_dt, get_rtstation_key(route.rtstation),
                arrival=route.rtstation.tz_arrival,
                departure=route.rtstation.tz_departure,
                thread_type=thread_type,
                clock_direction=clock_dir
            )
            schedule_routes_keys[route] = thread_station_key

    return schedule_routes_keys


@suburban_breaker
@traced_function
def get_states_by_schedule_routes(schedule_routes, all_keys=False, cancels_as_possible_delay=True):
    """
    :param schedule_routes: list of ScheduleRoute
    :param all_keys
    :param cancels_as_possible_delay: replace cancels with possible_delay
    :return: {
        schedule_route_1: ScheduleStates(),
        schedule_route_2: ScheduleStates(),
        ...
    }
    """

    schedule_routes_states = defaultdict(ScheduleRouteStates)
    if not is_events_enabled():
        return schedule_routes_states

    schedule_routes_keys = get_schedule_routes_keys(schedule_routes)
    station_states_by_key = gen_station_states_by_keys(
        station_keys=list(schedule_routes_keys.values()), cancels_as_possible_delay=cancels_as_possible_delay
    )
    threads_state_by_key = gen_thread_states_by_keys(thread_keys=[
        station_key.get_thread_key()
        for station_key in schedule_routes_keys.values()
    ])

    with log_run_time('match schedule routes states'):
        for route, tss_key in schedule_routes_keys.items():
            thread_key = tss_key.get_thread_key()
            thread_state = threads_state_by_key.get(thread_key)
            if thread_state:
                schedule_routes_states[route].thread = ScheduleRouteEventState(
                    key=thread_key.to_str(),
                    type_=thread_state['state']['type'],
                    minutes_from=thread_state['state'].get('minutes_from'),
                    minutes_to=thread_state['state'].get('minutes_to'),
                    dt=thread_state['state'].get('dt'),
                    tz=thread_state['tz'],
                )
            elif all_keys:
                schedule_routes_states[route].thread = ScheduleRouteEventState(
                    key=thread_key.to_str(),
                    type_=EventStateType.UNDEFINED
                )

            station_state = get_tss_by_key(station_states_by_key, tss_key)

            for event in ['arrival', 'departure']:
                state_name = event + '_state'
                event_state = station_state.get(state_name) if station_state else None
                if station_state and event_state:
                    setattr(schedule_routes_states[route], event, ScheduleRouteEventState(
                        key=tss_key.to_str(),
                        type_=event_state['type'],
                        minutes_from=event_state.get('minutes_from'),
                        minutes_to=event_state.get('minutes_to'),
                        tz=station_state['tz'],
                        dt=event_state.get('dt'),
                    ))
                elif all_keys and getattr(route, '{}_dt'.format(event), None):
                    setattr(schedule_routes_states[route], event, ScheduleRouteEventState(
                        key=tss_key.to_str(),
                        type_=EventStateType.UNDEFINED
                    ))

    return schedule_routes_states


@suburban_breaker
@traced_function
def get_thread_station_keys_for_segments(segments, rtstations_by_segment):
    suburban_keys = get_threads_suburban_keys([s.thread for s in segments])
    station_keys_by_segment = {}

    for segment in segments:
        suburban_key = suburban_keys.get(segment.thread.id)
        if not suburban_key:
            continue

        thread_start_dt = segment.start_dt.replace(tzinfo=None)
        thread_type, clock_dir = get_thread_type_and_clock_dir(segment.thread)
        rtstations = rtstations_by_segment.get(segment)
        if rtstations is None:
            continue

        station_keys_by_rts = {
            rts: ThreadStationKey(
                suburban_key, thread_start_dt, get_rtstation_key(rts),
                arrival=rts.tz_arrival,
                departure=rts.tz_departure,
                thread_type=thread_type,
                clock_direction=clock_dir
            ) for rts in rtstations
        }
        station_keys_by_segment[segment] = station_keys_by_rts

    return station_keys_by_segment


@suburban_breaker
@traced_function
def get_full_states_for_segments(thread_station_keys_by_segment):
    thread_station_keys = set()
    for station_keys_by_rts in thread_station_keys_by_segment.values():
        thread_station_keys.update(station_keys_by_rts.values())

    return gen_station_states_by_keys(thread_station_keys, cancels_as_possible_delay=False)


@suburban_breaker
@traced_function
def get_cancelled_path_for_segments(segments):
    cancels_by_segment = {}
    if not segments:
        return cancels_by_segment

    rtstations_by_segment = {}
    for segment in segments:
        rtstations_by_segment[segment] = list(segment.thread.path)

    thread_station_keys_by_segment = get_thread_station_keys_for_segments(segments, rtstations_by_segment)
    station_states_by_key = get_full_states_for_segments(thread_station_keys_by_segment)

    for segment in segments:
        station_keys_by_rts = thread_station_keys_by_segment[segment]

        cancelled_segments = []
        cancelled_from, cancelled_to = None, None
        rtstations = rtstations_by_segment.get(segment)
        if rtstations is None:
            continue

        for rts in rtstations:
            station_key = station_keys_by_rts[rts]
            tss = get_tss_by_key(station_states_by_key, station_key)
            if not tss:
                continue

            arrival_state_type = tss.get('arrival_state', {}).get('type')
            departure_state_type = tss.get('departure_state', {}).get('type')
            if arrival_state_type != EventStateType.CANCELLED and departure_state_type == EventStateType.CANCELLED:
                cancelled_from = rts
            elif arrival_state_type == EventStateType.CANCELLED and departure_state_type != EventStateType.CANCELLED:
                cancelled_to = rts
                if not (cancelled_from is None or cancelled_to is None):
                    cancelled_segments.append(CancelledSegment(cancelled_from, cancelled_to))
                    cancelled_to, cancelled_from = None, None

        if cancelled_segments:
            is_fully_cancelled = (
                len(cancelled_segments) == 1 and
                rtstations[0] == cancelled_segments[0].rtstation_from and
                rtstations[-1] == cancelled_segments[0].rtstation_to
            )
            cancels_by_segment[segment] = ThreadCancelsState(
                is_fully_cancelled=is_fully_cancelled,
                cancelled_segments=cancelled_segments
            )

    return cancels_by_segment


def get_tss_by_key(station_states_by_key, tss_key):
    states = station_states_by_key.get(tss_key)
    if states and len(states) > 1:
        log.error('Ключ {} не определяет tss однозначно'.format(tss_key.to_str()))
    station_state = states[0] if states else None

    return station_state


def replace_cancelled_with_possible_delay(station_state):
    # type: (ThreadStationState) -> None

    for event in ['arrival_state', 'departure_state']:
        event_state = station_state.get(event)
        if not event_state:
            continue
        if event_state['type'] == EventStateType.CANCELLED:
            event_state['type'] = EventStateType.POSSIBLE_DELAY
