# -*- coding: utf-8 -*-

import logging
import os.path

import pytz
from django.conf import settings
from django.utils.translation import ugettext_lazy as _, gettext_noop as N_
from lxml import etree

import library.python.resource
from common.models.geo import CodeSystem, StationCode, Station
from common.models.schedule import Company
from common.models.transport import TransportModel, TransportSubtype, TransportType
from cysix.base import CriticalImportError, safe_parse_xml
from travel.rasp.admin.lib.types import CaseInsensitiveSet
from travel.rasp.admin.lib.xmlutils import copy_lxml_element


log = logging.getLogger(__name__)


class CysixCheckError(CriticalImportError):
    def __init__(self, msg, line=None, extra=None):
        super(CysixCheckError, self).__init__(msg)

        self.line = line
        self.extra = extra

        self.raw_msg = msg

        if line and extra:
            self.msg_template = N_(u'%s\nНомер строки: %s\n%s')
            self.msg_args = (msg, line, extra)


def validate_cysix_xml(filepath):
    tree = get_tree_or_exit_if_bad_xml(filepath)

    validate_xml_against_xsd(tree)

    check_consistency(tree)
    check_timezones(tree)
    check_currencies(tree)
    check_transports(tree)


def get_tree_or_exit_if_bad_xml(xml_filename):
    try:
        return safe_parse_xml(xml_filename)

    except IOError as e:
        raise CysixCheckError(
            _(u'Не нашли файл "{filepath}": "{error}"')
            .format(error=unicode(e), filepath=xml_filename)
        )

    except etree.XMLSyntaxError as e:
        raise CysixCheckError(
            _(u'Файл не является валидным xml-файлом: "{error}"')
            .format(error=unicode(e))
        )


def validate_xml_against_xsd(tree):
    xmlschema_doc = etree.fromstring(_get_xsd_string())
    xsd = etree.XMLSchema(xmlschema_doc)

    try:
        xsd.assertValid(tree)

    except etree.DocumentInvalid as e:
        line = None
        extra = None

        if e.error_log:
            error = e.error_log[-1]
            line = error.line
            extra = error.message

        raise CysixCheckError(msg=_(u'Ошибка при проверке на соответствие общему xml-формату.'),
                              line=line, extra=extra)


def _get_xsd_string():
    return library.python.resource.find('cysix/data/public.xsd')


def check_timezones(tree):
    channel_el = tree.getroot()
    _check_timezone(channel_el)


def _check_timezone(channel_el):
    _check_timezone_in_element(channel_el)

    for group_el in channel_el.findall('.//group'):
        _check_timezone_in_element(group_el)

    for thread_el in channel_el.findall('.//thread'):
        _check_timezone_in_element(thread_el)

    for stoppoint_el in channel_el.findall('.//stoppoint'):
        _check_timezone_in_element(stoppoint_el)

    # legacy
    for schedule_el in channel_el.findall('.//schedule'):
        _check_timezone_in_element(schedule_el)


def _check_timezone_in_element(element):
    attrs = element.attrib

    if 'timezone' in attrs:
        tz = attrs['timezone']

        if tz.lower() not in ['local', 'start_station', 'end_station'] and tz not in pytz.all_timezones_set:
            top_element_only = copy_lxml_element(element, top_element_only=True)

            raise CysixCheckError(
                _(u'Файл общего xml-формата содержит неправильную временную зону "{timezone}".')
                .format(timezone=tz),
                line=element.sourceline,
                extra=etree.tostring(top_element_only, encoding=unicode))


def check_currencies(tree):
    from common.models.currency import Currency
    all_our_currencies = set(c.code for c in Currency.objects.all())
    all_our_currencies.update(['RUR', 'RUB'])

    channel_el = tree.getroot()

    for price_el in channel_el.findall('.//price'):
        if price_el.attrib['currency'] not in all_our_currencies:
            raise CysixCheckError(_(u'Файл общего xml-формата содержит неизвестную валюту "{currency}"')
                                  .format(currency=price_el.attrib['currency']),
                                  line=price_el.sourceline,
                                  extra=etree.tostring(price_el, encoding=unicode))


