# coding: utf8

from __future__ import absolute_import, unicode_literals

import logging
from time import strptime
from datetime import datetime, time, date

from common.models.schedule import (Supplier, RThread, RTStation, TrainSchedulePlan, Route, ExpressTypeLite,
                                    RThreadType, Company)
from common.models.geo import Station, StationCode, CodeSystem, StationMajority
from common.models.tariffs import TariffType
from common.models.transport import TransportType, TransportModel, TransportSubtype
from travel.rasp.library.python.common23.date import environment
from common.utils.caching import cache_method_result
from common.utils.date import MSK_TIMEZONE
from travel.rasp.admin.lib.exceptions import SimpleUnicodeException
from travel.rasp.admin.lib.mask_builder.ycal_builders import CountryRequiredError
from travel.rasp.admin.scripts.support_methods import esr_leading_zeros
from travel.rasp.admin.scripts.schedule.af_processors.parse_utils import get_yesno_param, get_date_param
from travel.rasp.admin.scripts.schedule.af_processors.af_multiparam_mask_parser import make_runmask, get_calendar_country, compute_bounds


log = logging.getLogger(__name__)


class BadESRCodeError(SimpleUnicodeException):
    def __init__(self, *args):
        msg = 'Не нашли станции с кодами: '
        msg += ', '.join(self.args)
        super(BadESRCodeError, self).__init__(msg)


class AFThreadParseError(SimpleUnicodeException):
    pass


def get_af_minutes(param):
    if not param or param == '-':
        return 0
    return int(param)


def get_station(thread_el, st_params, default_region_id, bad_esr_codes):
    if st_params.get('station_id'):
        station = Station.objects.get(pk=st_params['station_id'])
    elif st_params.get('vendor_id') and thread_el.get('vendor'):
        station = Station.get_by_code(thread_el.get('vendor'), st_params.get('vendor_id'))
    else:
        st_params['esrcode'] = esr_leading_zeros(st_params['esrcode'])
        try:
            station = Station.get_by_code('esr', st_params['esrcode'])
        except Station.DoesNotExist:
            if default_region_id:
                sts = Station.objects.filter(title=st_params['stname'],
                                             code_set__code=None,
                                             code_set__system__code='esr')
                if sts:
                    station = sts[0]
                    station.set_code('esr', st_params['esrcode'])
                else:
                    station = Station()
                    station.title = st_params['stname']
                    station.t_type_id = 1
                    station.region_id = default_region_id
                    station.time_zone = station.region.time_zone
                    station.majority = StationMajority.objects.get(code='station')
                    station.save()
                    StationCode(system=CodeSystem.objects.get(code='esr'),
                                code=st_params['esrcode'],
                                station=station).save()
                    log.info('Добавили станцию %s с кодом %s в область %s',
                             station.title, st_params['esrcode'], station.region)
            else:
                bad_esr_codes.append(st_params['esrcode'])
                return

    return station


def get_raw_stations(thread_el, default_region_id):
    bad_esr_codes = []
    raw_stations = []

    for station_el in thread_el.findall('./stations/station'):
        st_params = station_el.attrib
        station = get_station(thread_el, st_params, default_region_id, bad_esr_codes)
        if not station:
            continue

        raw_stations.append((station, station_el))

    if bad_esr_codes:
        raise BadESRCodeError(*bad_esr_codes)

    return raw_stations


def set_rtstation_flag(rts, station_el, flag_name):
    if flag_name in station_el.attrib:
        setattr(rts, flag_name, station_el.get(flag_name) == '1')
    else:
        setattr(rts, flag_name, getattr(rts.station, flag_name))


