# coding: utf-8

import os.path
import datetime
import re
import json
from urllib2 import urlopen

from django.conf import settings
from lxml import etree

from common.cysix.builder import ChannelBlock, GroupBlock, ThreadBlock, ScheduleBlock, StoppointBlock
from cysix.two_stage_import.factory import CysixTSIFactory
from common.utils.caching import cache_method_result_with_exception, cache_method_result
from travel.rasp.admin.lib.logs import get_current_file_logger
from travel.rasp.admin.scripts.schedule.utils import RaspImportError
from travel.rasp.admin.scripts.schedule.utils.file_providers import PackageFileProvider


log = get_current_file_logger()

TKVC_URL = u"http://www.tkvc.ru/rasp/"

CURRENCY = u'RUR'

MASK_MAX_EXTEND_DAYS = 60


class TkvcCysixFactory(CysixTSIFactory):
    def get_download_file_provider(self):
        log.info(u"Качаем данные %s", self.package.url)

        return TkvcCysixFileProvider(self.package)

    def get_package_file_provider(self):
        return TkvcCysixPackageFileProvider(self.package)

    @cache_method_result
    def get_settings(self):
        settings = super(TkvcCysixFactory, self).get_settings()

        settings.calculate_not_filled_last_stations_times = True
        settings.calculate_bad_distances_from_geo = True

        return settings


class TkvcCysixPackageFileProvider(PackageFileProvider):
    @cache_method_result_with_exception
    def get_cysix_file(self):
        filepath = self.get_package_filepath('cysix.xml')

        if os.path.exists(filepath) and not settings.DEBUG:
            log.info(u"Данные уже были сконвертированы в общий xml %s", filepath)
            return filepath

        log.info(u"Конвертируем в общий xml %s", filepath)

        self.unzip_and_convert_data(filepath)

        log.info(u"Данные сконвертированы в общий xml %s", filepath)

        return filepath

    def unzip_and_convert_data(self, filepath):
        c = Converter(self)
        filepath = c.convert(filepath)
        return filepath

    def get_schedule_files(self):
        return self.get_files_from_archive_with_ext('xml')


class TkvcCysixFileProvider(PackageFileProvider):
    @cache_method_result_with_exception
    def get_cysix_file(self):
        filepath = self.get_package_filepath('cysix.xml')

        if os.path.exists(filepath) and not settings.DEBUG:
            log.info(u"Данные уже были сконвертированы в общий xml %s", filepath)
            return filepath

        log.info(u"Конвертируем в общий xml %s", filepath)

        self.download_and_convert_data(filepath)

        log.info(u"Данные сконвертированы в общий xml %s", filepath)

        return filepath

    def download_and_convert_data(self, filepath):
        raw_provider = TkvcRawFileProvider(self.package)
        c = Converter(raw_provider)
        filepath = c.convert(filepath)
        return filepath


raw_file_pattern = re.compile('<a\s+href\s*=\s*"(\d+\.xml)"')


class TkvcRawFileProvider(PackageFileProvider):
    def get_av_filelist(self):
        urlpath = urlopen(TKVC_URL, timeout=settings.SCHEDULE_IMPORT_TIMEOUT)
        string = urlpath.read()
        filelist = raw_file_pattern.findall(string)
        return filelist

    def get_schedule_files(self):
        schedule_files = []
        for filename in self.get_av_filelist():
            download_params = TKVC_URL + filename
            filepath = self.get_package_filepath(filename)
            filepath = self.download_file(download_params, filepath)
            schedule_files.append(filepath)
        return schedule_files


class ConvertError(RaspImportError):
    pass