def check_transports(tree):
    channel_el = tree.getroot()
    transport_types = set(TransportType.objects.values_list('code', flat=True))
    transport_subtypes = dict(
        TransportSubtype.objects.values_list('code', 't_type__code')
    )

    _check_transport_in_element(
        channel_el, transport_types, transport_subtypes
    )

    for thread_el in channel_el.findall('.//thread'):
        _check_transport_in_element(
            thread_el, transport_types, transport_subtypes
        )


def _check_transport_in_element(element, transport_types, transport_subtypes):
    transport_type = element.attrib.get('t_type')
    if transport_type is not None and transport_type not in transport_types:
        log.warning(
            _(u'Строка {line}. Неизвестный тип транспорта "{transport_type}".')
            .format(line=element.sourceline, transport_type=transport_type)
        )

    transport_subtype = element.attrib.get('subtype')
    if transport_subtype is not None:
        subtype_type = transport_subtypes.get(transport_subtype)

        if subtype_type is None:
            log.warning(
                _(u'Строка {line}. Неизвестный подтип транспорта "{transport_subtype}".')
                .format(line=element.sourceline, transport_subtype=transport_subtype)
            )
        elif transport_type is not None and subtype_type != transport_type:
            log.warning(
                _(u'Строка {line}. Подтип транспорта "{transport_subtype}" не относится к типу "{transport_type}".')
                .format(line=element.sourceline, transport_subtype=transport_subtype, transport_type=transport_type)
            )


def check_consistency(tree):
    channel_el = tree.getroot()
    _check_stations(channel_el)
    _check_carriers(channel_el)
    _check_vehicles(channel_el)
    _check_fares(channel_el)


def _check_stations(channel_el):
    channel_system = channel_el.attrib.get('station_code_system', u'vendor')

    for group_el in channel_el.findall('group'):
        group_system = group_el.attrib.get('station_code_system', channel_system)

        group_code = group_el.get('code', u'').strip()
        group_title = group_el.attrib.get('title', u'').strip()

        all_stations = Stations(group_code, group_title)

        for station_el in group_el.findall('./stations/station'):
            system = station_el.attrib.get('code_system', group_system)
            code = station_el.attrib['code'].strip()

            if code:
                all_stations.add(code, system, station_el.sourceline)

        for stop_from_el in group_el.findall('./fares/fare/price/stop_from'):
            system = stop_from_el.attrib.get('station_code_system', group_system)
            code = stop_from_el.attrib['station_code'].strip()

            if code:
                all_stations.check(code, system, stop_from_el.sourceline)

        for stop_to_el in group_el.findall('./fares/fare/price/stop_to'):
            system = stop_to_el.attrib.get('station_code_system', group_system)
            code = stop_to_el.attrib['station_code'].strip()

            if code:
                all_stations.check(code, system, stop_to_el.sourceline)

        for thread_el in group_el.findall('./threads/thread'):
            thread_system = thread_el.attrib.get('station_code_system', group_system)

            for stoppoint_el in thread_el.findall('./stoppoints/stoppoint'):
                system = stoppoint_el.attrib.get('station_code_system', thread_system)
                code = stoppoint_el.attrib['station_code'].strip()

                if code:
                    all_stations.check(code, system, stoppoint_el.sourceline)

            for point_el in thread_el.findall('./geometry/point'):
                system = point_el.attrib.get('station_code_system', thread_system)
                code = point_el.get('station_code', u'').strip()

                if code:
                    all_stations.check(code, system, point_el.sourceline)


