# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import json
import re
from collections import defaultdict
from datetime import datetime, timedelta
from functools import partial
from itertools import chain

import pytz
from django.conf import settings
from django.db.models import Q
from mongoengine.context_managers import switch_db

from common.apps.suburban_events import models, dynamic_params
from common.apps.suburban_events.forecast.events import Event
from common.apps.suburban_events.models import LVGD01_TR2PROC_feed
from common.apps.suburban_events.scripts.update_companies_crashes import check_time_without_events
from common.apps.suburban_events.utils import (
    ThreadEventsTypeCodes, ClockDirection,
    collect_rts_by_threads, collect_threads_by_number, get_rtstation_key, light_mongoengine_query
)
from common.dynamic_settings.default import conf
from common.models.geo import Station, StationCode, CodeSystem
from common.models.schedule import RTStation
from common.models.transport import TransportType
from common.settings.utils import define_setting
from common.utils.date import MSK_TZ
from travel.rasp.library.python.common23.date.environment import now
from travel.rasp.library.python.common23.logging import log_run_time

NUMBER_HOURS_FOR_ONE_DAY = 23 * 60 * 60  # 23 часа

# https://st.yandex-team.ru/RASPFRONT-6470
define_setting('DISABLED_SOURCES', default=[LVGD01_TR2PROC_feed.Sources.SAI_PS], converter=json.loads)


def get_closest_rts(rtstations, event):
    diffs = []
    for rts, thread_start_dt in rtstations:
        rts_event_dt = rts.get_event_dt(event.type, thread_start_dt, MSK_TZ).replace(tzinfo=None)
        if rts_event_dt >= event.dt_normative:
            diff = (rts_event_dt - event.dt_normative).total_seconds()
        else:
            diff = (event.dt_normative - rts_event_dt).total_seconds()

        diffs.append((diff, (rts, thread_start_dt)))

    best_rts = sorted(diffs, key=lambda kv: kv[0])[0][1]
    return best_rts


class RZDEventMixin(object):
    @property
    def source_data(self):
        return self.rzd_event['id']

    @property
    def dt_normative(self):
        return self.rzd_event.dt_normative

    @property
    def dt_fact(self):
        return self.rzd_event.dt_fact

    @property
    def twin_key(self):
        return self.rzd_event.twin_key

    @property
    def type(self):
        return self.rzd_event.type

    @property
    def weight(self):
        return self.rzd_event.weight


class ResultEvent(RZDEventMixin, Event):
    def __init__(self, rzd_event, thread, thread_start_date, rtstation, passed_several_times):
        self.rzd_event = rzd_event
        self._thread = thread
        self._thread_start_date = thread_start_date
        self._rtstation = rtstation
        self._passed_several_times = passed_several_times

    @property
    def thread(self):
        return self._thread

    @property
    def thread_start_date(self):
        return self._thread_start_date

    @property
    def rtstation(self):
        return self._rtstation

    @property
    def passed_several_times(self):
        return self._passed_several_times

    @property
    def time(self):
        return getattr(self.rtstation, 'tz_' + self.type)

    @property
    def station_key(self):
        return get_rtstation_key(self.rtstation)


def station_to_str(station):
    return u'{}({})'.format(station.L_short_title(), station.id) if station else u'None'


