# -*- coding: utf-8 -*-
from itertools import count, izip

import os.path
import re
from datetime import datetime, timedelta

from bs4 import BeautifulSoup
from lxml import etree

from common.cysix.builder import ChannelBlock, GroupBlock, ThreadBlock, StoppointBlock, ScheduleBlock
from travel.rasp.library.python.common23.date import environment
from common.utils.caching import cache_method_result
from common.utils.iterrecipes import chunker
from common.utils.unicode_csv import unicode_reader
from cysix.tsi_converter import CysixTSIConverterFactory, CysixTSIConverterFileProvider
from travel.rasp.admin.lib.logs import get_current_file_logger
from travel.rasp.admin.scripts.schedule.utils import RaspImportError
from travel.rasp.admin.scripts.schedule.utils.file_providers import PackageFileProvider


log = get_current_file_logger()


URL = 'http://178.69.65.138:8080/'

CURRENCY = 'RUR'

GROUP_TITLES = {
    'vol02': u'Вологда',
    'vol18': u'Сокол',
}

MAX_DIFF_IN_MINUTES_FOR_EVERY_OTHER_DAY = 7


class DycCysixFactory(CysixTSIConverterFactory):
    def get_raw_download_file_provider(self):
        return DycHTTPFileProvider(self.package)

    def get_raw_package_file_provider(self):
        return DycPackageFileProvider(self.package)

    def get_converter_file_provider(self, raw_file_provider):
        return DycCysixFileProvider(self.package, raw_file_provider)


class DycCysixFileProvider(CysixTSIConverterFileProvider):
    def convert_data(self, filepath):
        c = Converter(self.raw_file_provider)

        c.convert(filepath)


class ConvertError(RaspImportError):
    pass


class Converter(object):
    def __init__(self, provider):
        self.provider = provider

        self.raw_data = RawData(provider.get_filemap())

    def convert(self, filepath):
        channel_el = self.build_channel_element()

        with open(filepath, 'w') as f:
            f.write(etree.tostring(channel_el, xml_declaration=True, encoding='utf-8', pretty_print=True))

    def build_channel_element(self):
        channel_block = ChannelBlock(
            'bus',
            station_code_system='vendor',
            carrier_code_system='local',
            vehicle_code_system='local',
            timezone='start_station'
        )

        for raw_group in self.raw_data.get_groups():
            self.group_block = GroupBlock(
                channel_block,
                title=GROUP_TITLES.get(raw_group.code, raw_group.code),
                code=raw_group.code
            )

            self.stations = dict()

            for raw_thread in raw_group.thread_iter():
                if not raw_thread.schedules:
                    log.error(u'Нет расписания для %s %s', raw_thread.title, raw_thread.number)

                    continue

                self.build_thread(raw_thread)

            channel_block.add_group_block(self.group_block)

        return channel_block.get_element()

    def build_thread(self, raw_thread):
        thread_block = ThreadBlock(self.group_block, raw_thread.title, raw_thread.number)

        self.add_schedules_to_thread(raw_thread, thread_block)

        self.add_stoppoints_to_thread(raw_thread, thread_block)

        self.add_fares_to_thread(raw_thread, thread_block)

        thread_block.set_raw_data(raw_thread.raw_data)

        self.group_block.add_thread_block(thread_block)

    def add_schedules_to_thread(self, raw_thread, thread_block):
        for days, times in raw_thread.schedules:
            thread_block.add_schedule_block(ScheduleBlock(
                thread_block,
                days,
                times=times
            ))

    def add_stoppoints_to_thread(self, raw_thread, thread_block):
        for stop in raw_thread.stoppoints:
            station_block = self.get_station_block(stop.title)

            thread_block.add_stoppoint_block(StoppointBlock(
                thread_block,
                station_block,
                arrival_shift=stop.arrival_shift,
                distance=stop.distance
            ))

    def add_fares_to_thread(self, raw_thread, thread_block):
        fare_block = self.group_block.add_local_fare()
        thread_block.set_fare_block(fare_block)

        first_stop = raw_thread.stoppoints[0]
        first_station_block = self.get_station_block(first_stop.title)

        for stop in raw_thread.stoppoints[1:]:
            if stop.price:
                station_block = self.get_station_block(stop.title)

                fare_block.add_price_block(stop.price, CURRENCY, first_station_block, station_block, data={})

        if raw_thread.matrix:
            log.info(u'Есть матрица для %s %s', raw_thread.title, raw_thread.number)

            for (title_from, title_to), price in raw_thread.matrix.iteritems():
                station_from = self.get_station_block(title_from)
                station_to = self.get_station_block(title_to)

                fare_block.add_price_block('%.2f' % price, CURRENCY, station_from, station_to, data={})

    def get_station_block(self, title):
        try:
            return self.stations[title]

        except KeyError:
            raw_station = RawStation(title)

            station_block = self.group_block.add_station(raw_station.title, raw_station.code)
            station_block.add_legacy_station(raw_station.legacy_title, '')

            self.stations[title] = station_block

            return station_block