def _check_carriers(channel_el):
    channel_system = channel_el.attrib.get('carrier_code_system', u'vendor')

    for group_el in channel_el.findall('group'):
        group_system = group_el.attrib.get('carrier_code_system', channel_system)

        group_code = group_el.attrib['code']
        group_title = group_el.attrib.get('title', None)

        all_carriers = Carriers(group_code, group_title)

        for carrier_el in group_el.findall('./carriers/carrier'):
            system = carrier_el.attrib.get('code_system', group_system)
            code = carrier_el.attrib['code'].strip()

            if code:
                all_carriers.add(code, system, carrier_el.sourceline)

        for thread_el in group_el.findall('./threads/thread'):
            if 'carrier_code' in thread_el.attrib:
                system = thread_el.attrib.get('carrier_code_system', group_system)
                code = thread_el.attrib['carrier_code'].strip()

                if code:
                    all_carriers.check(code, system, thread_el.sourceline)


def _check_vehicles(channel_el):
    channel_system = channel_el.attrib.get('vehicle_code_system', u'vendor')

    for group_el in channel_el.findall('group'):
        group_system = group_el.attrib.get('vehicle_code_system', channel_system)

        group_code = group_el.attrib['code']
        group_title = group_el.attrib.get('title', None)

        all_vehicles = Vehicles(group_code, group_title)

        for vehicle_el in group_el.findall('./vehicles/vehicle'):
            system = vehicle_el.attrib.get('code_system', group_system)
            code = vehicle_el.attrib['code'].strip()

            if code:
                all_vehicles.add(code, system, vehicle_el.sourceline)

        for thread_el in group_el.findall('./threads/thread'):
            if 'vehicle_code' in thread_el.attrib:
                system = thread_el.attrib.get('vehicle_code_system', group_system)
                code = thread_el.attrib['vehicle_code'].strip()

                if code:
                    all_vehicles.check(code, system, thread_el.sourceline)


def _check_fares(channel_el):
    for group_el in channel_el.findall('group'):

        group_code = group_el.attrib['code']
        group_title = group_el.attrib.get('title', None)

        group_name_for_log = get_group_name_for_log(group_code, group_title)

        all_fares = set()

        for fare_el in group_el.findall('./fares/fare'):
            code = fare_el.attrib['code'].strip()

            if code:
                if code in all_fares:
                    log.warning(
                        _(u'Строка {line}. Тариф с кодом "{code}" встречается повторно в группе "{group}".')
                        .format(line=fare_el.sourceline, code=code, group=group_name_for_log)
                    )
                all_fares.add(code)

        for thread_el in group_el.findall('./threads/thread'):
            if 'fare_code' in thread_el.attrib:
                code = thread_el.attrib['fare_code'].strip()

                if code:
                    if code not in all_fares:
                        log.error(
                            _(u'Строка {line}. Файл общего xml-формата содержит неизвестный код тарифа "{code}" '
                              u'в гуппе "{group}"')
                            .format(line=thread_el.sourceline, code=code, group=group_name_for_log)
                        )


def get_group_name_for_log(group_code, group_title):
    return u'<Group title="{title}" code="{code}">'.format(title=group_title or u"",
                                                           code=group_code or u"")