class RzdEvent(object):
    def __init__(self, event_data):
        self.event_data = event_data

        self.stations = {
            'from': {'rzd_esr': None, 'expess': None, 'matched': None},
            'to':   {'rzd_esr': None, 'expess': None, 'matched': None},  # noqa:E241
            'oper': {'rzd_esr': None, 'expess': None, 'matched': None},
        }

        self.threads = {
            'match_by_number': [],
            'match_by_first_last': [],
            'match_by_oper_station': [],
            'rts_with_dates': [],
        }

        self.thread_match_problems = []
        self.station_match_problems = []

    def __getitem__(self, item):
        return self.event_data[item]

    def get_result_events(self):
        """ Одному событию РЖД может соответствовать несколько наших ниток, поэтому генерим 1+ наших событий"""

        for rts, start_date in self.threads['rts_with_dates']:
            yield ResultEvent(
                rzd_event=self,
                thread=rts.thread,
                thread_start_date=start_date,
                rtstation=rts,
                passed_several_times=getattr(rts, 'station_passed_several_times', False),
            )

    @property
    def dt_normative(self):
        return self['TIMEOPER_N']

    @property
    def dt_fact(self):
        return self['TIMEOPER_F']

    @property
    def twin_key(self):
        # Ключ для разделения двух одинаковых станций в одной нитке.
        # Используем номер нитки РЖД, т.к. в их системе нитка с одним номером
        # не может проходить одну станцию два раза
        return self['NOMPEX']

    @property
    def type(self):
        return {
            1: 'arrival',
            3: 'departure',
        }[self['KODOP']]

    @property
    def weight(self):
        return self['PRIORITY_RATING']

    @property
    def is_valid(self):
        return bool(self.threads['rts_with_dates'])

    def is_valid_to_match_thread(self):
        """ Можем пытаться матчить нитку, если найдена хотя бы одна из крайних станций + станция события. """
        return bool(self.station_oper and (self.station_from or self.station_to))

    def validate_stations(self):
        for station_type in ['from', 'to', 'oper']:
            esr_st = self.stations[station_type]['rzd_esr']
            express_st = self.stations[station_type]['express']
            self.stations[station_type]['matched'] = None
            if not esr_st and not express_st:
                self.station_match_problems.append({
                    'type': ('station', station_type),
                    'problem': 'no_match',
                    'description': u'Не нашли станцию {} по rzd_esr и express: {}'.format(
                        station_type, self.get_orig_name(station_type))
                })
            else:
                if esr_st and express_st and esr_st != express_st:
                    self.station_match_problems.append({
                        'type': station_type,
                        'problem': 'esr_express_mismatch',
                        'description': u'Найденные станции {} различаются: по rzd_esr {}, по express {}'.format(
                            station_type, station_to_str(esr_st), station_to_str(express_st)
                        )
                    })
                else:
                    self.stations[station_type]['matched'] = esr_st or express_st

    def validate_thread(self):
        if not self.threads['match_by_number']:
            self.thread_match_problems.append({
                'problem': 'no_match_by_number',
                'description': u'Не нашли нитку по номеру {}'.format(self.thread_number),
            })
            return

        if not self.threads['match_by_first_last']:
            self.thread_match_problems.append({
                'problem': 'no_match_by_first_last',
                'description': u'Не нашли нитку по первой/последней станции {} -> {}'.format(
                    station_to_str(self.station_from), station_to_str(self.station_to)
                ),
            })
            return

        if not self.threads['match_by_oper_station']:
            self.thread_match_problems.append({
                'problem': 'no_match_by_oper_station',
                'description': u'Не нашли нитку по станции события {}'.format(
                    station_to_str(self.station_oper),
                ),
            })
            return

        if not self.threads['rts_with_dates']:
            self.thread_match_problems.append({
                'problem': 'no_match_by_date',
                'description': u'Не нашли нитку по дате старта'
            })
            return

    @property
    def thread_number(self):
        match = re.match('[^\d]*(\d+)[^\d]*', self['NOMPEX'])
        if match:
            return match.group(1)
        else:
            self.thread_match_problems.append({
                'problem': 'bad_thread_number',
                'description': u'Невалидный номер нитки {}'.format(self['NOMPEX']),
            })

    def get_clock_dir(self):
        """
        Поезда бывают по часовой стрелке и против часовой стрелки.
        У РЖД нечетный номер у поездов по часовой стрелке, а четный - против.
        У нас "по часовой" и "против часовой" содержится в title ниток.
        """
        return ClockDirection.C_CLOCK_WISE if (int(self.thread_number) % 2) == 0 else ClockDirection.CLOCK_WISE

    def get_matched_station(self, station_type):
        return self.stations[station_type]['matched']

    @property
    def station_from(self):
        return self.get_matched_station('from')

    @property
    def station_to(self):
        return self.get_matched_station('to')

    @property
    def station_oper(self):
        return self.get_matched_station('oper')

    def get_orig_name(self, station_type):
        return self[EventsMatcher.STATION_FIELDS['name'][station_type]]

    def __unicode__(self):
        orig_name_from = self.get_orig_name('from')
        orig_name_to = self.get_orig_name('to')
        orig_name_oper = self.get_orig_name('oper')

        def name(s):
            return s['matched'].L_short_title() if s['matched'] else ''

        return u'{} {} at {}->{} from {}->{} to {}->{}; id: {}'.format(
            self.thread_number,
            self.type,
            orig_name_oper, name(self.stations['oper']),
            orig_name_from, name(self.stations['from']),
            orig_name_to, name(self.stations['to']),
            self['_id']
        )

    def __str__(self):
        return self.__unicode__().encode('utf8')


