# -*- encoding: utf-8 -*-

"""
Особенности формата и данных:

1. Для некоторых маршрутов сдвиг отправления для первой станции не нулевой.
   Нужно поправить times и сдвиги времен.

2. Временная зона указана в schedule.
   Нужно каждое расписание складывать в отдельную нитку.

3. Наборы цен со своими датами. Иногда для одной нитки есть пересекающиеся тарифы.
   Нужно бить расписание хождения на интервалы с разными ценами и без,
"""

from bisect import bisect_right, bisect_left
from datetime import date, datetime, timedelta
from itertools import izip, islice

from lxml import etree
import pytz

from common.cysix.builder import ChannelBlock, GroupBlock, StationBlock, StoppointBlock, ThreadBlock, ScheduleBlock
from common.utils.caching import cache_method_result
from cysix.base import safe_parse_xml
from cysix.tsi_converter import CysixTSIConverterFactory, CysixTSIConverterFileProvider
from travel.rasp.admin.lib.logs import get_current_file_logger
from travel.rasp.admin.scripts.schedule.utils import RaspImportError
from travel.rasp.admin.scripts.schedule.utils.file_providers import XmlPackageFileProvider, PackageFileProvider
from travel.rasp.admin.scripts.utils.to_python_parsers import parse_time, parse_date


log = get_current_file_logger()


SCHEDULE_FILE_URL = 'http://www.mrtrans.ru/scripts/xml/datas/yandex_xml.xml'

CURRENCY = 'RUR'

JUST_DATE = date(2014, 8, 8)  # Используется для сдвига времени начала движения


class MrtransCysixFactory(CysixTSIConverterFactory):
    def get_raw_download_file_provider(self):
        return MrtransRawDownloadFileProvider(self.package)

    def get_raw_package_file_provider(self):
        return XmlPackageFileProvider(self.package)

    def get_converter_file_provider(self, raw_file_provider):
        return MrtransCysixFileProvider(self.package, raw_file_provider)


class MrtransCysixFileProvider(CysixTSIConverterFileProvider):
    def convert_data(self, filepath):
        c = Converter(self.raw_file_provider)

        c.convert(filepath)


class ConvertError(RaspImportError):
    pass


class Converter(object):
    def __init__(self, provider):
        self.provider = provider

        self.stations = dict()

    def convert(self, filepath):
        files = self.provider.get_schedule_files()

        if not files:
            raise ConvertError(u'Не нашли файл с расписанием.')

        raw_filepath = files[0]

        log.info(u"Парсим файл %s", raw_filepath)

        root_el = self.get_root_element(raw_filepath)

        channel_el = self.build_channel_element(root_el)

        with open(filepath, 'w') as f:
            f.write(etree.tostring(channel_el, xml_declaration=True, encoding='utf-8', pretty_print=True))

    def build_channel_element(self, root_el):
        channel_block = ChannelBlock(
            'bus',
            station_code_system='vendor',
            carrier_code_system='vendor',
            vehicle_code_system='local',
            timezone='start_station',
            country_code_system='yandex',
            region_code_system='yandex',
            settlement_code_system='yandex',
        )

        self.group_block = GroupBlock(channel_block, title='all', code='all')

        self.build_stations(root_el)

        self.build_threads(root_el)

        channel_block.add_group_block(self.group_block)

        return channel_block.get_element()

    def build_threads(self, root_el):
        for thread_el in root_el.findall('./thread'):
            try:
                raw_thread = RawThread(thread_el)

            except RawThreadConstructError, e:
                log.error(u'%s', e)

            for raw_schedule in raw_thread.schedules:
                try:
                    self.build_thread(raw_thread, raw_schedule)

                except ConvertError, e:
                    log.error(u'Не получилось сконвертировать часть расписания для %s %s. %s',
                              raw_thread.number, raw_thread.title, e)

    def build_thread(self, raw_thread, raw_schedule):
        thread_block = ThreadBlock(self.group_block, raw_thread.title, raw_thread.number,
                                   timezone=raw_schedule.timezone)

        for stop in raw_thread.stoppoints:
            stop_block = self.get_station_block(stop.code)

            thread_block.add_stoppoint_block(StoppointBlock(
                thread_block,
                stop_block,
                arrival_shift=stop.arrival_shift,
                departure_shift=stop.departure_shift,
                distance=stop.distance
            ))

        thread_block.add_schedule_block(ScheduleBlock(
            thread_block,
            raw_schedule.days,
            times=raw_schedule.times,
            canceled=raw_schedule.canceled,
            period_start_date=raw_schedule.start_date.strftime('%Y-%m-%d'),
            period_end_date=raw_schedule.end_date.strftime('%Y-%m-%d')
        ))

        raw_fare = next((f for f in raw_thread.fares if f.is_in(raw_schedule.start_date)), None)

        if raw_fare:
            fare_block = self.group_block.add_local_fare()
            thread_block.set_fare_block(fare_block)

            for price in raw_fare.prices:
                from_block = self.get_station_block(price.from_code)
                to_block = self.get_station_block(price.to_code)
                value = '%.2f' % price.price

                fare_block.add_price_block(value, CURRENCY, from_block, to_block, data=None)

        thread_block.set_raw_data(etree.tostring(raw_thread.thread_el, pretty_print=True,
                                  xml_declaration=False, encoding=unicode))

        self.group_block.add_thread_block(thread_block)

    def build_stations(self, root_el):
        for station_el in root_el.findall('./stopslist/stoppoint'):
            code = station_el.get('vendor_id', u'').strip()
            title = station_el.get('name', u'').strip()

            station_block = StationBlock(self.group_block, title, code)

            station_block.country_code = station_el.get('country_id', u'').strip()
            station_block.region_code = station_el.get('region_id', u'').strip()
            station_block.settlement_code = station_el.get('settlement_id', u'').strip()

            station_block.lon = station_el.get('lon', u'').strip()
            station_block.lat = station_el.get('lat', u'').strip()

            station_block.add_legacy_station(title, code)

            if code not in self.stations:
                self.stations[code] = station_block

                self.group_block.add_station_block(station_block)

    def get_station_block(self, code):
        try:
            return self.stations[code]

        except KeyError:
            raise ConvertError(u'Нет станции с кодом {}.'.format(code))

    def get_root_element(self, filepath):
        try:
            return safe_parse_xml(filepath).getroot()

        except etree.LxmlError, e:
            raise ConvertError(u'Ошибка разбора xml-файла {}. {}'.format(filepath, e))


