# coding: utf8

import logging
from collections import defaultdict, Counter
from datetime import datetime, timedelta
from functools import partial
from itertools import chain

from common.apps.suburban_events import models
from common.apps.suburban_events.utils import (
    ThreadKey, collect_rts_by_threads, get_threads_suburban_keys, get_rtstation_key, get_thread_type_and_clock_dir, prefix_for_dict_keys
)
from common.db.mongo.bulk_buffer import BulkBuffer
from common.models.schedule import RTStation, RThread
from common.models.transport import TransportType
from common.utils.date import MSK_TZ
from travel.rasp.library.python.common23.date.environment import now
from travel.rasp.library.python.common23.logging import log_run_time

log = logging.getLogger(__name__)
log_run_time = partial(log_run_time, logger=log)


def precalc_all_expected_thread_events(start_date=None, last_run_time=None):
    """
    Функция высчитывает предполагаемые события прибытия и отправления ниток на станции во время пути,
    учитываются только станции на которых нитки останавливаются.
    Просчитанные события сохраняются в монгу в коллекцию ThreadEvents в поле stations_expected_events.
    Отмененные и измененные нитки сохраняются в коллекцию UpdatedThread.
    UpdatedThread используется для фильтрации ниток в прогнозаторе.
    Если при создании/удалении нитки изменения оказывается, что у заменяемой нитки уже есть поматченные события,
    то происходит перематчинг этих событий на новую нитку.
    :param start_date: дата старта ниткок, на которую необходимо пересчитать события ниток
    :param last_run_time: при передаче будут учтены нитки с modified_at >= last_run_time
    """
    if not start_date:
        start_date = now()
    start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
    start_dates = [start_date]
    threads = RThread.objects.filter(t_type=TransportType.SUBURBAN_ID)

    if last_run_time:
        threads = threads.filter(modified_at__gte=last_run_time)
        start_dates.append(start_date - timedelta(days=1))

    threads = list(threads.select_related(
        'basic_thread', 'type'
    ))

    threads = [t for t in threads if any(t.runs_at(d) for d in start_dates) or (t.cancel or t.update)]

    with log_run_time('get rtstations of {} threads'.format(len(threads))):
        rtstations = list(
            RTStation.objects.filter(
                thread_id__in=[t.id for t in threads]
            ).select_related(
                'thread', 'station'
            ).only(
                'tz_arrival', 'tz_departure', 'time_zone',
                'thread__id', 'thread__tz_start_time', 'thread__time_zone', 'thread__uid', 'thread__company',
                'station__id', 'station__time_zone', 'station__settlement_id',
                'station__use_in_departure_forecast', 'station__use_in_forecast'
            ).order_by('id')
        )

    with log_run_time('get pathes for threads of {} rtstaions'.format(len(rtstations))):
        thread_rts = collect_rts_by_threads(rtstations)

    updated_uids_by_date = defaultdict(set)
    for obj in models.UpdatedThread.objects.all():
        updated_uids_by_date[obj.start_date].add(obj.uid)

    skip_thread_uids_by_date = defaultdict(list)
    not_skip_thread_uids_by_date = defaultdict(list)
    delete_threads_by_date = defaultdict(list)
    rematch_threads = defaultdict(lambda: defaultdict(dict))
    start_date_threads = defaultdict(set)

    for start_date in start_dates:
        # Если нам встречается нитка отмены, то не учитываем ее и ее основную нитку.
        # Если втречается нитка изменение, то не учитываем ее основную нитку.
        skip_threads = set()

        for t in threads:
            if t.runs_at(start_date):
                start_date_threads[start_date].add(t)
                if t.cancel:
                    skip_threads.update([t, t.basic_thread])
                    delete_threads_by_date[start_date].append(t.basic_thread)
                elif t.update:
                    skip_threads.add(t.basic_thread)
                    delete_threads_by_date[start_date].append(t.basic_thread)
                    rematch_threads[start_date][t] = t.basic_thread
            else:
                if (t.cancel or t.update) and (t.basic_thread.uid in updated_uids_by_date[start_date]):
                    start_date_threads[start_date].add(t.basic_thread)
                    not_skip_thread_uids_by_date[start_date].append(t.basic_thread.uid)

                    if t.update:
                        delete_threads_by_date[start_date].append(t)
                        rematch_threads[start_date][t.basic_thread] = t

                    if t.cancel:
                        not_skip_thread_uids_by_date[start_date].append(t.uid)

        start_date_threads[start_date] = [t for t in start_date_threads[start_date] if t not in skip_threads]
        for skip_thread in skip_threads:
            skip_thread_uids_by_date[start_date].append(skip_thread.uid)

    days = ', '.join(str(d) for d in start_dates)
    with log_run_time('calc_date_thread_events for days {}'.format(days)):
        thread_events = get_expected_thread_events_on_date(thread_rts, start_date_threads)

    with log_run_time('rematch threads events for days {}'.format(days)):
        rematched_threads_events = rematch_threads_events(rematch_threads, thread_events)

    with log_run_time('delete predicted events'):
        delete_expected_events(delete_threads_by_date)

    with log_run_time('save predicted events for days {}'.format(days)):
        save_expected_events(thread_events, rematched_threads_events)

    with log_run_time('save canceled and changed threads'):
        save_skip_threads(skip_thread_uids_by_date, not_skip_thread_uids_by_date)