class MCZKEvent(RzdEvent):
    def __init__(self, *args, **kwargs):
        self.th_events = []
        super(MCZKEvent, self).__init__(*args, **kwargs)

    def is_valid_to_match_thread(self):
        """ Можем пытаться матчить нитку, если найдена станция события. """
        return bool(self.station_oper)

    @property
    def is_valid(self):
        return bool(self.th_events)

    def get_result_events(self):
        for th_event, th_expected_event in self.th_events:
            yield MCZKResultEvent(
                rzd_event=self,
                th_event=th_event,
                time=th_expected_event.time,
                station_key=th_expected_event.station_key,
                passed_several_times=th_expected_event.passed_several_times
            )


class BaseEventsMatcher(object):
    STATION_FIELDS = {
        'rzd_esr': {
            'from': 'STORASP',
            'to': 'STNRASP',
            'oper': 'STOPER',
        },
        'express': {
            'from': 'STOEX',
            'to': 'STNEX',
            'oper': 'STOPEREX',
        },
        'name': {
            'from': 'NAMESTO',
            'to': 'NAMESTN',
            'oper': 'STNAME',
        }
    }
    event_class = RzdEvent

    def __init__(self, raw_events, log=None):
        self.raw_events = raw_events
        self.prepare_raw_events()
        self.events = [self.event_class(e) for e in self.raw_events]
        self.log = log
        self.log_run_time = partial(log_run_time, logger=self.log)

    def match(self):
        if not self.events:
            self.log.info('No events - no matching.')
            return [], []

        self.match_stations()
        self.match_threads()

        return [e for e in self.events if e.is_valid], [e for e in self.events if not e.is_valid]

    def match_stations(self):
        stations, not_matched_codes = self.match_stations_express()
        self.log.info('stations express match: {}, {}'.format(len(stations), len(not_matched_codes)))
        stations, not_matched_codes = self.match_stations_esr()
        self.log.info('stations rzd_esr match: {}, {}'.format(len(stations), len(not_matched_codes)))

        for event in self.events:
            event.validate_stations()

        bad_events = [e for e in self.events if e.station_match_problems]
        not_matched = len(bad_events)
        matched = len(self.events) - not_matched

        self.log.info('stations matched: {}, not matched: {} ({:.1f}%)'.format(
            matched, not_matched, 100.0 * not_matched / len(self.events)
        ))

    def match_threads(self):
        raise NotImplementedError

    def match_stations_by_code_system(self, code_system, stations_getter):
        fields = self.STATION_FIELDS[code_system]

        codes = set()
        for event in self.events:
            for field in fields.values():
                codes.add(int(event[field]))

        stations = stations_getter(codes)
        for event in self.events:
            for field_type, field in fields.items():
                code = event[field]
                event.stations[field_type][code_system] = stations.get(code)

        not_matched_codes = codes - set(stations.keys())
        return stations, not_matched_codes

    def match_stations_express(self):
        def stations_getter(codes):
            stations = {}
            for station in Station.objects.filter(express_id__in=list(codes)):
                stations[int(station.express_id)] = station

            return stations

        return self.match_stations_by_code_system('express', stations_getter)

    def match_stations_esr(self):
        def stations_getter(codes):
            system = CodeSystem.objects.get(code='rzd_esr')
            station_codes = StationCode.objects.filter(system=system, code__in=list(codes)).select_related('station')
            stations = {}
            for code in station_codes:
                stations[int(code.code)] = code.station

            return stations

        return self.match_stations_by_code_system('rzd_esr', stations_getter)

    def prepare_raw_events(self):
        pass

    def get_skip_thread_uids(self, threads):
        skip_thread_uids = defaultdict(list)
        for updated_thread in models.UpdatedThread.objects.filter(uid__in=[thread.uid for thread in threads]):
            skip_thread_uids[updated_thread.start_date.date()].append(updated_thread.uid)

        return skip_thread_uids


