# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import pytz
from itertools import groupby
from datetime import datetime

from common.models.geo import Station
from common.models.schedule import Route, RThread, RTStation, RThreadType
from common.models.transport import TransportType
from travel.rasp.library.python.common23.date import environment
from common.utils.date import RunMask
from mapping.views.paths import draw_path, walk_segments
from common.models_admin.geo import RegionWhiteList

from travel.rasp.rasp_scripts.scripts.long_haul.export import base, validation
from travel.rasp.rasp_scripts.scripts.long_haul.export.formatters import get_station_title
from travel.rasp.rasp_scripts.scripts.long_haul.export.utils import is_valid_travel_times


route_id_list = set()
thread_id_list = set()
station_id_set = set()
white_region_ids = list()


# https://st.yandex-team.ru/RASPEXPORT-189
SKIP_COMPANIES = [
    59181,  # МКЖД
    59942,  # Московский метрополитен
]


def add_prefix(str, prefix='lh_', sep='__'):
    s_type, id = str.split(sep)
    return s_type + sep + prefix + id


class RoutesGenerator(base.DataGenerator):
    u"""Генератор маршрутов"""

    @classmethod
    def insert_into_validation_table(cls, obj, conn):
        validation.insert_route(obj, conn)

    def generate(self, log):
        global route_id_list
        for route in (Route.objects
                           .exclude(rthread__t_type=TransportType.URBAN_ID)
                           .filter(hidden=False)
                           .filter(rthread__t_type=TransportType.SUBURBAN_ID)).distinct():
            threads = route.rthread_set.filter(hidden=False).exclude(company_id__in=SKIP_COMPANIES).exclude(type=RThreadType.CANCEL_ID)

            for thread in threads:
                stations_count = thread.rtstation_set.filter(station__region_id__in=get_white_region_ids()).count()
                if stations_count:
                    route_id_list.add(route.id)
                    route.accepted_thread = thread
                    yield route
                    break


def get_white_region_ids():
    global white_region_ids

    if not white_region_ids:
        regions = RegionWhiteList.objects.all()[0].regions.all()
        white_region_ids = [r.id for r in regions]

    return white_region_ids


class ThreadsGenerator(base.DataGenerator):
    u"""Генератор ниток"""

    @classmethod
    def insert_into_validation_table(cls, obj, conn):
        validation.insert_thread(obj, conn)

    def get_threads(self):
        return (
            RThread.objects
                   .filter(route_id__in=route_id_list)
                   .filter(hidden=False)
                   .exclude(company_id__in=SKIP_COMPANIES)
                   .exclude(type=RThreadType.CANCEL_ID))

    def generate(self, log):
        u"""Генерировать список ниток"""
        global thread_id_list

        now_aware = environment.now_aware()

        for thread in self.get_threads():
            thread_tz_today = now_aware.astimezone(thread.pytz).date()

            # дни хождений
            dates = RunMask(thread.year_days, today=thread_tz_today).dates(past=False)

            if not dates:
                continue

            if is_valid_travel_times(thread.rtstation_set.all()):
                thread_id_list.add(thread.id)
                yield thread


class ThreadStopsGenerator(base.DataGenerator, base.StationMixin):
    u"""Генератор связок остановок и ниток"""

    @classmethod
    def insert_into_validation_table(cls, obj, conn):
        validation.insert_thread_stop(obj, conn)

    class Flags:
        def __init__(self):
            self.no_boarding = 0
            self.important = 0
            self.no_drop_off = 0

        def __str__(self):
            flags = []
            if self.no_boarding:
                flags.append('nb')

            if self.important:
                flags.append('imp')

            if self.no_drop_off:
                flags.append('nd')

            return ','.join(flags)

    def generate(self, log):
        u"""Генерировать список связок между нитками и остановками"""
        global station_id_set
        global thread_id_list

        rtstations = (RTStation.objects.filter(thread__id__in=thread_id_list).select_related('thread', 'station')
                      .order_by('id'))

        now_aware = environment.now_aware()

        for thread, rts_iter in groupby(rtstations, lambda rts: rts.thread):
            thread_tz_today = now_aware.astimezone(thread.pytz).date()

            start_date = thread.first_run(thread_tz_today)

            thread_start_dt_naive = datetime.combine(start_date, thread.tz_start_time)
            thread_tz_start_dt = thread.pytz.localize(thread_start_dt_naive)

            rts = list(rts_iter)
            i = 0
            for rtstation in rts:
                arrival_dt_aware = rtstation.get_arrival_dt(thread_start_dt_naive)
                departure_dt_aware = rtstation.get_departure_dt(thread_start_dt_naive)

                if arrival_dt_aware:
                    arrival_td = arrival_dt_aware - thread_tz_start_dt
                    rtstation.arrival = int(arrival_td.total_seconds())

                else:
                    rtstation.arrival = None

                if departure_dt_aware:
                    departure_td = departure_dt_aware - thread_tz_start_dt
                    rtstation.departure = int(departure_td.total_seconds())

                else:
                    rtstation.departure = None

                # исключаем станции нитки с нулевым временем стоянки и с ТО
                if (rtstation.arrival is not None) and (rtstation.departure is not None):
                    if (rtstation.departure - rtstation.arrival <= 0) or rtstation.is_technical_stop:
                        continue

                flags = self.Flags()

                if not i:
                    flags.no_drop_off = 1
                    flags.important = 1

                if rtstation == rts[-1]:
                    flags.no_boarding = 1
                    flags.important = 1

                if rtstation.is_technical_stop:
                    flags.no_drop_off = 1
                    flags.no_boarding = 1

                rtstation.index = i
                rtstation.flags = flags

                station_id_set.add(rtstation.station.id)
                yield rtstation

                i += 1
                del flags