class RawThreadConstructError(ConvertError):
    pass


class RawThread(object):
    def __init__(self, thread_el):
        self.thread_el = thread_el

        self.title = thread_el.get('title', u'').strip()
        self.number = thread_el.get('number', u'').strip()

        self.stoppoints = self.get_stoppoints(thread_el)

        if not self.stoppoints:
            raise RawThreadConstructError(
                u'Нет станций следования для {} {}.'.format(self.number, self.title))

        shift = self.stoppoints[0].departure_shift

        has_shift = shift and int(shift) != 0

        if has_shift:
            shift = int(shift)

            self.fix_stoppoints_shift(shift)

            log.warning(u'Не нулевой сдвиг времени для отправления с первой станции %s %s, %s минут.',
                        self.number, self.title, shift / 60)

        self.fares = self.get_fares(thread_el)

        self.split_fares()

        self.schedules = self.get_schedules(thread_el, has_shift, shift)

    def get_schedules(self, thread_el, has_shift, shift):
        schedules = []

        for schedule_el in thread_el.findall('./schedules/schedule'):
            temp_schedules = self.get_sub_schedules(schedule_el, has_shift, shift)

            for schedule in temp_schedules:
                schedules += self.split_schedule_by_fares(schedule)

        return schedules

    @classmethod
    def get_sub_schedules(cls, schedule_el, has_shift, thread_shift):
        schedules = []

        timezone = schedule_el.get('timezone', u'').strip()

        times = schedule_el.get('times', u'').strip()
        days = schedule_el.get('days', u'').strip()

        period_end_date = get_date(schedule_el.get('period_end_date', u'').strip())
        period_start_date = get_date(schedule_el.get('period_start_date', u'').strip())

        canceled = schedule_el.get('canceled', u'').strip()

        if not has_shift:
            schedules.append(RawSchedule(
                timezone, days.replace(',', ''), times,
                period_start_date, period_end_date, canceled
            ))

        else:
            days = map(int, days.split(','))
            times = map(cls.get_local_start_time, times.split(';'))

            for start_time in times:
                new_days, new_start_time = cls.shift_schedule(thread_shift, timezone, days, start_time)

                log.info(u'%s, %s + %s минут -> %s, %s', days, start_time.strftime('%H:%M'),
                         thread_shift / 60, new_days, new_start_time)

                schedules.append(RawSchedule(
                    timezone, new_days, new_start_time,
                    period_start_date, period_end_date, canceled
                ))

        return [s for s in schedules if not s.empty]

    @classmethod
    def shift_schedule(cls, thread_shift, timezone_str, days, start_time):
        timezone = pytz.timezone(timezone_str)

        dt = datetime.combine(JUST_DATE, start_time).replace(tzinfo=timezone)

        new_dt = dt + timedelta(seconds=thread_shift)

        new_time = new_dt.strftime('%H:%M')

        day_shift = (new_dt.date() - dt.date()).days

        new_days = [u'{}'.format((d + day_shift) % 7) for d in days]

        return u''.join(new_days).replace(u'0', u'7'), new_time

    def get_fares(self, thread_el):
        fares = [RawFare.from_xml(f_el) for f_el in thread_el.findall('./fares/fare')]

        return [f for f in fares if not f.empty]

    def get_stoppoints(self, thread_el):
        return [RawStoppoint(s_el) for s_el in thread_el.findall('./stations/station')]

    def fix_stoppoints_shift(self, shift):
        for stoppoint in self.stoppoints:
            if stoppoint.arrival_shift:
                stoppoint.arrival_shift = str(int(stoppoint.arrival_shift) - shift)

            if stoppoint.departure_shift:
                stoppoint.departure_shift = str(int(stoppoint.departure_shift) - shift)

    @classmethod
    def get_local_start_time(cls, time_str):
        return parse_time(time_str, fmt='%H:%M')

    def split_fares(self):
        if len(self.fares) < 2:
            return

        dates = self.calculate_fare_split_dates()

        new_fares = []

        for start_date, end_date in izip(dates, islice(dates, 1, None)):
            current_fares = [f for f in self.fares if f.is_in(f.start_date)]

            prices = unite_prices(current_fares)

            new_fares.append(RawFare(start_date, end_date, prices))

        self.fares = new_fares

    @cache_method_result
    def calculate_fare_split_dates(self):
        dates = set()

        for fare in self.fares:
            dates.add(fare.start_date)
            dates.add(fare.end_date)

        return sorted(dates)

    def split_schedule_by_fares(self, schedule):
        dates = self.calculate_fare_split_dates()

        if not dates:
            return [schedule]

        if schedule.end_date <= dates[0]:
            return [schedule]

        if schedule.start_date >= dates[-1]:
            return [schedule]

        start_index = bisect_right(dates, schedule.start_date)
        end_index = bisect_left(dates, schedule.end_date)

        dates = [schedule.start_date] + dates[start_index:end_index] + [schedule.end_date]

        new_schedules = []

        for start_date, end_date in izip(dates, islice(dates, 1, None)):
            new_schedules.append(RawSchedule.from_schedule(schedule, start_date, end_date))

        return new_schedules