class RawData(object):
    def __init__(self, filemap):
        self.filemap = filemap

    def get_groups(self):
        groups = []

        for code, (pac_file, sma_files, pac_files) in self.get_filename_groups_iter(self.filemap.keys()):
            groups.append(RawGroup(
                code,
                self.filemap[pac_file],
                map(lambda x: self.filemap[x], sma_files),
                map(lambda x: self.filemap[x], pac_files)
            ))

        return groups

    def get_filename_groups_iter(self, filenames):
        all_pac_files = filter(lambda fn: fn.startswith('pac'), filenames)

        groupnames = set()

        for pac_file in all_pac_files:
            groupnames.add(pac_file.split('_')[1])

        log.info(u"Обнаружили групы файлов %s", groupnames)

        for groupname in groupnames:
            log.info(u"Собираем файлы для группы %s", groupname)
            pac_files = sorted(filter(lambda fn: fn.startswith('pac_%s' % groupname), filenames),
                               key=self.get_dyc_datepart)
            pac_file = pac_files[-1]

            sma_files = sorted(filter(lambda fn: fn.startswith('sma_%s' % groupname), filenames),
                               key=self.get_dyc_datepart)[-1:]

            pts_files = sorted(filter(lambda fn: fn.startswith('pts_%s' % groupname), filenames),
                               key=self.get_dyc_datepart)[-1:]

            if not sma_files:
                log.error(u"Не нашли sma файла для группы %s", groupname)
                continue

            log.info(u"Собрали группу %s %r", groupname, (pac_file, sma_files, pts_files))

            yield groupname, (pac_file, sma_files, pts_files)

    @cache_method_result
    def get_dyc_datepart(self, filename):
        return "_".join(filename.replace(".txt", "").split('_')[-2:])


