# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

from django.db.models import Q

from common.models_utils import fetch_related
from common.models.schedule import Supplier, RTStation, TrainSchedulePlan
from common.models_utils.i18n import RouteLTitle
from common.utils.date import calculate_run_days
from common.models.transport import TransportType
from travel.rasp.library.python.common23.date import environment
from route_search.models import IntervalRThreadSegment
from route_search.shortcuts import search_routes

from travel.rasp.api_public.api_public.v3.search.base_search import BaseSearch
from travel.rasp.api_public.api_public.v3.search.helpers import make_segment_mask_schedule
from travel.rasp.api_public.api_public.v3.tariffs.shortcuts import add_tariffs


class BaseRaspDbSearch(BaseSearch):
    """Базовый поиск в базе расписаний"""
    def __init__(self, query, request, query_points):
        super(BaseRaspDbSearch, self).__init__(request, query, query_points)
        self.client_city = request.client_city

        t_types = query["transport_types"] if query["transport_types"] else TransportType.objects.all_cached()
        self.t_types = [t for t in t_types if t.id != TransportType.PLANE_ID]

        ids_of_excluded_suppliers = Supplier.objects.filter(exclude_from_external_api=True).values_list("id", flat=True)
        self.threads_filter = (
            ~Q(supplier_id__in=ids_of_excluded_suppliers) &
            ~Q(route__supplier_id__in=ids_of_excluded_suppliers)
        )

    @staticmethod
    def _set_train_schedule_plan(segments):
        """Добавляет план расписаний электричек в сегменты"""
        current_plan, next_plan = TrainSchedulePlan.add_to_threads(
            [s.thread for s in segments], environment.today()
        )
        for segment in segments:
            segment.next_plan = next_plan

    def _set_segment_schedule(self, segments):
        if self.add_days_mask:
            for segment in segments:
                run_days = calculate_run_days(segment, days_ago=30, result_timezone=self.result_pytz)
                segment.schedule = make_segment_mask_schedule(run_days)

    def add_extra_data(self, segments):
        """Добавление дополнительной информации после поиска"""
        self._set_train_schedule_plan(segments)

        fetch_related([segment.rtstation_from for segment in segments], "station", "thread", model=RTStation)
        RouteLTitle.fetch([segment.thread.L_title for segment in segments])

        self._set_segment_schedule(segments)

        for segment in segments:
            self.add_segment_used_points(segment)

    def set_segments(self, segments):
        """Выделяет обычные и интервальные сегменты"""
        for segment in segments:
            if isinstance(segment, IntervalRThreadSegment):
                self.interval_segments.append(segment)
            else:
                self.segments.append(segment)


class OneDayRaspDbSearch(BaseRaspDbSearch):
    """Поиск в базе расписаний на один день"""
    def __init__(self, query, request, query_points, currency_info):
        super(OneDayRaspDbSearch, self).__init__(query, request, query_points)
        self.currency_info = currency_info

    def search(self):
        segments, nears, service_types = search_routes(
            point_from=self.point_from,
            point_to=self.point_to,
            departure_date=self.departure_date,
            transport_types=self.t_types,
            threads_filter=self.threads_filter
        )

        if segments:
            self.add_extra_data(segments)

            add_tariffs(
                segments, self.t_types, self.point_from, self.point_to,
                self.departure_date, self.currency_info, self.client_city, self.national_version
            )

            self.set_segments(segments)


class AllDaysRaspDbSearch(BaseRaspDbSearch):
    """Поиск в базе расписаний на все дни"""
    def __init__(self, query, request, query_points):
        super(AllDaysRaspDbSearch, self).__init__(query, request, query_points)

    def search(self):
        segments, nears, service_types = search_routes(
            point_from=self.point_from,
            point_to=self.point_to,
            transport_types=self.t_types,
            threads_filter=self.threads_filter
        )

        if segments:
            self.add_extra_data(segments)
            self.set_segments(segments)
