# coding: utf-8

import logging
import os.path
from datetime import timedelta, datetime
from operator import itemgetter

from lxml import etree

from common.models.schedule import RThread
from common.utils.date import daterange
from cysix.two_stage_import import CysixMaskParser
from cysix.two_stage_import.factory import CysixTSIFactory
from cysix.two_stage_import.route_importer import CysixTSIRouteImporter
from travel.rasp.library.python.common23.date import environment
from common.utils.caching import cache_method_result_with_exception, cache_method_result
from travel.rasp.admin.lib.schedule import fast_get_threads_with_rtstations
from travel.rasp.admin.lib.xmlutils import get_sub_tag_text
from travel.rasp.admin.scripts.schedule.utils.olven import get_olven_file
from travel.rasp.admin.scripts.schedule.utils import RaspImportError
from travel.rasp.admin.scripts.schedule.utils.mask_builders import MaskBuilder
from travel.rasp.admin.scripts.schedule.utils.file_providers import PackageFileProvider
from travel.rasp.admin.scripts.schedule.utils.to_python_parsers import get_date, get_time, ParseError
from common.utils.date import timedelta2minutes, RunMask


log = logging.getLogger(__name__)

MAX_FORWARD_DAYS = 20
MAX_DUPLICATE_DIFF_IN_MINUTES = 3


class OlvenCysixFactory(CysixTSIFactory):
    def get_download_file_provider(self):
        log.info(u"Качаем данные")

        return OlvenHTTP2CysixFileProvider(self.package)

    def get_route_importer(self):
        return CysixOlvenTSIRouteImporter(self)


class CysixOlvenTSIRouteImporter(CysixTSIRouteImporter):
    def after_import(self):
        if not self.factory.package.new_trusted:
            self.remove_duplicates()

        super(CysixOlvenTSIRouteImporter, self).after_import()

    def remove_duplicates(self):
        log.info(u"Удаляем дубликаты")

        routes_ids = list(self.get_affected_routes().filter(script_protected=False).values_list('id', flat=True))
        threads = fast_get_threads_with_rtstations(RThread.objects.filter(route_id__in=routes_ids), fetch_stations=True)

        threads.sort(key=lambda t: int(t.hidden_number.split('x')[0]))

        for index, thread in enumerate(threads[:-1]):
            for thread2 in threads[index + 1:]:
                log.debug(u"Ищем дубликаты в %s %s", thread.uid, thread2.uid)

                self.clean_excess_year_days(thread, thread2)

        self.clean_empty_threads(threads)

    def clean_excess_year_days(self, thread1, thread2):
        path1 = [rts.station_id for rts in thread1.rtstations]
        path2 = [rts.station_id for rts in thread2.rtstations]

        if len(path1) < len(path2):
            short_thread = thread1
            short_path = path1

            long_thread = thread2
            long_path = path2

        elif len(path1) > len(path2):
            short_thread = thread2
            short_path = path2

            long_thread = thread1
            long_path = path1

        else:
            return

        if not self.is_sub_path(short_path, long_path):
            log.debug(u"%s не является подпутем %s",
                      short_thread.uid, long_thread.uid)
            return

        log.info(u"%s является подпутем %s проверяем подходят ли времена старта",
                 short_thread.uid, long_thread.uid)

        subpath_start_index = self.get_subpath_start_index(short_path, long_path)
        subpath_start_rts = long_thread.rtstations[subpath_start_index]

        today = environment.today()

        naive_long_start_dt = datetime.combine(today, long_thread.tz_start_time)

        match_long_start_dt = subpath_start_rts.get_departure_dt(naive_long_start_dt)

        # найдем ближайщую дату старта для короткой нитки - short_start_date
        try_short_start_date = match_long_start_dt.astimezone(short_thread.pytz).date()
        try_short_start_dates = daterange(try_short_start_date - timedelta(1),
                                          try_short_start_date + timedelta(1), include_end=True)
        diff_by_short_start_date = {}
        for start_date in try_short_start_dates:
            short_start_dt = short_thread.pytz.localize(datetime.combine(start_date, short_thread.tz_start_time))
            diff_by_short_start_date[start_date] = abs(timedelta2minutes(short_start_dt - match_long_start_dt))

        short_start_date, common_station_departure_diff = min(
            ((start_date, diff) for start_date, diff in diff_by_short_start_date.iteritems()), key=itemgetter(1)
        )

        if common_station_departure_diff >= MAX_DUPLICATE_DIFF_IN_MINUTES:
            log.info(u"Слишком большая разница в отправлении с общей станции не чистим дни хождения")

            return

        day_shift = (short_start_date - today).days
        long_mask_shifted = RunMask(long_thread.year_days, today=today).shifted(day_shift)

        short_mask = RunMask(short_thread.year_days, today=today)
        short_mask = short_mask.difference(long_mask_shifted)

        short_thread.year_days = str(short_mask)
        short_thread.save()

        log.info(u"Удаляем из нитки %s дни хождения т.к. ее маршрут содержится в %s",
                 short_thread.uid, long_thread.uid)

    def is_sub_path(self, short_path, long_path):
        start_index = self.get_subpath_start_index(short_path, long_path)
        if not start_index:
            return False

        subpath = long_path[start_index:start_index + len(short_path)]

        if len(subpath) != len(short_path):
            return False

        for s1, s2 in zip(short_path, subpath):
            if s1 != s2:
                return False

        return True

    def get_subpath_start_index(self, short_path, long_path):
        first_station = short_path[0]
        try:
            return long_path.index(first_station)
        except ValueError:
            return

    def clean_empty_threads(self, threads):
        for thread in threads:
            if not RunMask(thread.year_days):
                log.info(u"Удалили нитку %s %s", thread.uid, thread.title)
                thread.delete()


