# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from collections import defaultdict
from datetime import timedelta
from functools import partial

from pymongo import ReadPreference

from common.apps.suburban_events.models import LVGD01_TR2PROC_feed, LVGD01_TR2PROC_query, HourEventsRate, CompanyCrash
from common.db.mongo.bulk_buffer import BulkBuffer
from common.dynamic_settings.default import conf
from common.models.geo import CodeSystem
from common.models.schedule import CompanyMarker
from travel.rasp.library.python.common23.date.environment import now
from travel.rasp.library.python.common23.logging import log_run_time

log = logging.getLogger(__name__)
log_run_time = partial(log_run_time, logger=log)


READ_PREF = ReadPreference.SECONDARY_PREFERRED


def update_companies_events_rate():
    company_events_count = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(int))))
    events_by_query = defaultdict(list)
    average_rate = defaultdict(dict)

    stations_by_company, stations = get_rzd_esr_stations_by_company()

    query_date = now().replace(hour=0, minute=0, second=0, microsecond=0)
    query_dates = [query_date - timedelta(weeks=i) for i in range(1, conf.SUBURBAN_WEEKS_DEEP + 1)]
    queries_conditions = [{'queried_at': {'$gte': day, '$lt': day + timedelta(days=1)}} for day in query_dates]

    with log_run_time('get queries'):
        queries = list(LVGD01_TR2PROC_query.objects.read_preference(READ_PREF)(
            __raw__={
                '$or': queries_conditions
            }
        ))

    with log_run_time('get feeds'):
        feeds = list(LVGD01_TR2PROC_feed.objects.read_preference(READ_PREF).
                     filter(query__in=queries).only('STOPER', 'query').aggregate())

    with log_run_time('get feeds by query'):
        for feed in feeds:
            events_by_query[feed['query']].append(feed)

    for query in queries:
        q_day = query.queried_at.replace(hour=0, minute=0, second=0, microsecond=0)
        q_hour = query.queried_at.hour

        for event in events_by_query[query.id]:
            for company, company_stations in stations_by_company.items():
                if event['STOPER'] in company_stations:
                    company_events_count[company][q_hour][q_day][query] += 1

    for company, h_d_q in company_events_count.items():
        for hour, d_q in h_d_q.items():
            average_rate[company][hour] = sum([(sum(q.values()) / len(q)) for d, q in d_q.items()]) / len(d_q)

    with BulkBuffer(HourEventsRate._get_collection(), max_buffer_size=200, logger=log) as coll:
        for company, hour_average in average_rate.items():
            for hour, rate in hour_average.items():
                coll.update_one(
                    {'company': company.id,
                     'hour': hour},
                    {'$set': {'rate': rate}},
                    upsert=True,
                )


def calc_companies_crashes(rzd_feeds):
    """
    Рассчитываем аварии перевозчиков на основании очередного пакета от РЖД.
    Считаем количество событий в пакете для каждого перевозчика.
    Если количество событий сильно меньше среднестатистического для данного времени,
    то считаем, что по перевозчику есть авария.
    Если авария была ранее, а сейчас количество событий нормализовалось, то отменяем аварию.
    Если по какому-то перевозчику события вообще не пришли, то считаем это аварией.
    """
    company_events_count = defaultdict(int)
    dt = rzd_feeds[0].query.queried_at

    stations_by_company, stations = get_rzd_esr_stations_by_company()

    for event in rzd_feeds:
        for company, company_stations in stations_by_company.items():
            if event.STOPER in company_stations:
                company_events_count[company.id] += 1

    saved_crashes = {crash.company: crash for crash in CompanyCrash.objects.read_preference(READ_PREF).all()}
    company_rates_objs = HourEventsRate.objects.read_preference(READ_PREF).filter(hour=dt.hour)
    company_rates = {o.company: o.rate for o in company_rates_objs}

    companies_update_dicts = dict()
    for company, rate in company_events_count.items():
        saved_crash = saved_crashes.get(company)
        if company_rates[company] * conf.SUBURBAN_COMPANY_CRASH_RATE > rate:
            if not saved_crash or saved_crash.last_dt:
                companies_update_dicts[company] = {
                    'first_dt': dt,
                    'last_dt': None,
                    'last_rate': None,
                    'first_rate': rate,
                    'avr_rate': company_rates[company]
                }
        else:
            if saved_crash and not saved_crash.last_dt:
                companies_update_dicts[company] = {
                    'last_dt': dt,
                    'last_rate': rate
                }

    for company in set(company_rates.keys()).difference(set(company_events_count.keys())):
        saved_crash = saved_crashes.get(company)
        if not saved_crash or saved_crash.last_dt:
            companies_update_dicts[company] = {
                'first_dt': dt,
                'last_dt': None,
                'last_rate': None,
                'first_rate': 0,
                'avr_rate': company_rates[company]
            }

    with BulkBuffer(CompanyCrash._get_collection(), max_buffer_size=200) as coll:
        for company, update_dict in companies_update_dicts.items():
            coll.update_one(
                {'company': company},
                {'$set': update_dict},
                upsert=True,
            )


def check_time_without_events(events, last_event_time):
    if events:
        return

    if now() - timedelta(minutes=conf.SUBURBAN_ALL_COMPANIES_CRASH_TIME) > last_event_time:
        markers = CompanyMarker.objects.all().select_related('company')
        companies = [marker.company.id for marker in markers]
        with BulkBuffer(CompanyCrash._get_collection(), max_buffer_size=200) as coll:
            for company in companies:
                coll.update_one(
                    {'company': company},
                    {'$set': {
                        'first_dt': last_event_time,
                        'first_rate': 0,
                        'last_rate': None,
                        'last_dt': None,
                    }},
                    upsert=True,
                )


def get_rzd_esr_stations_by_company():
    stations_by_company = defaultdict(set)
    stations = set()

    markers = list(CompanyMarker.objects.all().select_related('company', 'station'))
    system = CodeSystem.objects.get(code='rzd_esr')

    for marker in markers:
        rzd_esr = marker.station.get_code(system=system)
        if rzd_esr:
            rzd_esr = int(rzd_esr)
            stations_by_company[marker.company].add(rzd_esr)
            stations.add(rzd_esr)

    return stations_by_company, stations