class EventsMatcher(BaseEventsMatcher):
    def match_threads_by_number_and_stations(self, events):
        with self.log_run_time('get first/last rts'):
            rtstations = list(RTStation.objects.filter(
                Q(tz_arrival__isnull=True) | Q(tz_departure__isnull=True),
                thread__t_type=TransportType.SUBURBAN_ID,
            ).exclude(
                thread__uid__contains='MCZK',
                thread__type__code='cancel'
            ).select_related(
                'thread', 'station', 'thread__type'
            ).order_by('id'))

        with self.log_run_time('rts_by_threads'):
            thread_paths = collect_rts_by_threads(rtstations, collect_stations=True)

        with self.log_run_time('threads_by_number'):
            threads_by_number = collect_threads_by_number(thread_paths)

        with self.log_run_time('match threads by number and first/last station'):
            for event in events:
                if not event.thread_number:
                    continue

                if threads_by_number[event.thread_number]:
                    event.threads['match_by_number'] = threads_by_number[event.thread_number]

                for thread in threads_by_number[event.thread_number]:
                    first_st, last_st = thread_paths[thread]

                    number_parts = thread.number.split('/')
                    if first_st == event.station_from or last_st == event.station_to:
                        event.threads['match_by_first_last'].append(thread)
                    elif len(number_parts) > 1:
                        # Центральную часть составной нитки неудастся поматчить по начальной/конечной станциям.
                        # Будем матчить по станции события.
                        if event.thread_number in number_parts[1:-1]:
                            event.threads['match_by_first_last'].append(thread)

    @staticmethod
    def calc_rts_event_db_date(event, rts, thread):
        """
        Время прибытия/отправления со станции в базе может не совпадать со временем РЖД.
        В случае "перехода через сутки" необходимо правильно (как в базе) определять дату события
        для вычисления даты старта нитки.
        https://st.yandex-team.ru/RASPEXPORT-232
        """
        date = event.dt_normative.date()
        thread_time = thread.tz_start_time
        db_event_time = (datetime.combine(date, thread_time) +
                         timedelta(minutes=getattr(rts, 'tz_' + event.type)))
        db_event_time = datetime.combine(date, db_event_time.time())

        # Если разница между максимальным и минимальным значениями времени ржд
        # и времени из базы больше NUMBER_HOURS_FOR_ONE_DAY часов,
        # то события относятся к разным дням.
        diff_max_min_time = (max(event.dt_normative, db_event_time) - min(event.dt_normative, db_event_time)).seconds
        if event.dt_normative != db_event_time and diff_max_min_time > NUMBER_HOURS_FOR_ONE_DAY:
            minutes_delta = min((db_event_time - event.dt_normative).seconds,
                                (event.dt_normative - db_event_time).seconds) / 60

            if (db_event_time + timedelta(minutes=minutes_delta)).date() == date + timedelta(days=1):
                date -= timedelta(days=1)
            elif (db_event_time - timedelta(minutes=minutes_delta)).date() == date - timedelta(days=1):
                date += timedelta(days=1)
        return date

    def match_threads_oper_station_and_date(self, events):
        threads = set(chain.from_iterable(e.threads['match_by_first_last'] for e in events))
        skip_thread_uids = self.get_skip_thread_uids(threads)

        with self.log_run_time('get all suburban rtstations'):
            rtstations = list(RTStation.objects.filter(
                thread_id__in=[t.id for t in threads]
            ).exclude(
                thread__type__code='cancel'
            ).select_related(
                'thread', 'station'
            ).order_by('id'))

        with self.log_run_time('thread_paths'):
            thread_paths = defaultdict(list)
            for rts in rtstations:
                thread_paths[rts.thread].append(rts)

        with self.log_run_time('match threads by date'):
            for event in events:
                possible_with_rts = defaultdict(list)
                for thread in event.threads['match_by_first_last']:
                    thread_path = thread_paths[thread]
                    for rts in thread_path:
                        if event.station_oper == rts.station and getattr(rts, 'tz_' + event.type) is not None:
                            event.threads['match_by_oper_station'].append(rts)
                            date = self.calc_rts_event_db_date(event, rts, thread)
                            thread_start_date = rts.calc_thread_start_date(event.type, date, MSK_TZ)
                            thread_start_dt = datetime.combine(thread_start_date, rts.thread.tz_start_time)
                            if thread.runs_at(thread_start_date):
                                if thread.uid not in skip_thread_uids[thread_start_date]:
                                    possible_with_rts[thread].append((rts, thread_start_dt))

                # Для каждой найденной нитки выбираем одну rts + start_date
                rts_with_dates = []
                for thread, rtstations in possible_with_rts.items():
                    # Если нитка проходит через 1 станцию два+ раза,
                    # то пытаемся матчить по времени
                    if len(rtstations) > 1:
                        rts, start_dt = get_closest_rts(rtstations, event)
                        rts.station_passed_several_times = True
                        rts_with_dates.append((rts, start_dt))
                    else:
                        rts_with_dates.append(rtstations[0])

                # Найдено несколько подходящих ниток
                if len(rts_with_dates) > 1:
                    event_times = set()
                    for rts, thread_start_dt in rts_with_dates:
                        rts_event_dt = rts.get_event_dt(event.type, thread_start_dt, MSK_TZ).replace(tzinfo=None)
                        event_times.add(rts_event_dt)

                    if len(event_times) == 1:
                        # Две нитки с пересекающимся номером проходят в одно время по одной станции.
                        # Значит, это основная нитка + согласованная, а физически поезд один.
                        # Событие относится ко всем ниткам.
                        event.threads['rts_with_dates'] = rts_with_dates
                    else:
                        # Нитки и времена прохождения разные, но номера и крайние станции пересекаются.
                        # Значит, это нитки с составными номерами, у которых пересеклись номера для разных направлений.
                        # Выбираем нитку по близости события к расписанию.
                        event.threads['rts_with_dates'] = [get_closest_rts(rts_with_dates, event)]
                else:
                    event.threads['rts_with_dates'] = rts_with_dates

    def match_threads(self):
        events = [e for e in self.events if e.is_valid_to_match_thread()]
        self.match_threads_by_number_and_stations(events)
        self.match_threads_oper_station_and_date(events)

        for event in events:
            event.validate_thread()

    def prepare_raw_events(self):
        """
        https://st.yandex-team.ru/RASPEXPORT-221
        РЖД и КТЖ передали друг другу некоторые участки в аренду.
        На некоторых участках, около Локтя, Петропавловска и Илецка используется казахстанское время.
        Экспресс код таких станций начинается с 27.
        :param new_events: list of event dicts
        """
        almaty_tz = pytz.timezone('Asia/Almaty')
        for event in self.raw_events:
            oper_station_express = event.get('STOPEREX')
            if oper_station_express and str(oper_station_express).startswith('27'):
                event['TIMEOPER_N'] = almaty_tz.localize(event['TIMEOPER_N']).astimezone(MSK_TZ).replace(tzinfo=None)