class Converter(object):
    def __init__(self, provider):
        self.provider = provider

    def convert(self, filepath):
        channel_block = ChannelBlock('bus', station_code_system='vendor', carrier_code_system='local',
                                     vehicle_code_system='local', timezone='start_station')

        for av_filepath in self.provider.get_schedule_files():
            log.info(u"Парсим файл %s", av_filepath)

            root_element = self._get_etree_from_file(av_filepath)

            try:
                group_block = self.build_group(channel_block, root_element)
                channel_block.add_group_block(group_block)

            except ConvertError, e:
                log.error(unicode(e))

        channel_el = channel_block.get_element()

        self.sort_by_distance(channel_el)

        channel_el = self.add_legacy_stations(channel_el)

        with open(filepath, 'w') as f:
            f.write(etree.tostring(channel_el, xml_declaration=True, encoding='utf-8', pretty_print=True))

    def _get_etree_from_file(self, filepath):
        with open(filepath) as f:
            data = f.read()
            data = re.sub('[\x00\x01\x02\x03\x04\x05]', "", data)
            return etree.fromstring(data)

    def build_group(self, channel_block, root_element):
        region_code = self.get_region_code(root_element)

        av_element = self.get_av_element(root_element)
        av_code = av_element.attrib.get('AV_Code', '').strip()
        av_title = av_element.attrib.get('AV_Name', u'').strip()

        if not av_code:
            raise ConvertError(u"Не указан код автовокзала.")

        if not av_title:
            raise ConvertError(u"Не указано название автовокзала.")

        group_code = '%s_%s' % (av_code, region_code)
        group_title = av_title

        group_block = GroupBlock(channel_block, group_title, group_code)

        start_station_code = '%s_%s' % (av_code, region_code)
        start_station_title = av_title

        start_station_block = group_block.add_station(start_station_title, start_station_code)

        for raw_routes in av_element.findall('./routes'):
            try:
                route_el = self.get_route_el(raw_routes)

                thread_block = self.build_thread_block(group_block, start_station_block, raw_routes, route_el, av_code)

                if thread_block is not None:
                    raw_data = etree.tostring(raw_routes, pretty_print=True, xml_declaration=False, encoding=unicode)
                    thread_block.set_raw_data(raw_data)

                    group_block.add_thread_block(thread_block)
            except ConvertError, e:
                log.error(u"Пропускаем нитку. %s", unicode(e))

        return group_block

    def build_thread_block(self, group_block, start_station_block, raw_routes, route_el, av_code):
        route_title = route_el.get('route_title', u'')

        route_start_time = route_el.get('route_time_start', u'').strip()

        raw_days = route_el.get('route_reg', u'').strip()

        carrier_title = route_el.get('route_atp', u'').strip()
        carrier_title = re.sub(ur"\s+", u" ", carrier_title)

        vehicle_title = route_el.get('route_awt', u'').strip()
        vehicle_title = re.sub(ur"\s+", u" ", vehicle_title)

        has_sales = route_el.get('route_inet_sale', u'0').strip() != u'0'

        days, period_start_date, exclude = parse_mask(route_title, raw_days, datetime.date.today())
        if not days:
            log.error(u"Нитка '%s'. Маска дней хождении '%s' не распознана." % (route_title, raw_days))
            return

        thread_block = ThreadBlock(group_block, route_title, sales=has_sales)

        thread_block.add_schedule_block(ScheduleBlock(thread_block, days, times=route_start_time,
                                        period_start_date=period_start_date, exclude_days=exclude))

        fare_block = group_block.add_local_fare()

        thread_block.set_fare_block(fare_block)

        if carrier_title:
            carrier_block = group_block.add_local_carrier(carrier_title)
            thread_block.carrier = carrier_block

        if vehicle_title:
            vehicle_block = group_block.add_local_vehicle(vehicle_title)
            thread_block.vehicle = vehicle_block

        self.build_stoppoints(group_block, thread_block, fare_block, start_station_block, raw_routes, route_el, av_code)

        return thread_block

    def build_stoppoints(self, group_block, thread_block, fare_block, start_station_block, raw_routes, route_el,
                         av_code):
        route_nap_code = route_el.get('route_nap_cod', u'')

        start_stoppoint_block = StoppointBlock(thread_block, start_station_block)
        start_stoppoint_block.departure_shift = 0
        start_stoppoint_block.distance = 0

        thread_block.add_stoppoint_block(start_stoppoint_block)

        stations_el = self.get_stations_el(raw_routes)
        all_stations = stations_el.findall('./station')

        for i, station_el in enumerate(all_stations):
            title = station_el.get('station_title', u'').strip()
            station_code = station_el.get('station_code', u"").strip()
            code = '%s_%s' % (station_code, route_nap_code)
            distance = station_el.get('station_rasst', u"").strip()

            station_block = group_block.add_station(title, code)

            stoppoint_block = StoppointBlock(thread_block, station_block)

            stoppoint_block.arrival_shift = self.get_arrival_shift(station_el)

            if stoppoint_block.arrival_shift is not None:
                stop_time = self.get_stop_time(station_el)

                if stop_time is not None:
                    stoppoint_block.departure_shift = stoppoint_block.arrival_shift + stop_time

            stoppoint_block.distance = distance

            if i == len(all_stations) - 1:
                if stoppoint_block.arrival_shift is None:
                    route_time_stop = route_el.get('route_time_stop', u'').strip()

                    if route_time_stop and not route_time_stop.startswith(u'00:00'):
                        stoppoint_block.arrival_time = route_time_stop

            thread_block.add_stoppoint_block(stoppoint_block)

            price_value = station_el.get('station_fullpla', u"").strip()
            if price_value:
                order_data = {
                    "av_code": av_code,
                    "station_title": title,
                    "station_code": code,
                }

                fare_block.add_price_block(
                    price_value,
                    CURRENCY,
                    start_station_block,
                    station_block,
                    data=json.dumps(order_data, ensure_ascii=False, encoding='utf8')
                )

    @classmethod
    def get_arrival_shift(cls, station_el):
        shift = station_el.get('station_time_start', '').strip()

        try:
            shift = int(shift) * 60
        except ValueError:
            shift = None

        return shift

    @classmethod
    def get_stop_time(cls, station_el):
        stop_time = station_el.get('station_time_stop', '').strip()

        try:
            stop_time = int(stop_time) * 60
        except ValueError:
            stop_time = None

        return stop_time

    def get_region_code(self, root_element):
        region_elements = root_element.findall('./Region')

        if len(region_elements) == 0:
            raise ConvertError(u"Не найден ни один регион.")

        if len(region_elements) > 1:
            log.warning(u"В файле больше одного региона, импортируем только первый.")

        code = region_elements[0].get('Region_Code', '').strip()

        return code

    def get_av_element(self, root_element):
        av_elements = root_element.findall('./Region/A_Vokzal')

        if len(av_elements) == 0:
            raise ConvertError(u"Не найден ни один вокзал.")

        if len(av_elements) > 1:
            log.warning(u"В файле больше одного вокзала, импортируем только первый.")

        return av_elements[0]

    def get_route_el(self, routes_el):
        route_els = routes_el.findall('./route')

        if len(route_els) == 0:
            raise ConvertError(u"Не найден маршрут, пропускаем.")

        if len(route_els) > 1:
            log.warning(u"В маршруте больше одного маршрута, импортируем только первый.")

        return route_els[0]

    def get_stations_el(self, route_el):
        stations_els = route_el.findall('./stations')

        if len(stations_els) == 0:
            raise ConvertError(u"Не найдены станции маршрута.")

        if len(stations_els) > 1:
            log.warning(u"В маршруте больше одного набора станций, импортируем только первый.")

        return stations_els[0]

    def add_legacy_stations(self, channel_el):
        for stations_el in channel_el.findall('.//stations'):
            first_station_code = None
            for station_el in stations_el.findall('./station'):
                legacy_station_el = etree.Element('legacy_station')
                code = station_el.attrib.get('code', '')
                if first_station_code:
                    code = '%s_%s' % (code, first_station_code)
                else:
                    first_station_code = code
                legacy_station_el.set('code', code)
                legacy_station_el.set('title', station_el.attrib.get('title', ''))
                legacy_station_el.set('type', 'raw')
                station_el.append(legacy_station_el)
        return channel_el

    def sort_by_distance(self, channel_el):
        stoppoints_to_sort = self._find_stoppoints_to_sort(channel_el)
        self._sort_stoppoints_by_distance(stoppoints_to_sort)

    def _find_stoppoints_to_sort(self, channel_el):
        all_threads = 0
        not_sorted = 0

        stoppoints_to_sort = []
        for thread_el in channel_el.findall('.//thread'):
            thread_name = thread_el.attrib.get('title', u'')

            for stoppoints_el in thread_el.findall('.//stoppoints'):
                if self._need_to_sort(stoppoints_el, thread_name):
                    not_sorted += 1
                    stoppoints_to_sort.append(stoppoints_el)

            all_threads += 1

        log.info(u'Всего ниток {}, нужна сортировка {}'.format(all_threads, not_sorted))
        return stoppoints_to_sort

    def _need_to_sort(self, stoppoints_el, thread_name):
        distance = -1
        prev_title = None
        prev_code = None

        for stoppoint_el in stoppoints_el.findall('./stoppoint'):

            if 'distance' not in stoppoint_el.attrib:
                log.warning(u"Нитка '{}'. Нет расстояния!".format(thread_name))
                return False

            new_distance = float(stoppoint_el.attrib.get('distance', 0.0))
            if new_distance < distance:
                log.warning(u"Нитка '{}'. нужна сортировка по расстоянию {} < {}"
                            .format(thread_name, new_distance, distance))
                return True

            if new_distance == distance:
                if prev_code is not None:
                    log.warning(u"Нитка '{}'. Расстояние {} не изменилось '{}', '{}' - '{}', '{}'".format(
                        thread_name, new_distance,
                        prev_title, prev_code,
                        stoppoint_el.attrib.get('station_title', u''), stoppoint_el.attrib.get('station_code', u'')
                    ))

            distance = new_distance
            prev_title = stoppoint_el.attrib.get('station_title', u'')
            prev_code = stoppoint_el.attrib.get('station_code', u'')
        return False

    def _sort_stoppoints_by_distance(self, stoppoints_to_sort):
        for stoppoints_el in stoppoints_to_sort:
            stops = []
            for i, stoppoint_el in enumerate(stoppoints_el.findall('./stoppoint')):
                distance = float(stoppoint_el.attrib.get('distance', 0.0))
                stop = (distance, i, stoppoint_el)
                stops.append(stop)

            stops.sort()

            stoppoints_el.clear()
            for stop in stops:
                stoppoints_el.append(stop[2])

