# -*- coding: utf-8 -*-
import logging

from django.utils.functional import cached_property

from travel.avia.library.python.common.models.geo import Point
from travel.avia.library.python.common.models.tariffs import AeroexTariff
from travel.avia.library.python.common.utils.date import UTC_TZ

from travel.avia.avia_api.avia.v1.model.db import db

log = logging.getLogger(__name__)


class Aeroexpress(db.EmbeddedDocument):
    station = db.StringField()
    datetime = db.DateTimeField()
    arrival_datetime = db.DateTimeField()

    price_standart = db.IntField()
    price_business = db.IntField()
    price_currency = db.StringField(default='RUR')

    NOTIFICATION_WAITING = ''
    NOTIFICATION_SENDING = 'sending'
    NOTIFICATION_SENT = 'sent'
    NOTIFICATION_CANCELLED = 'cancelled'

    # deprecated
    notification_status = db.StringField(choices=(
        (NOTIFICATION_WAITING, u"Не отправлялось"),
        (NOTIFICATION_SENDING, u"Отправляется"),
        (NOTIFICATION_SENT, u"Отправлено"),
        (NOTIFICATION_CANCELLED, u'Нотификация отменена'),
    ), default=NOTIFICATION_WAITING)

    @classmethod
    def get_tariffs(cls, segment):
        tarifftypes = ['aeroexpress_i', 'aeroexpress_b']

        price_standart = None
        price_business = None

        aeroex_tarif = AeroexTariff.objects.filter(
            station_from=segment.station_from,
            station_to=segment.station_to,
            type__code__in=tarifftypes
        )

        for a in aeroex_tarif:
            if a.type.code == 'aeroexpress_i':
                price_standart = a.tariff

            elif a.type.code == 'aeroexpress_b':
                price_business = a.tariff

        return price_standart, price_business

    @classmethod
    def from_segment(cls, segment):
        price_standart, price_business = cls.get_tariffs(segment)

        def naive_utc(dt):
            return dt.astimezone(UTC_TZ).replace(tzinfo=None)

        return cls(
            station=segment.station_from.point_key,
            datetime=naive_utc(segment.departure),
            arrival_datetime=naive_utc(segment.arrival),
            price_standart=price_standart,
            price_business=price_business,
        )

    @cached_property
    def station_point(self):
        try:
            return Point.get_by_key(self.station)

        except Exception:
            log.exception('Get station_point by key %r error', self.station)

    @property
    def local_departure_datetime(self):
        if not self.station_point:
            return

        return (
            UTC_TZ.localize(self.datetime).astimezone(self.station_point.pytz)
        )

    @property
    def local_arrival_datetime(self):
        if not self.arrival_datetime:
            return

        if not self.station_point:
            return

        return (
            UTC_TZ.localize(self.arrival_datetime)
                  .astimezone(self.station_point.pytz)
        )

    def __unicode__(self):
        return '<Aeroexpress %s [%s]-[%s] (%s)>' % (
            self.station,

            self.datetime.strftime('%Y-%m-%d %H:%M')
            if self.datetime else None,

            self.arrival_datetime.strftime('%Y-%m-%d %H:%M')
            if self.arrival_datetime else None,

            self.notification_status,
        )