def save_skip_threads(skip_thread_uids_by_date, not_skip_thread_uids_by_date):
    with BulkBuffer(models.UpdatedThread._get_collection(), max_buffer_size=200, logger=log) as coll:
        for start_date, uids in skip_thread_uids_by_date.items():
            for uid in uids:
                coll.update_one({
                    'uid': uid,
                    'start_date': start_date},
                    {'$set': {
                        'uid': uid,
                        'start_date': start_date
                    }},
                    upsert=True,
                )

        for start_date, uids in not_skip_thread_uids_by_date.items():
            for uid in uids:
                coll.delete_one({
                    'uid': uid,
                    'start_date': start_date
                })


def get_expected_thread_events_on_date(thread_rts, start_date_threads):
    thread_events = defaultdict(lambda: defaultdict(lambda: {'path': [], 'pass_count': Counter()}))

    for start_date, threads in start_date_threads.items():
        for thread in threads:
            thread_tz_start = datetime.combine(start_date, thread.tz_start_time)
            for rts in thread_rts[thread]:
                rts_arrival_loc_dt = rts.get_event_loc_dt('arrival', thread_tz_start)
                rts_departure_loc_dt = rts.get_event_loc_dt('departure', thread_tz_start)

                # Отбрасываем станции на которых нет остановок.
                if rts.is_no_stop():
                    continue

                if rts_arrival_loc_dt:
                    rts_arrival_loc_dt = rts_arrival_loc_dt.astimezone(MSK_TZ).replace(tzinfo=None)
                    thread_events[start_date][thread]['path'].append(
                        {
                            'rtstation': rts,
                            'type': 'arrival',
                            'dt_normative': rts_arrival_loc_dt
                        }
                    )

                if rts_departure_loc_dt:
                    rts_departure_loc_dt = rts_departure_loc_dt.astimezone(MSK_TZ).replace(tzinfo=None)
                    thread_events[start_date][thread]['path'].append(
                        {
                            'rtstation': rts,
                            'type': 'departure',
                            'dt_normative': rts_departure_loc_dt
                        }
                    )

                thread_events[start_date][thread]['pass_count'][rts.station.id] += 1
    return thread_events