def unite_prices(fares):
    prices = dict()

    for fare in fares:
        for price in fare.prices:
            key = (price.from_code, price.to_code)

            if key not in prices:
                prices[key] = price.price

            else:
                prices[key] = max(prices[key], price.price)

    return [RawPrice(from_code, to_code, price) for (from_code, to_code), price in prices.iteritems()]


class RawSchedule(object):
    def __init__(self, timezone, days, times, start_date, end_date, canceled):
        self.timezone = timezone
        self.days = days
        self.times = times
        self.start_date = start_date
        self.end_date = end_date
        self.canceled = canceled

        self.empty = False

        if self.start_date is None:
            self.empty = True

        if self.end_date is None:
            self.empty = True

        if not self.empty and self.start_date >= self.end_date:
            self.empty = True

    @classmethod
    def from_schedule(cls, schedule, start_date, end_date):
        return cls(schedule.timezone, schedule.days, schedule.times,
                   start_date, end_date, schedule.canceled)


class RawFare(object):
    def __init__(self, start_date, end_date, prices):
        self.start_date = start_date
        self.end_date = end_date

        self.empty = False

        self.prices = [p for p in prices if p.price is not None and p.price > 0]

        if not self.prices:
            self.empty = True

        if self.start_date is None:
            self.empty = True

        if self.end_date is None:
            self.empty = True

        if not self.empty and self.start_date >= self.end_date:
            self.empty = True

    @classmethod
    def from_xml(cls, fare_el):
        start_date = get_date(fare_el.get('period_start_date', u'').strip())
        end_date = get_date(fare_el.get('period_end_date', u'').strip())

        prices = [RawPrice.from_xml(p_el) for p_el in fare_el.findall('./price')]

        return cls(start_date, end_date, prices)

    def is_in(self, date):
        return self.start_date <= date < self.end_date


def get_date(date_str):
    return parse_date(date_str, default=None)


class RawPrice(object):
    def __init__(self, from_code, to_code, price):
        self.from_code = from_code
        self.to_code = to_code
        self.price = price

    @classmethod
    def from_xml(cls, price_el):
        from_code = price_el.get('station_from_id', u'').strip()
        to_code = price_el.get('station_to_id', u'').strip()
        price = cls.get_price(price_el.get('price', u'').strip())

        return cls(from_code, to_code, price)

    @classmethod
    def get_price(cls, price_str):
        try:
            return float(price_str)

        except (ValueError, TypeError):
            return None


class RawStoppoint(object):
    def __init__(self, stoppoint_el):
        self.code = stoppoint_el.get('vendor_id', u'').strip()

        self.departure_shift = stoppoint_el.get('departure_time', u'').strip()
        self.arrival_shift = stoppoint_el.get('arrival_time', u'').strip()

        self.distance = stoppoint_el.get('distance', u'').strip()


class MrtransRawDownloadFileProvider(PackageFileProvider):
    filename = u'yandex_xml.xml'

    def get_schedule_files(self):
        download_params = SCHEDULE_FILE_URL
        filepath = self.get_package_filepath(self.filename)
        filepath = self.download_file(download_params, filepath)

        return [filepath]
