# -*- coding: utf-8 -*-

"""
RASPFRONT-1825
Скрипт для заполнения дополнительных полей в ThreadTariff
"""

import travel.avia.admin.init_project  # noqa

import logging
from collections import defaultdict
from datetime import datetime
from optparse import OptionParser

from django.db import transaction

from travel.avia.library.python.common.models.geo import Station
from travel.avia.library.python.common.models.schedule import RThread
from travel.avia.library.python.common.models.tariffs import ThreadTariff
from travel.avia.library.python.common.utils import environment
from travel.avia.library.python.common.utils.date import RunMask, timedelta2minutes
from travel.avia.library.python.common.utils.progress import PercentageStatus
from travel.avia.admin.lib.logs import add_stdout_handler
from travel.avia.admin.lib.mysqlutils import MysqlModelUpdater
from travel.avia.admin.lib.tmpfiles import clean_temp, get_tmp_dir
from travel.avia.admin.www.utils.mysql import fast_delete_tariffs_by_uids


log = logging.getLogger(__name__)


@clean_temp
@transaction.atomic
def fill_thread_tariff(log):
    log.info(u'Начинаем заполнение полей ThreadTariff')
    today = environment.today()

    new_tariff_fields = [
        'settlement_from',
        'settlement_to',
        'supplier',
        'number',
        't_type',
        'year_days_from',
        'time_from',
        'time_zone_from',
        'time_to',
        'time_zone_to',
    ]
    tmp_dir = get_tmp_dir()
    tariff_ids_by_thread_uid = defaultdict(list)
    tariff_query = ThreadTariff.objects.all().values_list('id', 'thread_uid')
    for tariff_id, thread_uid in tariff_query:
        tariff_ids_by_thread_uid[thread_uid].append(tariff_id)

    status = PercentageStatus(len(tariff_ids_by_thread_uid), log)
    outdated_thread_uids = set()
    with MysqlModelUpdater(ThreadTariff, tmp_dir, fields=new_tariff_fields) as tariff_updater:
        for thread_uid, tariff_ids in tariff_ids_by_thread_uid.iteritems():
            try:
                thread = RThread.objects.get(uid=thread_uid)
            except RThread.DoesNotExist:
                outdated_thread_uids.add(thread_uid)
            else:
                _update_tariffs(thread, tariff_ids, tariff_updater, today)
            status.step()

    log.info(u'Удаляем устаревшие тарифы из ThreadTariff')
    fast_delete_tariffs_by_uids(list(outdated_thread_uids))
    log.info(u'Успешно завершили заполнение полей ThreadTariff')


def _update_tariffs(thread, tariff_ids, tariff_updater, today):
    tariffs = ThreadTariff.objects.filter(id__in=tariff_ids)

    station_ids = {tariff.station_from_id for tariff in tariffs} | {tariff.station_to_id for tariff in tariffs}
    settlement_id_by_station_ids = {
        station_id: settlement_id
        for station_id, settlement_id in Station.objects.filter(id__in=station_ids).values_list('id', 'settlement_id')
    }

    path = list(thread.path)
    first_run = thread.first_run(today=today)
    naive_start_dt = datetime.combine(first_run, thread.tz_start_time)
    for tariff in tariffs:
        tariff.settlement_from_id = settlement_id_by_station_ids[tariff.station_from_id]
        tariff.settlement_to_id = settlement_id_by_station_ids[tariff.station_to_id]
        tariff.supplier_id = thread.supplier_id
        tariff.number = thread.number or thread.hidden_number
        tariff.t_type_id = thread.t_type_id

        rtstation_from = next(rts for rts in path if rts.station_id == tariff.station_from_id)
        rtstation_to = next(rts for rts in path
                            if rts.station_id == tariff.station_to_id and rts.id > rtstation_from.id)
        shift = rtstation_from.calc_days_shift(event='departure', start_date=first_run,
                                               event_tz=rtstation_from.pytz)
        tariff.year_days_from = str(RunMask(mask=tariff.year_days, today=today).shifted(shift))
        departure_dt = rtstation_from.get_departure_dt(naive_start_dt, rtstation_from.pytz)
        tariff.time_from = departure_dt.time()
        tariff.time_zone_from = rtstation_from.time_zone
        arrival_dt = rtstation_to.get_arrival_dt(naive_start_dt, rtstation_to.pytz)
        tariff.time_to = arrival_dt.time()
        tariff.time_zone_to = rtstation_to.time_zone
        tariff.duration = int(timedelta2minutes(arrival_dt - departure_dt))

        tariff_updater.add(tariff)


def main():
    optparser = OptionParser()

    optparser.add_option('-v', '--verbose', action='store_true')

    options, args = optparser.parse_args()

    if options.verbose:
        add_stdout_handler(log)

    fill_thread_tariff(log)