class RawGroup(object):
    dyc_id_template = u'{group_code}x{low_number}x{high_number}'

    dyc_number_re = re.compile(ur'^\s*(?P<low_number>\d+)\s*[(](?P<high_number>\d+)[)].*', re.U + re.I)

    sma_dyc_number_re = re.compile(ur'^\s*(?P<high_number>\d+)\s*[(](?P<low_number>\d+)[)].*', re.U + re.I)

    def __init__(self, code, pac_filepath, sma_filepaths, pts_filepaths):
        self.code = code
        self.pac_filepath = pac_filepath
        self.sma_filepaths = sma_filepaths
        self.pts_filepaths = pts_filepaths

        self.paths = self._get_paths()

        self.matrixes = self._get_matrixes()

    def thread_iter(self):
        for thread_info in self._thread_info_iter():
            rtstations = self.paths.get(thread_info['dyc_id'], None)

            if not rtstations:
                log.error(u"Маршрут %s не имеет информации о списке станций", thread_info['dyc_id'])

                continue

            matrix = self.matrixes.get(thread_info['dyc_id'], None)

            yield RawThread(thread_info, rtstations, matrix)

    def _thread_info_iter(self):
        with open(self.pac_filepath) as f:
            reader = unicode_reader(f, delimiter='\t', encoding="cp1251", strip_values=True)

            # skip first row
            reader.next()

            thread_info = {}

            for row in reader:
                if len(row) < 6:
                    log.error(u"Не верная строчка в pac-файле группы %s: %r", self.code, row)

                    continue

                (number, title_from, title_to, distance, duration, mask_template), times = row[:6], row[6:]

                if number:
                    match = self.sma_dyc_number_re.match(number)

                    if not match:
                        thread_info = {}

                        log.error(u"%s не является номером рейса ДЮК")

                        continue

                    if thread_info:
                        yield thread_info

                    params = {'group_code': self.code}
                    params.update(match.groupdict())

                    dyc_id = self.dyc_id_template.format(**params)

                    thread_info = {
                        'dyc_id': dyc_id,
                        'title_from': title_from,
                        'title_to': title_to,
                        'distance': distance,
                        'duration': duration,
                        'masks_and_times': [(mask_template, times)],
                        'raw_data': [u'\t'.join(row)]
                    }

                elif mask_template:
                    if not thread_info:
                        log.error(u"Не верная строчка в pac-файле группы %s: %r", self.code, row)

                    thread_info['masks_and_times'].append((mask_template, times))
                    thread_info['raw_data'].append(u'\t'.join(row))

            if thread_info:
                yield thread_info

    def _get_paths(self):
        paths = dict()

        for sma_filepath in self.sma_filepaths:
            self._fill_paths_from_sma(sma_filepath, paths)

        return paths

    def _fill_paths_from_sma(self, sma_filepath, paths):
        filename = os.path.basename(sma_filepath)

        with open(sma_filepath) as f:
            reader = unicode_reader(f, delimiter='\t', encoding="cp1251", strip_values=True)
            line_iter = izip(count(1), reader)

            for route_header_line_number, route_header_row in line_iter:
                log.debug(u'try route %s %s', route_header_line_number, repr(route_header_row))
                if route_header_row and route_header_row[0]:
                    match = self.dyc_number_re.match(route_header_row[0])
                    if not match:
                        log.error(u'Неожиданный порядок данных в sma-файле'
                                  u' при разборе заголовка маршрута, файл %s: строка %s:\n%s',
                                  filename, route_header_line_number, u'\t'.join(map(unicode, route_header_row)))
                        continue

                    params = {'group_code': self.code}
                    params.update(match.groupdict())

                    dyc_id = self.dyc_id_template.format(**params)

                    stoppoints = []

                    line_iter.next()  # class row
                    line_iter.next()  # fields row

                    has_error = False
                    for stop_line_number, stop_row in line_iter:
                        log.debug(u'stop %s %s', stop_line_number, repr(stop_row))
                        if not stop_row[0]:
                            break

                        elif len(stop_row) < 4:
                            log.error(u'Неожиданный порядок данных в sma-файле'
                                      u' при разборе остановки, файл %s: строка %s:\n%s',
                                      filename, stop_line_number, u'\t'.join(map(unicode, stop_row)))
                            has_error = True
                            break

                        else:
                            stoppoints.append(RawStoppoint(
                                title=stop_row[0],
                                distance=stop_row[1],
                                duration_time_str=stop_row[2],
                                price=stop_row[3],
                                raw_data=u'\t'.join(stop_row)
                            ))

                    if not has_error:
                        paths[dyc_id] = stoppoints

    def _get_matrixes(self):
        matrixes = dict()

        for pts_filepath in self.pts_filepaths:
            log.info(u'Разбираем pts %s файл', os.path.basename(pts_filepath))

            for dyc_id, matix, raw_data in self._matix_iter(pts_filepath):
                matrixes[dyc_id] = (matix, raw_data)

        return matrixes

    def _matix_iter(self, pts_filepath):
        for matrix_rows in self._matrix_rows_iter(pts_filepath):
            number_row = matrix_rows[1]
            number = number_row[0] + number_row[1]

            params = {'group_code': self.code}
            params.update(self.dyc_number_re.match(number).groupdict())

            dyc_id = self.dyc_id_template.format(**params)

            try:
                matrix_map = self._parse_matrix_map(matrix_rows)

            except Exception:
                log.error(u"Ошибка разбора тарифов для %s", dyc_id)

                continue

            raw_data = u'\n'.join((u'\t'.join(r) for r in matrix_rows))

            yield dyc_id, matrix_map, raw_data

    def _parse_matrix_map(self, matrix_rows):
        matrix_map = {}

        titles = []

        for row in matrix_rows[3:]:
            if not row[0]:
                continue

            title = row[1]

            titles.append(title)

            prices = row[3:]

            for from_index, price in enumerate(prices[:len(titles)]):
                try:
                    price = float(price.replace(',', '.'))
                except ValueError:
                    continue

                if price < 10:
                    continue

                matrix_map[titles[from_index], title] = price

        return matrix_map

    def _matrix_rows_iter(self, pts_filepath):
        with open(pts_filepath) as f:
            reader = unicode_reader(f, delimiter='\t', encoding="cp1251", strip_values=True)

            matrix_rows = []

            prev_row = []

            for row in reader:
                if not row or len(row) < 2:
                    continue

                number = row[0] + row[1]

                match = self.dyc_number_re.match(number)

                if match:
                    if matrix_rows:
                        del matrix_rows[-1]

                        yield matrix_rows

                    matrix_rows = [prev_row]
                    matrix_rows.append(row)

                elif matrix_rows:
                    matrix_rows.append(row)

                prev_row = row

            if matrix_rows:
                yield matrix_rows