class StopsGenerator(base.DataGenerator, base.StationMixin):
    @classmethod
    def insert_into_validation_table(cls, obj, conn):
        validation.insert_stop(obj, conn)

    def generate(self, log):
        global station_id_set
        stations = Station.objects.filter(id__in=station_id_set).select_related('t_type')

        for station in stations:
            yield station


class TimetableGenerator(base.DataGenerator):
    u"""
    Генератор времен отправления маршрутов, движущихся по индивидуальным расписаниям
    @see https://jira.yandex-team.ru/browse/RASP-5923
    @see http://wiki.yandex-team.ru/JandeksKarty/development/fordevelopers/masstransit/DataFormat#timetable
    """

    @classmethod
    def insert_into_validation_table(cls, obj, conn):
        validation.insert_timetable(obj, conn)

    def generate(self, log):
        global thread_id_list
        global marked_thread_uids

        threads = RThread.objects.filter(id__in=thread_id_list).filter(tz_start_time__isnull=False)

        for thread in threads:
            thread_tz_now = environment.now_aware().astimezone(thread.pytz)

            start_dt_aware = thread.pytz.localize(datetime.combine(thread_tz_now.date(), thread.tz_start_time))

            utc_dt = start_dt_aware.astimezone(pytz.utc)

            yield thread.uid, utc_dt.strftime('%H:%M:%S')


class CalendarGenerator(base.DataGenerator):
    u"""Генератор дат"""

    @classmethod
    def insert_into_validation_table(cls, obj, conn):
        validation.insert_calendar(obj, conn)

    def generate(self, log):
        global thread_id_list

        for thread in RThread.objects.filter(id__in=thread_id_list):
            yield thread


class GeometryGenerator(base.DataGenerator):

    @classmethod
    def insert_into_validation_table(cls, obj, conn):
        validation.insert_geometry(obj, conn)

    def generate(self, log):
        threads = RThread.objects.filter(
            route__rthread__t_type=TransportType.SUBURBAN_ID,
            id__in=thread_id_list,
            route__id__in=route_id_list).select_related('route')

        now_aware = environment.now_aware()

        for thread in threads:
            thread_tz_now = now_aware.astimezone(thread.pytz)

            start_date = thread.first_run(thread_tz_now.date())

            thread_start_dt = datetime.combine(start_date, thread.tz_start_time)

            route_map = draw_path(thread, thread_start_dt, list(thread.path.select_related('station')))

            if not route_map:
                continue

            if not route_map.path_arcs:
                continue

            index = 0

            for coords, level in walk_segments(route_map.path_arcs):
                if (coords is None) or (None in coords):
                    continue

                yield (
                    thread.uid,
                    index,
                    coords[1],
                    coords[0],
                    thread.id
                )

                index += 1


def route_number(route):
    try:
        return RThread.objects.filter(route=route)[0].number
    except IndexError:
        return ''


class L10nGenerator(base.DataGenerator):

    @classmethod
    def insert_into_validation_table(cls, obj, conn):
        pass

    def generate(self, log):
        global station_id_set
        stations = Station.objects.filter(id__in=station_id_set)
        routes = Route.objects.filter(id__in=route_id_list)

        for station in stations:
            station_title = get_station_title(station)
            row = ('stop', add_prefix(station.export_uid), 'ru', 'nominative', station_title)
            yield row

            row = ('stop', add_prefix(station.export_uid), 'ru', '', station_title)
            yield row

        for route in routes:
            number = route_number(route)

            yield ('route', u'lh_' + route.route_uid, 'ru', 'nominative', number)
            yield ('route', u'lh_' + route.route_uid, 'ru', '', number)