SHORT_DAYS = {
    u'пн': u'1',
    u'вт': u'2',
    u'ср': u'3',
    u'чт': u'4',
    u'пт': u'5',
    u'сб': u'6',
    u'вс': u'7',
}
ch_d = re.compile(ur'^через день, с (\d+)\.(\d+)$')
ch_d_exclude = re.compile(ur'^через день, с (\d+)\.(\d+), кроме\s+(.*?)\s*$')
ch_d_exclude_2 = re.compile(ur'^через день, с (\d+)\.(\d+),\s*(.*?)\s*$')
week_and_days = re.compile(ur'^недели:(\d),\s*(.*?)\s*$')


def parse_mask(route_title, raw_days, start_date):
    period_start_date = None
    exclude = None

    if raw_days == u'ежедневно':
        days = '1234567'
        return days, period_start_date, exclude

    if raw_days == u'по нечетным числам':
        days = u'нечетные'
        return days, period_start_date, exclude

    if raw_days == u'по четным числам':
        days = u'четные'
        return days, period_start_date, exclude

    if raw_days.startswith(tuple(SHORT_DAYS.keys())):
        days = short_days_only_to_days(raw_days, route_title)
        return days, period_start_date, exclude

    if raw_days.startswith(u'через день'):

        rez = ch_d.match(raw_days)
        if rez:
            day = int(rez.group(1))
            month = int(rez.group(2))
            closest_date = find_closest_date(month, day, start_date)
            period_start_date = closest_date.strftime('%Y-%m-%d')
            days = u'через день'
            return days, period_start_date, exclude

        rez = ch_d_exclude.match(raw_days)
        if rez:
            day = int(rez.group(1))
            month = int(rez.group(2))
            closest_date = find_closest_date(month, day, start_date)
            period_start_date = closest_date.strftime('%Y-%m-%d')

            exclude = short_days_only_to_days(rez.group(3), route_title)

            days = u'через день'
            return days, period_start_date, exclude

        rez = ch_d_exclude_2.match(raw_days)
        if rez:
            day = int(rez.group(1))
            month = int(rez.group(2))
            closest_date = find_closest_date(month, day, start_date)
            period_start_date = closest_date.strftime('%Y-%m-%d')

            include = short_days_only_to_days(rez.group(3), route_title)

            exclude = u''.join(set(u'1234567') - set(include))

            days = u'через день'
            return days, period_start_date, exclude

        return None, None, None

    rez = week_and_days.match(raw_days)
    if rez:
        week_number = int(rez.group(1))
        week_days = set(short_days_only_to_days(rez.group(2), route_title))

        our_dates = []

        now_date = start_date
        end_date = now_date + datetime.timedelta(days=MASK_MAX_EXTEND_DAYS)

        while now_date < end_date:
            if unicode(now_date.weekday() + 1) in week_days:
                day_of_month = now_date.day
                week_number_tmp = (day_of_month - 1) / 7 + 1
                if week_number == week_number_tmp:
                    our_dates.append(now_date)

            now_date += datetime.timedelta(days=1)

        if our_dates:
            days = u';'.join([ds.strftime('%Y-%m-%d') for ds in our_dates])
            return days, None, None
        else:
            return None, None, None

    if re.match(ur'^\d+\.\d+', raw_days):
        dates, tail = parse_dates(raw_days, start_date)

        if tail:
            log.error(u"Нитка '%s'. Маска дней хождения '%s' распознана не полностью."
                      u" Не удалось распознать хвост '%s'." %
                      (route_title, raw_days, tail))

        days = u';'.join(dates)
        return days, period_start_date, exclude

    return None, None, None