def parse_stations(thread_el, start_pytz, default_region_id, mask):
    if mask:
        first_day = mask.iter_dates().next()
        log.info('Берем первый день хождения %s для вычисления времени в пути', first_day)
    else:
        first_day = date.today()
        log.info('Маска пустая, берем сегодняшнюю дату %s для вычисления времени в пути', first_day)

    rtstations = []
    tz_start_time = None

    duration = 0

    for station, station_el in get_raw_stations(thread_el, default_region_id):
        if tz_start_time is None:
            if station_el.get('departure_time'):
                try:
                    tz_start_time = time(*strptime(station_el.get('departure_time'), '%H:%M')[3:6])
                except ValueError:
                    log.exception('Неверный формат даты в нитке с id %s',
                                  thread_el.get('number'))
                    return

                naive_start_dt = datetime.combine(first_day, tz_start_time)
                tz_start_time = naive_start_dt.time()

                stop_time = 0

            else:
                log.error('У нитки с id %s не указано время отправления',
                          thread_el.get('id'))
                return
        else:
            try:
                stop_time = get_af_minutes(station_el.get('stop_time', '0'))
                mfs = get_af_minutes(station_el.get('minutes_from_start'))
                if mfs == 0:
                    duration += 1
                else:
                    duration = mfs
            except (ValueError, KeyError):
                log.error('Не верно указано время в минутах со старта'
                          ' или стоянка для станции %s в нитке с id %s',
                          station_el.get('esrcode'), thread_el.get('id'))
                continue

        rtstation = RTStation()
        rtstation.time_zone = start_pytz.zone
        rtstation.tz_arrival = duration
        rtstation.tz_departure = duration + stop_time
        rtstation.station = station

        rtstation.platform = station_el.get('platform')
        rtstation.is_combined = station_el.get('is_combined') == '1'

        set_rtstation_flag(rtstation, station_el, 'is_searchable_from')
        set_rtstation_flag(rtstation, station_el, 'is_searchable_to')
        set_rtstation_flag(rtstation, station_el, 'in_station_schedule')

        rtstation.is_virtual_end = bool(int(station_el.get('important') or 0))

        duration += stop_time

        rtstations.append(rtstation)

    if len(rtstations) >= 2:
        rtstations[0].tz_arrival = None
        rtstations[-1].tz_departure = None
    else:
        log.error('Меньше 2 станций в нитке %s', thread_el.attrib['number'])
        return

    return rtstations, tz_start_time