class OlvenHTTP2CysixFileProvider(PackageFileProvider):
    @cache_method_result_with_exception
    def get_cysix_file(self):
        filepath = self.get_package_filepath('cysix.xml')

        if os.path.exists(filepath):
            log.info(u"Данные уже были сконвертированы в общий xml %s", filepath)
            return filepath

        log.info(u"Конвертируем в общий xml %s", filepath)

        self.dowload_and_convert_data(filepath)

        log.info(u"Данные сконвертированы в общий xml %s", filepath)

        return filepath

    def dowload_and_convert_data(self, filepath):
        c = Converter()

        return c.convert(filepath, self.package)


class Converter(object):
    def convert(self, filepath, package):
        self.package = package

        self.stations = []
        self.carriers = []
        self.vehicles = []
        self.fares = []

        self.cache_routes = self.get_routes()
        self.cache_cities = self.get_cities()
        self.cache_cities_on_route = self.get_cities_on_route()
        self.cache_timetables = self.get_timetables()

        self.today = environment.today()

        self.mask_builder = MaskBuilder(self.today - timedelta(1), self.today + timedelta(MAX_FORWARD_DAYS))

        channel = etree.Element('channel', {
            'version': '1.0',
            't_type': 'bus',
            'station_code_system': 'vendor',
            'carrier_code_system': 'local',
            'vehicle_code_system': 'local',
            'timezone': 'start_station'
        })

        group = etree.SubElement(channel, 'group', {
            'code': 'default'
        })

        threads = self.parse_threads()

        if self.stations:
            stations_el = etree.SubElement(group, 'stations')
            stations_el.extend(self.stations)

        if self.vehicles:
            vehicles_el = etree.SubElement(group, 'vehicles')
            vehicles_el.extend(self.vehicles)

        if self.carriers:
            carriers_el = etree.SubElement(group, 'carriers')
            carriers_el.extend(self.carriers)

        if threads:
            threads_el = etree.SubElement(group, 'threads')
            threads_el.extend(threads)

        if self.fares:
            fares_el = etree.SubElement(group, 'fares')
            fares_el.extend(self.fares)

        with open(filepath, 'w') as f:
            f.write(etree.tostring(channel, encoding='utf8', xml_declaration=True, pretty_print=True))

    @cache_method_result
    def get_routes(self):
        routes = {}

        routes_file = get_olven_file(self.package, 'routes')
        tree = etree.parse(routes_file)
        for route_el in tree.findall('.//route'):
            number = get_olven_number(route_el)

            routes[number] = route_el

        return routes

    @cache_method_result
    def get_cities(self):
        stations = {}

        cities_file = get_olven_file(self.package, 'cities')
        for city in etree.parse(cities_file).findall('.//city'):
            title = get_sub_tag_text(city, 'name').strip()
            code = get_sub_tag_text(city, 'city_id').strip()

            stations[code] = title

        return stations

    @cache_method_result
    def get_cities_on_route(self):
        route_stations = {}

        cities_on_routes_file = get_olven_file(self.package, 'cities_on_routes')
        for city in etree.parse(cities_on_routes_file).findall('.//station'):
            try:
                number = get_olven_number(city)
                olven_id = city.find('city_id').text
                olven_order = int(city.find('city_srt').text)
            except (AttributeError, TypeError):
                log.exception(u"Нехватает необходимых тегов в файле списка станций," +
                              u" или не правильно задан порядковый номер станции")
                continue

            route_stations.setdefault(number, []).append((olven_order, olven_id))

        route_stations_filtered = {}

        for number, stations in route_stations.iteritems():
            if len(stations) < 2:
                log.error(u"Для маршрута '%s' не хватает станций",
                          number)
                continue

            stations.sort(key=lambda x: x[0])

            route_stations_filtered[number] = stations

        return route_stations_filtered

    @cache_method_result
    def get_timetables(self):
        timetables = {}

        timetable_file = get_olven_file(self.package, 'timetable')

        for run_el in etree.parse(timetable_file).findall('.//run'):
            number = get_olven_number(run_el)
            timetables.setdefault(number, list()).append(run_el)

        return timetables

    def parse_threads(self):
        threads = []

        for number in self.cache_routes:
            try:
                threads.extend(self.parse_route(number))
            except RaspImportError as e:
                log.error(u"Ошибка при разборе %s: %s", number, unicode(e))
            except Exception:
                log.exception(u"Неожиданная ошибка при разборе %s", number)

        return threads

    def parse_route(self, number):
        threads = []

        route_el = self.get_route_el(number)
        route_stations = self.get_route_stations(number)
        timetable_els = self.get_timetable_els(number)

        for timetable_el in timetable_els:
            if get_sub_tag_text(timetable_el, 'run_status').strip().lower() in (u"запрет", u"отмена"):
                log.info(u"Пропускаем отмененную нитку %s", number)
                continue

            try:
                threads.append(self.parse_thread(number, route_el, route_stations, timetable_el))
            except RaspImportError as e:
                log.error(u"Ошибка при разборе нитки %s: %s", number, unicode(e))
            except Exception:
                log.exception(u"Неожиданная ошибка при разборе нитки %s", number)

        return threads

    def parse_thread(self, number, route_el, route_stations, timetable_el):
        schedule_el = self.make_schedule(timetable_el)
        departure_time = get_sub_tag_text(timetable_el, 'departure_time').strip()
        arrival_time = get_sub_tag_text(timetable_el, 'arrival_time').strip()

        mask = self.get_mask(schedule_el)

        stops_info = None

        for day in mask.iter_dates():
            stops_info = self.get_stops_info(number, day, departure_time)
            if stops_info:
                break

        thread_el = etree.Element('thread')
        thread_el.set('number', number)

        self.add_stops_and_fares(thread_el, route_stations, stops_info, departure_time, arrival_time)

        schedules_el = etree.SubElement(thread_el, 'schedules')
        schedules_el.append(schedule_el)
        schedule_el.set('times', departure_time)

        self.add_carrier(thread_el, timetable_el)
        self.add_vehicle(thread_el, timetable_el)

        return thread_el

    def get_mask(self, schedule_el):
        return CysixMaskParser.parse_mask(schedule_el, self.mask_builder)

    def make_schedule(self, timetable_el):
        schedule_el = etree.Element('schedule')

        schedule_el.attrib['period_start_date'] = get_sub_tag_text(timetable_el, 'begin_date').strip()
        schedule_el.attrib['period_end_date'] = get_sub_tag_text(timetable_el, 'end_date').strip()
        run_days = get_sub_tag_text(timetable_el, 'run_days').strip()
        run_period_id = get_sub_tag_text(timetable_el, 'run_period_id').strip()

        if run_days:
            schedule_el.attrib['days'] = run_days.replace(u",", u"")
        elif run_period_id == u"1":
            schedule_el.attrib['days'] = u"нечетные"
        elif run_period_id == u"2":
            schedule_el.attrib['days'] = u"четные"
        elif run_period_id == u"10":
            schedule_el.attrib['days'] = u"ежедневно"
        elif run_period_id in u"3456789":
            schedule_el.attrib['days'] = u"через {}".format(int(run_period_id) - 2)
        else:
            raise RaspImportError(
                u"Ошибка преобразования маски run_period_id={run_period_id} run_days='{run_days}'"
                .format(run_period_id=run_period_id, run_days=run_days)
            )

        return schedule_el

    def get_route_el(self, number):
        try:
            return self.cache_routes[number]
        except KeyError:
            raise RaspImportError(u"Нет общей информации о маршруте {}".format(number))

    def get_route_stations(self, number):
        try:
            return self.cache_cities_on_route[number]
        except KeyError:
            raise RaspImportError(u"Нет станций маршрута {}".format(number))

    def get_timetable_els(self, number):
        try:
            return self.cache_timetables[number]
        except KeyError:
            raise RaspImportError(u"Нет расписания маршрута {}".format(number))

    def get_stops_info(self, number, day, departure_time):
        parsed = self.get_parsed_stops_infos(day)

        return parsed.get((number, departure_time))

    @cache_method_result_with_exception
    def get_parsed_stops_infos(self, day):
        filename = get_olven_file(self.package, 'day_timetable', departure_day=day, calc_time=1)

        route_infos = {}

        for run_el in etree.parse(filename).findall('.//run'):

            number = get_olven_number(run_el)

            start_time = get_sub_tag_text(run_el, 'departure_time').strip()

            result = {}

            stations = run_el.findall('.//station')

            # информация по станциям не встретилась - плохо
            if stations is None:
                continue

            first_station = stations[0]
            has_bad_times = False
            try:
                station_start_time = get_sub_tag_text(first_station, 'departure_time_calc').strip()

                if station_start_time != start_time:
                    has_bad_times = True
                    start_dt = None
                else:
                    start_dt = datetime.combine(
                        get_date(get_sub_tag_text(first_station, 'departure_day_calc'), fmt="%d.%m.%Y"),
                        get_time(get_sub_tag_text(first_station, 'departure_time_calc'))
                    )
            except (AttributeError, TypeError, ParseError):
                continue

            distance_base = 0
            try:
                distance_base = float(get_sub_tag_text(first_station, 'dist'))
            except (AttributeError, ValueError):
                pass

            for station_el in stations:
                departure = None
                arrival = None
                arrival_day_shift = None
                departure_day_shift = None

                distance = None
                try:
                    distance = float(get_sub_tag_text(station_el, 'dist')) - distance_base
                except (AttributeError, ValueError):
                    pass

                if not has_bad_times:
                    try:
                        departure_dt = datetime.combine(
                            get_date(get_sub_tag_text(station_el, 'departure_day_calc'), fmt="%d.%m.%Y"),
                            get_time(get_sub_tag_text(station_el, 'departure_time_calc'))
                        )

                        departure = timedelta2minutes(departure_dt - start_dt)
                        departure_day_shift = (departure_dt.date() - start_dt.date()).days

                    except (AttributeError, ValueError, ParseError):
                        pass

                    try:
                        arrival_dt = datetime.combine(
                            get_date(get_sub_tag_text(station_el, 'arrival_day_calc'), fmt="%d.%m.%Y"),
                            get_time(get_sub_tag_text(station_el, 'arrival_time_calc'))
                        )

                        arrival = timedelta2minutes(arrival_dt - start_dt)
                        arrival_day_shift = (arrival_dt.date() - start_dt.date()).days

                    except (AttributeError, ValueError, ParseError):
                        pass

                    if arrival and departure:
                        if arrival == departure:
                            departure = None

                try:
                    tariff = station_el.find('price').text and float(station_el.find('price').text)
                except (ValueError, TypeError):
                    log.error(u"Не смогли разобрать тариф %s", station_el.find('price').text)
                    tariff = None

                if tariff and tariff < 10:
                    # Игнорируем тариф меньше 10 рублей
                    tariff = None

                olven_order = int(station_el.find('srt').text)

                result[olven_order] = {
                    'distance': distance,
                    'departure': departure,
                    'arrival': arrival,
                    'arrival_day_shift': arrival_day_shift,
                    'departure_day_shift': departure_day_shift,
                    'title': get_sub_tag_text(station_el, 'name').strip(),
                    'tariff': tariff
                }

            if number not in route_infos:
                route_infos[number] = result

            route_infos[number, start_time] = result

        return route_infos

    def add_stops_and_fares(self, thread_el, route_stations, stops_info, departure_time, arrival_time):
        stoppoints = []

        prices = []

        start_code, start_title = route_stations[0][1], self.get_station_title(route_stations[0][1])

        for order, code in route_stations:
            title = self.get_station_title(code)

            info = self.get_info(order, title, stops_info)

            stoppoint = etree.Element('stoppoint', {
                'station_title': title,
                'station_code': code
            })

            if info:
                if info['departure']:
                    stoppoint.set('departure_shift', unicode(info['departure'] * 60))

                if info['arrival']:
                    stoppoint.set('arrival_shift', unicode(info['arrival'] * 60))

                if info['distance']:
                    stoppoint.set('distance', unicode(info['distance']))

                if info['tariff'] is not None:
                    price = etree.Element('price', {
                        'currency': 'RUR',
                        'price': unicode(info['tariff'])
                    })
                    etree.SubElement(price, 'stop_from', {
                        'station_code': start_code,
                        'station_title': start_title
                    })
                    etree.SubElement(price, 'stop_to', {
                        'station_code': code,
                        'station_title': title
                    })

                    prices.append(price)

            stoppoints.append(stoppoint)

        startpoint = stoppoints[0]
        startpoint.set('departure_time', departure_time)

        endpoint = stoppoints[-1]

        if 'arrival_shift' not in endpoint.attrib:
            endpoint.set('arrival_time', arrival_time)

        if prices:
            fare_code = unicode(len(self.fares))
            fare = etree.Element('fare', {"code": fare_code})
            fare.extend(prices)

            self.fares.append(fare)

            thread_el.set('fare_code', fare_code)

        stoppoints_el = etree.Element('stoppoints')
        stoppoints_el.extend(stoppoints)

        thread_el.append(stoppoints_el)

    @cache_method_result
    def get_station_title(self, code):
        title = self.cache_cities.get(code, u"")

        station = etree.Element('station', {
            'code': code,
            'title': title
        })
        etree.SubElement(station, 'legacy_station', {
            'type': 'raw',
            'code': code,
            'title': title
        })

        self.stations.append(station)

        return title

    def get_info(self, order, title, stops_info):
        if stops_info is None:
            return None

        try:
            info = stops_info[order]
        except IndexError:
            return None

        if info['title'].lower() != title.lower():
            log.error(u"Название в списке станции маршрута и название в информации о рейсе не совпадают, %s: %s != %s",
                      order, title, info['title'])
            return None

        return info

    def add_carrier(self, thread_el, timetable_el):
        carrier_title = get_sub_tag_text(timetable_el, 'atp', silent=True).strip().replace(ur'\"', u'"')
        if carrier_title:
            code = self.create_carrier(carrier_title)
            thread_el.set('carrier_code', code)

    @cache_method_result
    def create_carrier(self, carrier_title):
        code = unicode(len(self.carriers))
        self.carriers.append(etree.Element('carrier', {
            'title': carrier_title,
            'code': code
        }))

        return code

    def add_vehicle(self, thread_el, timetable_el):
        vehicle_title = get_sub_tag_text(timetable_el, 'marka', silent=True).strip().replace(ur'\"', u'"')
        if vehicle_title:
            code = self.create_vehicle(vehicle_title)
            thread_el.set('vehicle_code', code)

    @cache_method_result
    def create_vehicle(self, vehicle_title):
        code = unicode(len(self.vehicles))
        self.vehicles.append(etree.Element('vehicle', {
            'title': vehicle_title,
            'code': code
        }))

        return code


def get_olven_number(element):
    return get_sub_tag_text(element, 'route_id').strip() + u"x" + get_sub_tag_text(element, 'server_id').strip()
