# coding: utf-8

import travel.rasp.admin.scripts.load_project  # noqa

import sys
import logging
from optparse import OptionParser

from django.db import transaction, IntegrityError

from common.models.geo import ExternalDirection
from common.models.schedule import RThread
from common.models.tariffs import SuburbanTariff, AeroexTariff, TariffType
from common.models.transport import TransportType
from common.utils.progress import PercentageStatus
from common.utils.metrics import task_progress_report
from travel.rasp.admin.lib.maintenance.flags import flags
from travel.rasp.admin.lib.maintenance.scripts import job
from travel.rasp.admin.lib.logs import create_current_file_run_log, print_log_to_stdout, get_script_log_context, ylog_context
from travel.rasp.admin.lib.sql import fast_delete


log = logging.getLogger(__name__)


class TariffsSet(object):
    ETRAIN = TariffType.objects.get(code='etrain')
    REDUCED = TariffType.objects.get(code='reduced')

    def __init__(self):
        self.tariffs = {}

    def set_tariffs(self, from_markers, to_markers, tariff, multiplier, synchronous=False):
        for f in from_markers:
            for t in to_markers:
                if f == t:
                    continue

                station_from_id, station_to_id = f.station_id, t.station_id

                # Тарифы по зонам не должны перекрывать тариф по группе
                if (station_from_id, station_to_id) not in self.tariffs:
                    self.tariffs[station_from_id, station_to_id] = tariff, multiplier, synchronous

    @transaction.atomic
    def save(self):
        log.info(u'Удаляем старые записи')
        fast_delete(AeroexTariff.objects.filter(precalc=True))

        log.info(u'Записываем маршруты в базу')

        progress = PercentageStatus(len(self.tariffs), log=log)
        for key, value in self.tariffs.items():
            station_from_id, station_to_id = key
            tariff, multiplier, synchronous = value

            t = AeroexTariff()

            t.station_from_id = station_from_id
            t.station_to_id = station_to_id
            t.reverse = synchronous
            t.precalc = True

            t.tariff = tariff
            t.type = self.ETRAIN

            try:
                t.save()
            except IntegrityError:
                # Пропускаем уже существующие записи цен
                continue

            if multiplier:
                r = AeroexTariff()

                r.station_from_id = station_from_id
                r.station_to_id = station_to_id
                r.reverse = synchronous
                r.precalc = True

                r.tariff = tariff * multiplier
                r.type = self.REDUCED

                r.save()

            progress.step()

        log.info(u'Готово')


def trim_path(path, stations_from, stations_to):
    """Максимально сужаем path по данным станциям"""

    first = None
    last = None

    stations_from = set(s.id for s in stations_from)
    stations_to = set(s.id for s in stations_to)

    for i, rts in enumerate(path):
        if rts.station_id in stations_from:
            first = i

        if rts.station_id in stations_to and not last:
            last = i

    return path[first:last + 1]


def walk(from_markers, to_markers, markers_by_station):
    """Поиск машрута для маркеров и вычисление пути по нему"""

    stations_from = [m.station for m in from_markers]
    stations_to = [m.station for m in to_markers]

    threads = RThread.objects.filter(
        znoderoute2__station_from__in=stations_from,
        znoderoute2__station_to__in=stations_to,
        t_type_id=TransportType.SUBURBAN_ID,
    ).order_by()[:1]

    if not threads:
        return None

    thread = threads[0]

    path = trim_path(list(thread.rtstation_set.order_by('id')), stations_from, stations_to)

    zones = []

    for rts in path:
        try:
            marker = markers_by_station[rts.station_id]
        except KeyError:
            log.error(u'Станция %s не привязана к направлению' % rts.station.title)
            return None

        # Въехали в новую зону
        if not zones or marker.zone != zones[-1].zone:
            zones.append(marker)

    try:
        return sum(SuburbanTariff.get_tariff(marker.zprice) for marker in zones[1:])
    except TypeError:
        return None


def calculate_zone_tariffs():
    tariffs = TariffsSet()

    log.info(u'Начинаем расчет тарифов')

    progress = PercentageStatus(ExternalDirection.objects.count(), log=log)

    for d in ExternalDirection.objects.all():
        calculate_zone_tariff_for_direction(d, tariffs)

        progress.step()

    log.info(u'Расчет тарифов закончен')

    return tariffs


def calculate_zone_tariff_for_direction(direction, tariffs=None):
    tariffs = tariffs or TariffsSet()

    multiplier = direction.reduced_tariff

    zones = {}
    groups = {}
    markers_by_station = {}

    for m in direction.externaldirectionmarker_set.select_related('station').all():
        if m.zgroup:
            groups.setdefault(m.zgroup, []).append(m)

        if m.zone:
            zones.setdefault(m.zone, []).append(m)

        markers_by_station[m.station_id] = m

    for g, markers in groups.items():
        tariff = SuburbanTariff.get_tariff(markers[0].zgroup)

        if tariff:
            tariffs.set_tariffs(markers, markers, tariff, None)

    for from_zone, from_markers in zones.items():
        for to_zone, to_markers in zones.items():
            if from_zone == to_zone:
                marker = from_markers[0]
                tariff = SuburbanTariff.get_tariff(marker.zprice_inner or marker.zprice)
            elif len(zones) == 2:
                tariff = SuburbanTariff.get_tariff(from_markers[0].zprice_exit or to_markers[0].zprice)
            else:
                tariff = walk(from_markers, to_markers, markers_by_station)

            if tariff:
                tariffs.set_tariffs(from_markers,
                                    to_markers,
                                    tariff,
                                    multiplier,
                                    synchronous=direction.use_synchronous_tariff)

    return tariffs


if __name__ == '__main__':
    with ylog_context(**get_script_log_context()), task_progress_report('precalc_tarrifs'):
        create_current_file_run_log()

        parser = OptionParser()
        parser.add_option('-v', '--verbose', dest='verbose', action='store_true')

        (options, args) = parser.parse_args()

        if options.verbose:
            print_log_to_stdout(log)

        if flags['maintenance'] == job.PRECALC_TARIFFS.flag_value:
            log.error(u"Нельзя запускать пересчет тарифов, он уже запущен")
            sys.exit(1)

        previous_flag = flags['maintenance']

        try:
            flags['maintenance'] = job.PRECALC_TARIFFS.flag_value

            tariffs = calculate_zone_tariffs()

            tariffs.save()
        finally:
            flags['maintenance'] = previous_flag