class RawThread(object):
    def __init__(self, thread_info, stoppoints, matrix):
        self.number = thread_info['dyc_id']
        self.title = thread_info['title_from'] + u' - ' + thread_info['title_to']

        self.stoppoints = stoppoints

        self.matrix = None
        matrix_raw_data = u''

        if matrix:
            self.matrix, matrix_raw_data = matrix

        self.distance = thread_info['distance'].replace(u',', u'.')
        self.duration = thread_info['duration'].replace(u',', u'.')
        self.schedules = [s for s in self.schedule_iter(thread_info['masks_and_times'])]

        self.raw_data = u'\n'.join(thread_info['raw_data'])
        self.raw_data += u'\n---------------\n'
        self.raw_data += u'\n'.join((s.raw_data for s in self.stoppoints))
        self.raw_data += u'\n---------------\n'
        self.raw_data += matrix_raw_data or u''

    def schedule_iter(self, masks_and_times):
        for days_template, times in masks_and_times:
            times = [t for t in times if t]
            days, times = self.get_days_and_times(days_template, times)

            if not days or not times:
                continue

            yield days, u';'.join(times)

    week_days = {
        u'пн': u'1',
        u'вт': u'2',
        u'ср': u'3',
        u'чт': u'4',
        u'пт': u'5',
        u'сб': u'6',
        u'вс': u'7',
    }

    week_day_re = re.compile(ur'(пн|вт|ср|чт|пт|сб|вс)(,(пн|вт|ср|чт|пт|сб|вс))*', re.U)

    def get_days_and_times(self, days_template, times):
        days_template = days_template.lower()

        if days_template == u'ежедневно':
            return days_template, times

        elif self.week_day_re.match(days_template):
            return u''.join([self.week_days[d] for d in days_template.split(u',')]), times

        elif days_template == u'через день':
            times = check_and_filter_times_for_every_other_day_mask(times)
            return u'ежедневно', times

        else:
            log.error(u"Неожиданные данные в днях хождения '%s'", days_template)

            return None, None


def check_and_filter_times_for_every_other_day_mask(times):
    """
    RASPADMIN-880
    Для шаблона дней хождения "Через день", если заданы пары близких времен,
    то берем первое время из каждой пары и делаем шаблон "ежедневно"
    """

    today = environment.today()
    try:
        new_times = []
        for first, second in chunker(times, 2):
            first_time = datetime.strptime(first, '%H:%M').time()
            second_time = datetime.strptime(second, '%H:%M').time()

            first_dt = datetime.combine(today, first_time)
            second_dt = datetime.combine(today, second_time)
            if first_dt > second_dt:
                first_dt, second_dt = second_dt, first_dt

            if second_dt - first_dt <= timedelta(minutes=MAX_DIFF_IN_MINUTES_FOR_EVERY_OTHER_DAY):
                new_times.append(first)
            else:
                # Пробуем добавить сутки (например, если пара времен 23:59 и 00:01)
                first_dt += timedelta(days=1)
                if first_dt - second_dt <= timedelta(minutes=MAX_DIFF_IN_MINUTES_FOR_EVERY_OTHER_DAY):
                    new_times.append(first)
                else:
                    raise ValueError('Incorrect times')

        return new_times

    except ValueError:
        return None


class RawStoppoint(object):
    def __init__(self, title, distance, duration_time_str, price, raw_data):
        self.title = title
        self.distance = distance.replace(u',', u'.')
        self.price = price.replace(u',', u'.')

        self.raw_data = raw_data

        self.arrival_shift = None

        if duration_time_str:
            hours, minutes = map(int, duration_time_str.split(':'))

            self.arrival_shift = (hours * 60 + minutes) * 60


class RawStation(object):
    space_re = re.compile(ur'\s+', re.U)

    def __init__(self, title):
        self.title = title

        self.code = title

        self.legacy_title = self._get_legacy_title()

    def _get_legacy_title(self):
        legacy_title = self.space_re.sub(u" ", self.title).strip()

        return legacy_title.replace(u" -", u"-").replace(u"- ", u"-").lower().capitalize()


def flattened(*args):
    result = []

    args = list(args)

    while args:
        v = args.pop()

        if isinstance(v, (tuple, list)):
            args[0:0] = v

        else:
            result.append(v)

    return result


class DycHTTPFileProvider(PackageFileProvider):
    def get_filemap(self):
        filemap = {}

        for filename in self.get_http_filenames():
            filepath = self.get_package_filepath(filename)
            self.download_file(URL + filename, filepath)

            filemap[filename] = filepath

        return filemap

    def get_http_filenames(self):
        filepath = self.get_package_filepath('index.html')

        self.download_file(URL, filepath)

        filenames = []

        with open(filepath) as f:
            tree = BeautifulSoup(f)

            for a in tree.findAll('a'):
                filename = a['href']

                if filename.endswith('.txt'):
                    filenames.append(filename)

        return filenames


class DycPackageFileProvider(PackageFileProvider):
    def __init__(self, package):
        super(DycPackageFileProvider, self).__init__(package)

    def get_filemap(self):
        return self.get_unpack_map_with_trimmed_dirs_from_package_archive()
