# -*- coding: utf-8 -*-

import travel.rasp.admin.scripts.load_project  # noqa

import argparse
import logging
import sys

from django.db import transaction

from common.models.geo import Station
from common.models.tariffs import AeroexTariff, TariffType
from travel.rasp.library.python.common23.date import environment
from common.utils.safe_xml_parser import safe_parse_xml

from travel.rasp.admin.importinfo.models import TariffFile
from travel.rasp.admin.lib.logs import create_current_file_run_log, get_collector_context, print_log_to_stdout, get_script_log_context, ylog_context
from travel.rasp.admin.scripts.support_methods import esr_leading_zeros


log = logging.getLogger(__name__)


def get_esr_station(esr_code):
    try:
        return Station.get_by_code('esr', esr_leading_zeros(esr_code))
    except Station.DoesNotExist:
        log.error(u'Не нашли станию с esr кодом %s', esr_code)


def get_tariff_type(tariff_el, tariff_field='type'):
    try:
        return TariffType.objects.get(code=tariff_el.get(tariff_field))
    except TariffType.DoesNotExist:
        log.error(u'Типа тарифа с кодом %s не существует',
                  tariff_el.get(tariff_field))
        raise


def import_tariff_xml(xml_root):
    for tariff_el in xml_root.getiterator('tarif'):
        station_from = get_esr_station(tariff_el.get('station1'))
        station_to = get_esr_station(tariff_el.get('station2'))

        if station_from is None or station_to is None:
            continue

        price = float(tariff_el.get('price'))
        reverse = tariff_el.get('reverse') == u'1'
        insearch = tariff_el.get('insearch', u'1') == u'1'
        currency = tariff_el.get('currency') or None

        tariff_type = get_tariff_type(tariff_el)

        try:
            tariff = AeroexTariff.objects.get(station_from=station_from, station_to=station_to,
                                              type=tariff_type, precalc=False)
            if price == 0:
                tariff.delete()
                log.info(u'Удалили тариф %s - %s %s %s%s',
                         station_from.title, station_to.title,
                         tariff_type.title, price, reverse and u' reversed' or u'')
                continue

        except AeroexTariff.DoesNotExist:
            if price <= 0:
                continue

            tariff = AeroexTariff(station_from=station_from, station_to=station_to, precalc=False)

        is_new = not tariff.id

        tariff.type = tariff_type
        tariff.tariff = price
        tariff.reverse = reverse
        tariff.suburban_search = insearch
        tariff.currency = currency

        if tariff_el.get('replace_tariff_type'):
            tariff.replace_tariff_type = get_tariff_type(tariff_el, 'replace_tariff_type')

        tariff.save()

        if is_new:
            log.info(u'Создали тариф %s - %s %s %s%s',
                     station_from.title, station_to.title,
                     tariff_type.title, price, reverse and u' reversed' or u'')
        else:
            log.info(u'Обновили тариф %s - %s %s %s%s',
                     station_from.title, station_to.title,
                     tariff_type.title, price, reverse and u' reversed' or u'')


@transaction.atomic
def import_tariff(tarifffile):
    with get_collector_context() as collector:
        log.info(u'Импортируем файл тарифа %s %s', tarifffile.id,
                 tarifffile.tariff_file_name)
        sid = transaction.savepoint()
        try:
            import_tariff_xml(tarifffile.get_xml_root())
        except Exception:
            log.exception(u"Ошибка при импортировании тарифа %s %s", tarifffile.id,
                          tarifffile.tariff_file_name)
            transaction.savepoint_rollback(sid)
        else:
            tarifffile.imported = environment.now()
            transaction.savepoint_commit(sid)

        tarifffile.load_log = collector.get_collected()
        tarifffile.save()


def main(command_args):
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--tarifffile-ids', action='store_true')
    parser.add_argument('--no-run-log', action='store_true')
    parser.add_argument('files', metavar='FILE_OR_ID', nargs='+')

    args = parser.parse_args(command_args)
    if args.verbose:
        print_log_to_stdout()

    if not args.no_run_log:
        create_current_file_run_log()

    if args.tarifffile_ids:
        for tarifffile_id in args.files:
            tarifffile = TariffFile.objects.get(id=tarifffile_id)
            import_tariff(tarifffile)
    else:
        for filename in args.files:
            import_tariff_xml(safe_parse_xml(filename))


if __name__ == '__main__':
    with ylog_context(**get_script_log_context()):
        main(sys.argv[1:])