class AfThreadParser(object):
    def __init__(self, default_region_id=None, default_t_type_code=None, default_company_id=None,
                 default_supplier=None):
        self._default_t_type_code = default_t_type_code
        self._default_region_id = default_region_id
        self._default_company_id = default_company_id
        self._default_supplier = default_supplier

    def parse_thread(self, thread_el):
        # Если отменен пропускаем
        if get_yesno_param(thread_el, 'cancel'):
            log.info('Нитка %s отменена', thread_el.get('number'))
            return

        thread = RThread()
        thread.thread_el = thread_el

        thread.template_text = thread_el.get('daystr', '').strip()
        thread.time_zone = thread_el.get('time_zone', MSK_TIMEZONE)
        thread.template_timezone = thread.time_zone

        thread.template_code = thread_el.get('weektemplate', '').strip()
        if '#' in thread.template_code:
            tz, template_code = thread.template_code.split('#')
            thread.template_timezone = tz or thread.time_zone
            thread.template_code = template_code
        else:
            thread.template_timezone = thread.time_zone

        thread.template_start = get_date_param(thread_el, 'weektemplate_start')
        thread.template_end = get_date_param(thread_el, 'weektemplate_end')

        thread.type = self.get_basic_type()
        thread.t_type = self.get_t_type(thread_el)
        thread.t_subtype = self.get_t_subtype(thread_el, thread.t_type)
        thread.t_model = self.get_t_model(thread_el)

        thread.number = thread_el.get('number')
        thread.changemode = thread_el.get('changemode')

        thread.company = self.get_company(thread_el)
        thread.express_type = self.get_express_type(thread_el)
        thread.express_lite = self.get_express_lite(thread_el)

        thread.period_start = get_date_param(thread_el, 'period_start')
        thread.period_end = get_date_param(thread_el, 'period_end')

        thread.schedule_plan = self.get_schedule_plan(thread_el)
        thread.tariff_type = (thread_el.get('tariff_type') and
                              TariffType.objects.get(code=thread_el.get('tariff_type')))

        thread.supplier = (Supplier.objects.get(code=thread_el.get('supplier'))
                           if thread_el.get('supplier')
                           else self._default_supplier)
        thread.is_circular = thread_el.get('routemode') == 'circular'
        thread.is_combined = thread_el.get('is_combined') == '1'

        country = get_calendar_country(thread_el)
        bounds = compute_bounds(thread_el, thread.schedule_plan)
        try:
            thread.mask = make_runmask(thread_el, bounds, country, environment.today())
        except CountryRequiredError:
            log.error('Для построения календаря нужно обязательно указать calendar_geobase_country_id')
            return
        thread.year_days = str(thread.mask)

        # thread.route.supplier нужен для gen_route_uid
        # TODO: убрать отсюда маршрут, когда поставщик всегда будет заполнен у нитки
        route = Route()
        route.supplier = thread.supplier
        thread.route = route

        if thread.changemode == 'delay':
            thread.raw_stations = get_raw_stations(thread_el, self._default_region_id)

        elif thread_el.findall('./stations/station'):
            result = parse_stations(thread_el, thread.pytz, self._default_region_id, thread.mask)
            if result is not None:
                rtstations, tz_start_time = result
            else:
                log.error('Ошибка при разборе нитки поезда %s', thread.number)
                return

            thread.tz_start_time = tz_start_time
            thread.rtstations = rtstations

            if not (thread.template_text or thread.template_code):
                if thread.changemode in ('insert', 'add'):
                    log.warning('Шаблон дней хождений пустой')
                else:
                    log.info('Шаблон дней хождений пустой')

            self.gen_title(thread, thread_el)

        if thread_el.get('thread'):
            thread.uid = thread_el.get('thread')
        elif thread_el.get('canonical') and thread_el.get('threaddate'):
            threads = RThread.objects.filter(canonical_uid=thread_el.get('canonical'))
            for t in threads:
                if t.runs_at(datetime.strptime(thread_el.get('threaddate'), "%Y-%m-%d").date()):
                    thread.uid = t.uid
                    break

        return thread

    TITLE_ATTRS = ['title', 'title_tr', 'title_uk', 'title_short']

    def gen_title(self, thread, thread_el):
        thread.gen_title()

        if thread_el.get('is_manual_title') == '1':
            log.info('Используем название нитки из файла')
            thread.is_manual_title = True

            for attr in self.TITLE_ATTRS:
                attr_value = thread_el.get(attr, getattr(thread, attr))
                setattr(thread, attr, attr_value)

    def get_schedule_plan(self, thread_el):
        graph_code = thread_el.get('graph', '').strip()
        return self._get_schedule_plan(graph_code)

    @cache_method_result
    def _get_schedule_plan(self, graph_code):
        if graph_code:
            try:
                return TrainSchedulePlan.objects.get(code=graph_code)
            except TrainSchedulePlan.DoesNotExist:
                raise AFThreadParseError('Не нашли график с кодом {}'.format(graph_code))

    def get_t_type(self, thread_el):
        return self._get_t_type(thread_el.get('t_type', '').strip())

    @cache_method_result
    def _get_t_type(self, t_type_code):
        t_type_code = t_type_code or self._default_t_type_code

        if t_type_code == 'local_train':
            log.warning(
                'Тип траспорта local_train больше не существует, нужно заменить его на suburban')
            log.info('Автоматически подменяем local_train на suburban')
            t_type_code = 'suburban'

        try:
            return TransportType.objects.get(code=t_type_code)
        except TransportType.DoesNotExist:
            raise AFThreadParseError('Не указан тип транспорта, или указан не верно t_type="{}"'
                                     .format(t_type_code))

    def get_t_model(self, thread_el):
        return self._get_t_model(thread_el.get('t_model', ''))

    @cache_method_result
    def _get_t_model(self, t_model_code):
        return TransportModel.objects.get(pk=t_model_code) if t_model_code else None

    def get_company(self, thread_el):
        company_id = thread_el.get('career', None)
        if company_id is None:
            company_id = self._default_company_id
        else:
            company_id = company_id.strip()

        return self._get_company(company_id)

    @cache_method_result
    def _get_company(self, company_id):
        if not company_id:
            return
        try:
            return Company.objects.get(pk=company_id, hidden=False)
        except Company.DoesNotExist:
            log.error('Компания с кодом %s отсутствует в базе, или скрыта')
            return

    def get_express_lite(self, thread_el):
        return self._get_express_lite(thread_el.get('express_subtype', '').strip())

    @cache_method_result
    def _get_express_lite(self, code):
        if not code:
            return None

        try:
            return ExpressTypeLite.objects.get(code=code)
        except ExpressTypeLite.DoesNotExist:
            log.warning('Не нашли подтип экспресса %s', code)

    def get_t_subtype(self, thread_el, t_type):
        code = thread_el.get('express_subtype', '').strip()
        return self._get_t_subtype(code, t_type)

    @cache_method_result
    def _get_t_subtype(self, code, t_type):
        if code:
            try:
                return TransportSubtype.objects.get(code=code)
            except TransportSubtype.DoesNotExist:
                log.error('Не нашли подтипа для express_subtype="%s"', code)
                return

        t_subtype_id = TransportSubtype.get_default_subtype_id(t_type.id)
        if t_subtype_id:
            return TransportSubtype.objects.get(pk=t_subtype_id)

    def get_express_type(self, thread_el):
        if thread_el.get('express_type') == 'etrain':
            express_type = None
        elif thread_el.get('express_type') == 'aeroexpress':
            express_type = 'aeroexpress'
        elif thread_el.get('express_type') and thread_el.get('express_type').startswith('express'):
            express_type = 'express'
        else:
            express_type = None

        return express_type

    @cache_method_result
    def get_basic_type(self):
        return RThreadType.objects.get(pk=RThreadType.BASIC_ID)