class MCZKEventsMatcher(BaseEventsMatcher):
    """
    Матчинг ниток МЦК.
    https://st.yandex-team.ru/RASPEXPORT-290
    """
    event_class = MCZKEvent

    def match_threads(self):
        events = [e for e in self.events if e.is_valid_to_match_thread()]

        expected_events_keys_set = set()
        for event in events:
            if not event.thread_number:
                continue

            possible_event_times = [event.dt_normative.replace(second=0)]
            if event.type == 'arrival':
                possible_event_times.append(possible_event_times[0] - timedelta(minutes=1))

            station = event.stations['oper']['matched']
            for possible_time in possible_event_times:
                expected_events_keys_set.add((possible_time, str(station.id), event.type))

        expected_events_keys = [
            {
                'dt_normative': dt_normative,
                'station_key': station_key,
                'type': event_type,
            }
            for dt_normative, station_key, event_type in expected_events_keys_set
        ]

        if not expected_events_keys:
            return

        msg = 'get thread_events for mczk matching ({} expected_events_keys)'.format(len(expected_events_keys))
        with self.log_run_time(msg):
            query_date = now().replace(hour=0, minute=0, second=0, microsecond=0)
            q = models.ThreadEvents.objects(
                __raw__={
                    'key.thread_start_date': {'$gte': query_date - timedelta(days=1),
                                              '$lt': query_date + timedelta(days=1)},
                    'key.thread_type': ThreadEventsTypeCodes.MCZK,
                    'stations_expected_events': {'$elemMatch': {
                        '$or': expected_events_keys
                    }},
                }
            )

            with switch_db(models.ThreadEvents, settings.SUBURBAN_EVENTS_DATABASE_NAME + '_no_timeout'):
                th_events = light_mongoengine_query(q)

        with self.log_run_time('create expected_event_index'):
            expected_events_index = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
            for th_event in th_events:
                for exp_event in th_event.stations_expected_events:
                    (
                        expected_events_index
                        [th_event.key.clock_direction]
                        [exp_event.type]
                        [exp_event.station_key]
                        [exp_event.dt_normative].append(
                            (th_event, exp_event)
                        )
                    )

        with self.log_run_time('mczk matching (events {}, th_events {})'.format(len(events), len(th_events))):
            for event in events:
                clock_dir = event.get_clock_dir()
                station_id = str(event.stations['oper']['matched'].id)
                possible_event_times = [event.dt_normative.replace(second=0)]
                if event.type == 'arrival':
                    possible_event_times.append(possible_event_times[0] - timedelta(minutes=1))

                for possible_time in possible_event_times:
                    for th_event, exp_event in expected_events_index[clock_dir][event.type][station_id][possible_time]:
                        event.th_events.append((th_event, exp_event))