def parse_dates(string, start_date):
    dates = []
    tail = string
    while re.match(ur'^(\d+)\.(\d+)\s*(.*?)\s*$', tail):
        rez = re.match(ur'^(\d+)\.(\d+)\s*(.*?)\s*$', tail)

        day = int(rez.group(1))
        month = int(rez.group(2))
        tail = rez.group(3)

        closest_date = find_closest_date(month, day, start_date)
        date = closest_date.strftime('%Y-%m-%d')
        dates.append(date)

    return dates, tail


def short_days_only_to_days(string, route_title):
    days = u''
    days_list = [d for d in string.split(u'.') if d]
    for d in days_list:
        if d in SHORT_DAYS.keys():
            days += SHORT_DAYS[d]
        else:
            log.warning(u"Нитка '%s'. Маска дней хождения '%s' распознана не полностью. Не удалось распознать '%s'." %
                        (route_title, string, d))
    return days


def find_closest_date(month, day, start_date):

    now = start_date
    year_now = now.year

    try:
        date = datetime.date(year_now, month, day)
        date2 = datetime.date(year_now + 1, month, day)
        date3 = datetime.date(year_now - 1, month, day)

        first = abs((date - now).total_seconds()) / 60 / 60 / 24
        second = abs((date2 - now).total_seconds()) / 60 / 60 / 24
        third = abs((date3 - now).total_seconds()) / 60 / 60 / 24

        if first <= second and first <= third:
            return date
        elif second <= third:
            return date2
        else:
            return date3

    except ValueError:
        # 29 февраля
        r = year_now % 4
        if r == 0:
            return datetime.date(year_now, month, day)
        elif r == 1:
            return datetime.date(year_now - 1, month, day)
        elif r == 2:
            return datetime.date(year_now - 2, month, day)
        else:
            return datetime.date(year_now + 1, month, day)