class Stations(object):
    VENDOR_CODES = (u'vendor', u'temporary_vendor', u'local')

    def __init__(self, group_code, group_title):
        self.vendor_codes = CaseInsensitiveSet()
        self.our_codes = CaseInsensitiveSet()
        self.group_code = group_code
        self.group_title = group_title

        self.group_name_for_log = get_group_name_for_log(group_code, group_title)

    def add(self, code, system, line):
        ss = (code, system)

        if system.lower() in self.VENDOR_CODES:
            if ss in self.vendor_codes:
                log.warning(
                    _(u'Строка {line}. В теге stations код станции "{code}" '
                      u'в системе кодирования "{system}" '
                      u'встречается повторно в группе "{group}"')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )
            self.vendor_codes.add(ss)

        else:
            if not self._is_station_in_db(code, system, line):
                log.error(
                    _(u'Строка {line}. В базе отсутствует станция с кодом "{code}" '
                      u'в системе кодирования "{system}".')
                    .format(line=line, code=code, system=system)
                )

            if ss in self.our_codes:
                log.warning(
                    _(u'Строка {line}. В теге stations код станции "{code}" '
                      u'в системе кодирования "{system}" '
                      u'встречается повторно в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

            self.our_codes.add(ss)

    def check(self, code, system, line):
        ss = (code, system)

        if system.lower() in self.VENDOR_CODES:
            if ss not in self.vendor_codes:
                log.error(
                    _(u'Строка {line}. Неизвестная станция с кодом "{code}" '
                      u'в системе кодирования "{system}" '
                      u'в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

        else:
            if not self._is_station_in_db(code, system, line):
                log.error(
                    _(u'Строка {line}. В базе отсутствует станция с кодом "{code}" '
                      u'в системе кодирования "{system}".')
                    .format(line=line, code=code, system=system)
                )

    def _is_station_in_db(self, code, system, line):
        if system.lower() == 'yandex':
            station = Station.objects.filter(id=code)

        else:
            code_system = CodeSystem.objects.filter(code=system)

            if len(code_system) < 1:
                return False

            elif len(code_system) > 1:
                log.warning(
                    _(u'Строка {line}. В базе больше одной системы кодирования "{system}".')
                    .format(line=line, system=system)
                )

            code_system = code_system[0]

            station = StationCode.objects.filter(system=code_system, code=code)

        if len(station) < 1:
            return False

        elif len(station) > 1:
            log.warning(
                _(u'Строка {line}. В базе больше одной станции с кодом "{code}" '
                  u'в системе кодирования "{system}".')
                .format(line=line, code=code, system=system)
            )

        return True


class Carriers(object):
    VENDOR_SYSTEMS = (u'vendor', u'temporary_vendor', u'local')
    KNOWN_SYSTEMS = (u'iata', u'sirena', u'icao')

    def __init__(self, group_code, group_title):
        self.vendor_codes = set()
        self.our_codes = set()
        self.group_code = group_code
        self.group_title = group_title

        self.group_name_for_log = get_group_name_for_log(group_code, group_title)

    def add(self, code, system, line):
        ss = (code, system)

        if system.lower() in self.KNOWN_SYSTEMS:
            if not self._is_carrier_in_db(code, system, line):
                log.error(
                    _(u'Строка {line}. В базе отсутствует перевозчик с кодом "{code}" '
                      u'в системе кодирования "{system}".')
                    .format(line=line, code=code, system=system)
                )

            if (code, system) in self.our_codes:
                log.warning(
                    _(u'Строка {line}. В теге carriers код перевозчика "{code}" '
                      u'в системе кодирования "{system}" '
                      u'встречается повторно в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

            self.our_codes.add(ss)

        elif system.lower() in self.VENDOR_SYSTEMS:
            if ss in self.vendor_codes:
                log.warning(
                    _(u'Строка {line}. В теге carriers код перевозчика "{code}" в '
                      u'системе кодирования "{system}" '
                      u'встречается повторно в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

            self.vendor_codes.add(ss)

        else:
            # TODO пока обрабатываем неизвестные системы аналогично vendor.
            # Вообще они должны быть в базе. А здесь место исключению.
            if ss in self.vendor_codes:
                log.warning(
                    _(u'Строка {line}. В теге carriers код перевозчика "{code}" в '
                      u'системе кодирования "{system}" '
                      u'встречается повторно в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

            self.vendor_codes.add(ss)

    def check(self, code, system, line):
        ss = (code, system)

        if system.lower() in self.KNOWN_SYSTEMS:
            if not self._is_carrier_in_db(code, system, line):
                log.error(
                    _(u'Строка {line}. В базе отсутствует перевозчик с кодом "{code}" '
                      u'в системе кодирования "{system}".')
                    .format(line=line, code=code, system=system)
                )

        else:
            if ss not in self.vendor_codes:
                log.error(
                    _(u'Строка {line}. Неизвестный перевозчик с кодом "{code}" '
                      u'в системе кодирования "{system}" '
                      u'в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

    def _is_carrier_in_db(self, code, system, line):
        companies = getattr(self, '_find_by_%s' % system.lower())(code)

        if len(companies) < 1:
            return False

        elif len(companies) > 1:
            log.warning(
                _(u'Строка {line}. В базе больше одного перевозчика с кодом "{code}" '
                  u'в системе кодирования "{system}".')
                .format(line=line, code=code, system=system)
            )

        return True

    def _find_by_iata(self, code):
        return Company.objects.filter(iata=code)

    def _find_by_sirena(self, code):
        return Company.objects.filter(sirena_id=code)

    def _find_by_icao(self, code):
        return Company.objects.filter(icao=code)


class Vehicles(object):
    VENDOR_SYSTEMS = (u'vendor', u'temporary_vendor', u'local')
    KNOWN_SYSTEMS = (u'sirena', u'oag')

    def __init__(self, group_code, group_title):
        self.vendor_codes = CaseInsensitiveSet()
        self.our_codes = CaseInsensitiveSet()
        self.group_code = group_code
        self.group_title = group_title

        self.group_name_for_log = get_group_name_for_log(group_code, group_title)

    def add(self, code, system, line):
        ss = (code, system)

        if system.lower() in self.KNOWN_SYSTEMS:

            if not self._is_vehicle_in_db(code, system, line):
                log.error(
                    _(u'Строка {line}. В базе отсутствует транспорт с кодом "{code}" '
                      u'в системе кодирования "{system}".')
                    .format(line=line, code=code, system=system)
                )

            if ss in self.our_codes:
                log.warning(
                    _(u'Строка {line}. В теге vehicles код транспорта "{code}" '
                      u'в системе кодирования "{system}" '
                      u'встречается повторно в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

            self.our_codes.add(ss)

        elif system.lower() in self.VENDOR_SYSTEMS:

            if ss in self.vendor_codes:
                log.warning(
                    _(u'Строка {line}. В теге vehicles код транспорта "{code}" '
                      u'в системе кодирования "{system}" '
                      u'встречается повторно в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

            self.vendor_codes.add(ss)

        else:
            # TODO пока обрабатываем неизвестные системы аналогично vendor.
            # Вообще они должны быть в базе. А здесь место исключению.
            if ss in self.vendor_codes:
                log.warning(
                    _(u'Строка {line}. В теге vehicles код транспорта "{code}" '
                      u'в системе кодирования "{system}" '
                      u'встречается повторно в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

            self.vendor_codes.add(ss)

    def check(self, code, system, line):
        ss = (code, system)

        if system.lower() in self.KNOWN_SYSTEMS:
            if not self._is_vehicle_in_db(code, system, line):
                log.error(
                    _(u'Строка {line}. В базе отсутствует транспорт с кодом "{code}" '
                      u'в системе кодирования "{system}".')
                    .format(line=line, code=code, system=system)
                )

        else:
            if ss not in self.vendor_codes:
                log.error(
                    _(u'Строка {line}. Неизвестный транспорт с кодом "{code}" '
                      u'в системе кодирования "{system}" '
                      u'в группе "{group}".')
                    .format(line=line, code=code, system=system, group=self.group_name_for_log)
                )

    def _is_vehicle_in_db(self, code, system, line):
        vehicles = getattr(self, '_find_by_%s' % system.lower())(code)

        if len(vehicles) < 1:
            return False

        elif len(vehicles) > 1:
            log.warning(
                _(u'Строка {line}. В базе больше одного транспорта с кодом "{code}" '
                  u'в системе кодирования "{system}".')
                .format(line=line, code=code, system=system)
            )

        return True

    def _find_by_oag(self, code):
        return TransportModel.objects.filter(code_en=code)

    def _find_by_sirena(self, code):
        return TransportModel.objects.filter(code=code)