def save_expected_events(thread_events, rematched_threads_events):
    suburban_keys = get_threads_suburban_keys(chain(*thread_events.values()))

    with BulkBuffer(models.ThreadEvents._get_collection(), max_buffer_size=2000, logger=log) as coll:
        for start_date, threads in thread_events.items():
            for thread, th_events in threads.items():
                try:
                    thread_start_dt = datetime.combine(start_date, thread.tz_start_time)
                    thread_type, clock_dir = get_thread_type_and_clock_dir(thread)
                    thread_key = ThreadKey(suburban_keys[thread.id], thread_start_dt, thread_type, clock_dir)
                except Exception as ex:
                    log.exception('thread {} failed: {}'.format(thread.uid, repr(ex)))
                    continue

                set_dict = {
                    'stations_expected_events': [
                        models.StationExpectedEvent(
                            station_key=get_rtstation_key(th_event['rtstation']),
                            type=th_event['type'],
                            dt_normative=th_event['dt_normative'],
                            time=getattr(th_event['rtstation'], 'tz_' + th_event['type']),
                            passed_several_times=th_events['pass_count'][th_event['rtstation'].station.id] > 1
                        ).to_mongo() for th_event in th_events['path']]
                }

                if thread.company:
                    set_dict['thread_company'] = thread.company.id

                date_rematch = rematched_threads_events.get(start_date)
                if date_rematch:
                    events = date_rematch.get(thread)
                    if events:
                        set_dict['stations_events'] = [se.to_mongo() for se in events]
                        set_dict['need_recalc'] = True

                coll.update_one(
                    {'key.thread_key': thread_key.thread_key,
                     'key.thread_start_date': thread_key.thread_start_date,
                     'key.thread_type': thread_key.thread_type,
                     'key.clock_direction': thread_key.clock_direction},
                    {'$set': set_dict},
                    upsert=True,
                )


def delete_expected_events(delete_threads_by_date):
    suburban_keys = get_threads_suburban_keys(chain(*delete_threads_by_date.values()))
    thread_keys = []
    for start_date, threads in delete_threads_by_date.items():
        for thread in threads:
            try:
                thread_start_dt = datetime.combine(start_date, thread.tz_start_time)
                thread_type, clock_dir = get_thread_type_and_clock_dir(thread)
                thread_keys.append(ThreadKey(suburban_keys[thread.id], thread_start_dt, thread_type, clock_dir))
            except Exception as ex:
                log.exception('thread {} failed: {}'.format(thread.uid, repr(ex)))

    with BulkBuffer(models.ThreadEvents._get_collection(), max_buffer_size=200, logger=log) as coll_th, \
            BulkBuffer(models.ThreadStationState._get_collection(), max_buffer_size=200, logger=log) as coll_th_st:
        for thread_key in thread_keys:
            delete_filter = {
                'key.thread_key': thread_key.thread_key,
                'key.thread_start_date': thread_key.thread_start_date,
                'key.thread_type': thread_key.thread_type,
                'key.clock_direction': thread_key.clock_direction
            }
            coll_th.delete_one(delete_filter)
            coll_th_st.delete_one(delete_filter)


def rematch_threads_events(rematch_threads, thread_events):
    """
    Функция получает пары ниток, для которых нужно сделать перематчинг для каждого дня.
    Из монги достаются поматченные события "старых ниток".
    Если в новой нитке присутствуют те же станции, что и в поматченных событиях,
    то события перематчиваются на эти станции новой нитки.
    """
    if not rematch_threads:
        return {}

    suburban_keys = get_threads_suburban_keys(chain(*(t.values() for t in rematch_threads.values())))
    thread_keys = []
    reverse_key = {}
    old_events_by_thread = defaultdict(dict)

    for start_date, threads in rematch_threads.items():
        for thread, old_thread in threads.items():
            try:
                thread_start_dt = datetime.combine(start_date, old_thread.tz_start_time)
                thread_type, clock_dir = get_thread_type_and_clock_dir(old_thread)
                old_thread_key = ThreadKey(suburban_keys[old_thread.id], thread_start_dt, thread_type, clock_dir)
                reverse_key[old_thread_key] = thread
            except Exception as ex:
                log.exception('rematch thread {} failed: {}'.format(old_thread.uid, repr(ex)))
                continue
            thread_keys.append(prefix_for_dict_keys(old_thread_key.to_mongo_dict(), 'key.'))

        if not thread_keys:
            return {}

        for th in models.ThreadEvents.objects(__raw__={'$or': thread_keys}):
            key = ThreadKey(th.key.thread_key, th.key.thread_start_date, th.key.thread_type, th.key.clock_direction)
            thread = reverse_key.get(key)
            if thread and th.stations_events:
                new_events = []
                for station_event in th.stations_events:
                    count = thread_events[start_date][thread]['pass_count'].get(int(station_event.station_key))
                    if not count:
                        continue

                    station_event.passed_several_times = station_event.passed_several_times or (count > 1)
                    new_events.append(station_event)
                old_events_by_thread[start_date][thread] = new_events

    return old_events_by_thread