class MCZKResultEvent(RZDEventMixin):
    def __init__(self, rzd_event, th_event, time, station_key, passed_several_times):
        self.rzd_event = rzd_event
        self.th_event = th_event
        self._time = time
        self._station_key = station_key
        self._passed_several_times = passed_several_times

    @property
    def time(self):
        return self._time

    @property
    def station_key(self):
        return self._station_key

    @property
    def passed_several_times(self):
        return self._passed_several_times

    @property
    def thread_start_date(self):
        return self.th_event.key.thread_start_date


def get_new_events(last_matched_event_id):
    last_event_time = LVGD01_TR2PROC_feed.objects.order_by('-TIMEOPER_F')[0]['TIMEOPER_F']
    dynamic_params.set_param('last_successful_query_time', last_event_time)
    query = LVGD01_TR2PROC_feed.objects.all().order_by('id')

    if last_matched_event_id:
        query = query.filter(id__gt=last_matched_event_id)

    if settings.DISABLED_SOURCES:
        query = query.filter(SOURCE__nin=settings.DISABLED_SOURCES)

    query = query.limit(conf.SUBURBAN_MAX_EVENTS_TO_MATCH)

    events, mczk_events = [], []
    all_events = list(query.aggregate())
    for event in all_events:
        if event['NOMPEX'].endswith(u'МКЖД'):
            mczk_events.append(event)
        else:
            events.append(event)

    check_time_without_events(all_events, last_event_time)

    return events, mczk_events
