# coding: utf-8

import travel.rasp.admin.scripts.load_project  # noqa

import argparse
import json
import logging
from collections import defaultdict
from datetime import datetime, time
from itertools import groupby

from django.db import transaction
from django.db.models import Q

from travel.rasp.admin.admin.red.models import MetaRoute
from common.models.geo import Station
from common.models.schedule import Route, RThread, RTStation, Supplier, RThreadType
from common.models.transport import TransportType
from travel.rasp.library.python.common23.date import environment
from common.utils.date import RunMask, timedelta2minutes
from common.utils.progress import PercentageStatus
from travel.rasp.admin.importinfo.models.two_stage_import import TwoStageImportPackage
from travel.rasp.admin.lib.logs import print_log_to_stdout, create_current_file_run_log, get_script_log_context, ylog_context
from travel.rasp.admin.scripts.schedule.utils.route_loader import BulkRouteSaver
from travel.rasp.admin.www.utils.mysql import fast_delete_routes


log = logging.getLogger(__name__)


AFFECTED_TRANSPORT_TYPE_ID = TransportType.BUS_ID
MAX_ALLOWED_MINUTES = 12 * 60


class ConstructBackDataError(Exception):
    pass


class PseudoRouteBuilder(object):
    def __init__(self):
        self.today = environment.today()
        self.t_type = TransportType.objects.get(id=AFFECTED_TRANSPORT_TYPE_ID)
        self.pseudo_thread_type = RThreadType.objects.get(id=RThreadType.PSEUDO_BACK_ID)
        self.supplier = Supplier.objects.get(code='pseudo')

    @transaction.atomic
    def run(self):
        self.delete_pseudo_directions()
        thread_ids_by_pseudo_directions = self.find_pseudo_directions()
        self.build_pseudo_routes(thread_ids_by_pseudo_directions)

    def delete_pseudo_directions(self):
        log.info(u'Начинаем удаление обратных псевдорейсов.')
        fast_delete_routes(Route.objects.filter(supplier=self.supplier), log)
        log.info(u'Успешно удалили обратные псевдорейсы.')

    def find_pseudo_directions(self):
        """
        Нужно получить все нитки заданного типа транспорта для того,
        чтобы исключить из возможных псевдо направлений те, где мы знаем точное сообщение,
        а строить нужно только по тем направлениям, где разрешено.
        """

        log.info(u'Начинаем поиск направлений для обратных псевдорейсов.')

        allowed_route_ids = set(self.get_allowed_route_ids())
        all_route_ids = list(Route.objects.filter(t_type_id=AFFECTED_TRANSPORT_TYPE_ID, hidden=False).order_by()
                                          .values_list('id', flat=True))
        log.info(u'Получили %s разрешенных маршрутов', len(allowed_route_ids))
        log.info(u'Получили %s маршрутов', len(all_route_ids))

        can_add_pseudo_route_by_thread_id = {
            thread_id: route_id in allowed_route_ids
            for thread_id, route_id in RThread.objects.filter(route_id__in=all_route_ids).order_by()
                                                      .values_list('id', 'route_id')
        }
        log.info(u'Получили %s ниток', len(can_add_pseudo_route_by_thread_id))

        status = PercentageStatus(len(can_add_pseudo_route_by_thread_id), log)
        directions = set()
        back_directions = defaultdict(set)
        # TODO: Добавить явное упорядочивание по нитке
        # В metamove из RThread убран ordering, и не будет лишнего JOIN при order_by('thread', 'id')
        stops = RTStation.objects.filter(thread_id__in=can_add_pseudo_route_by_thread_id)\
                                 .values('thread_id', 'station_id')\
                                 .order_by('id')
        log.debug(u'RTStation query: %s...%s', unicode(stops.query)[:500], unicode(stops.query)[-500:])
        for thread_id, stops_iter in groupby(stops, key=lambda x: x['thread_id']):
            station_ids = [stop['station_id'] for stop in stops_iter]

            directions.add((station_ids[0], station_ids[-1]))

            if can_add_pseudo_route_by_thread_id[thread_id]:
                back_directions[(station_ids[-1], station_ids[0])].add(thread_id)

            status.step()

        thread_ids_by_pseudo_directions = {
            back_direction: thread_ids for back_direction, thread_ids in back_directions.iteritems()
            if back_direction not in directions
        }
        log.info(u'Нашли %s возможных направлений для обратных псевдорейсов.', len(thread_ids_by_pseudo_directions))

        return thread_ids_by_pseudo_directions

    def build_pseudo_routes(self, thread_ids_by_pseudo_directions):
        log.info(u'Начинаем построение обратных псевдорейсов.')
        route_saver = BulkRouteSaver(log=log)
        thread_by_ids = self.prefetch_threads(thread_ids_by_pseudo_directions)

        status = PercentageStatus(len(thread_ids_by_pseudo_directions), log)
        pseudo_route_count = 0
        for pseudo_direction, thread_ids in thread_ids_by_pseudo_directions.iteritems():
            status.step()

            threads = [thread_by_ids[thread_id] for thread_id in thread_ids]
            try:
                mask, days, trip_minutes = self.construct_back_data_from_forward_threads(threads)
            except ConstructBackDataError:
                continue

            pseudo_data = {
                'min_trips_per_day': min(days.itervalues()),
                'max_trips_per_day': max(days.itervalues()),
                'min_trip_minutes': min(trip_minutes),
                'max_trip_minutes': max(trip_minutes),
                'avg_trip_minutes': sum(trip_minutes) / len(trip_minutes),
            }
            self.build_pseudo_route(pseudo_direction, mask, pseudo_data, route_saver)
            pseudo_route_count += 1

        route_saver.load()
        log.info(u'Успешно построили %s обратных псевдорейсов.', pseudo_route_count)

    def build_pseudo_route(self, pseudo_direction, mask, pseudo_data, route_saver):
        first_station = Station.objects.get(id=pseudo_direction[0])
        last_station_id = pseudo_direction[1]
        timezone = first_station.get_tz_name()

        route = Route(t_type=self.t_type, supplier=self.supplier)

        thread = RThread(
            route=route,
            type=self.pseudo_thread_type,
            changed=True,
            pseudo_data=json.dumps(pseudo_data, ensure_ascii=False, indent=4, encoding='utf-8'),
            year_days=str(mask),
            time_zone=timezone,
            t_type=self.t_type,
            supplier=self.supplier,

            # fake field
            tz_start_time=time(0, 0),
        )

        first_stop = RTStation(thread=thread, station=first_station, time_zone=timezone,
                               tz_arrival=None, tz_departure=0)
        last_stop = RTStation(thread=thread, station_id=last_station_id, time_zone=timezone,
                              tz_arrival=pseudo_data['avg_trip_minutes'], tz_departure=None)

        route.threads = [thread]
        thread.rtstations = [first_stop, last_stop]

        route.route_uid = thread.gen_route_uid(use_stations=True)

        thread.ordinal_number = 0
        thread.gen_import_uid()
        thread.gen_uid()
        thread.gen_title()

        route.title = thread.title

        route_saver.save_route(route)
        route_saver.save_thread(thread)
        route_saver.save_rtstation(first_stop)
        route_saver.save_rtstation(last_stop)

    def construct_back_data_from_forward_threads(self, threads):
        mask = RunMask(today=self.today)
        days = defaultdict(int)
        trip_minutes = []
        for thread in threads:
            naive_start_dt = datetime.combine(self.today, thread.tz_start_time)
            first_stop = thread.rtstation_set.all().order_by('id')[0]
            last_stop = thread.rtstation_set.all().order_by('-id')[0]
            departure_dt = first_stop.get_departure_dt(naive_start_dt)
            arrival_dt = last_stop.get_arrival_dt(naive_start_dt)
            thread_trip_minutes = int(timedelta2minutes(arrival_dt - departure_dt))
            trip_minutes.append(thread_trip_minutes)

            if thread_trip_minutes > MAX_ALLOWED_MINUTES:
                raise ConstructBackDataError()

            # Маску посчитаем во временной зоне конечной станции для времени прибытия,
            # тогда это будет маска обратного псевдорейса
            thread_mask = thread.get_mask(today=self.today)
            shift = last_stop.calc_days_shift(event='arrival', start_date=self.today)
            if shift:
                thread_mask = thread_mask.shifted(shift)

            mask |= thread_mask
            for day in thread_mask.iter_dates():
                days[day] += 1

        if not days:
            raise ConstructBackDataError()

        return mask, days, trip_minutes

    def get_allowed_route_ids(self):
        allowed_tsi_package_ids = list(
            TwoStageImportPackage.objects.filter(allow_back_pseudo_routes=True, t_type_id=AFFECTED_TRANSPORT_TYPE_ID)
                                         .values_list('id', flat=True)
        )
        allowed_red_metaroute_ids = list(
            MetaRoute.objects.filter(allow_back_pseudo_routes=True, t_type_id=AFFECTED_TRANSPORT_TYPE_ID)
                             .values_list('id', flat=True)
        )
        return Route.objects.filter(t_type_id=AFFECTED_TRANSPORT_TYPE_ID, hidden=False)\
                            .filter(Q(two_stage_package__in=allowed_tsi_package_ids) |
                                    Q(red_metaroute__in=allowed_red_metaroute_ids))\
                            .order_by().values_list('id', flat=True)

    def prefetch_threads(self, thread_ids_by_pseudo_directions):
        thread_ids = set()
        for ids in thread_ids_by_pseudo_directions.itervalues():
            thread_ids |= ids

        return {thread.id: thread for thread in RThread.objects.filter(id__in=thread_ids).order_by()}


if __name__ == '__main__':
    with ylog_context(**get_script_log_context()):
        parser = argparse.ArgumentParser()
        parser.add_argument("-v", "--verbose", action="store_true", help="increase output verbosity")

        args = parser.parse_args()
        if args.verbose:
            print_log_to_stdout(log)

        create_current_file_run_log()

        PseudoRouteBuilder().run()
