# -*- encoding: utf-8 -*-
import logging
from datetime import datetime, timedelta

import pytz

from travel.avia.avia_api.avia.cache.companies import company_cache
from travel.avia.avia_api.avia.cache.partners import partner_cache
from travel.avia.avia_api.avia.cache.stations import station_cache

log = logging.getLogger(__name__)


class Tariff(object):
    def __init__(self, currency, value):
        self.currency = currency
        self.value = value

    def converted_to(self, currency, rates):
        return Tariff(
            currency,
            1.03 * float(self.value) * rates[self.currency] / rates[currency],
        )


class Flight(object):
    __slots__ = (
        'number', 'company', 'departs_at', 'arrives_at', 'departure_station',
        'arrival_station',
    )

    def __init__(self, number, company,
                 arrival_time, departure_time,
                 arrival_station, departure_station):

        self.number = number
        self.company = company  # may be None
        self.departs_at = departure_time  # aware, tzinfo == station.pytz
        self.arrives_at = arrival_time  # aware, tzinfo == station.pytz
        self.departure_station = departure_station
        self.arrival_station = arrival_station

    def key(self):
        return u'{}|{}|{}|{}'.format(
            self.number, self.departure_station.point_key,
            self.arrival_station.point_key, self.departs_at,
        )


class Trip(object):
    __slots__ = ('segments', )

    def __init__(self, segments):
        self.segments = tuple(segments)

    def start_station(self):
        return self.segments[0].departure_station if self.segments else None

    def end_station(self):
        return self.segments[-1].arrival_station if self.segments else None

    def starts_at(self):  # aware datetime
        return self.segments[0].departs_at if self.segments else None

    def duration(self):
        if not self.segments:
            return timedelta(0)
        return self.segments[-1].arrives_at - self.segments[0].departs_at

    def key(self):
        return '-'.join(f.number for f in self.segments)


class Variant(object):
    __slots__ = (
        'partner', 'tag', 'forward', 'backward', 'original_tariff', 'tariff',
        'query_time', 'is_charter', 'expires_at', 'order_data', 'ttl',
        'deep_link', 'order_link',
        'show_id',  # Устанавливается в ручке order для денежных логов.
        # TODO: избавиться от костылей в виде show_id на варианте.
    )

    def __init__(
        self, partner, tag, forward, backward, original_tariff,
        tariff, query_time, is_charter, expires_at, deep_link, order_link, order_data=None,
    ):
        self.partner = partner
        self.tag = tag
        self.forward = forward
        self.backward = backward
        self.original_tariff = original_tariff
        self.tariff = tariff
        self.query_time = query_time  # in seconds
        self.is_charter = is_charter
        self.expires_at = expires_at
        self.deep_link = deep_link
        self.order_link = order_link
        self.order_data = order_data

        self.ttl = int(
            (
                self.expires_at - pytz.UTC.localize(datetime.utcnow())
            ).total_seconds()
        )

    @property
    def sort_order(self):
        return (
            (0 if self.partner.is_aviacompany else 1 << 16) +  # 1 bit
            min(int(self.query_time * 1000), 2 ** 16 - 1)  # 16 bits
        )

    def __eq__(self, other):
        if isinstance(other, Variant):
            return self.tag == other.tag

        return self is other


class SearchResults(object):
    def __init__(self, statuses, variants):
        self.statuses = statuses
        self.variants = variants

    @classmethod
    def from_json(cls, daemon_json):
        def aware_dt(time_dict):
            return pytz.timezone(time_dict['tz']).localize(
                datetime.strptime(time_dict['local'], '%Y-%m-%dT%H:%M:%S')
            )

        def _station(id):
            result = station_cache.by_id(id)
            if not result:
                log.error('Station %s is not found', id)
            return result

        flights = {}
        for key, flight in daemon_json['flights'].iteritems():
            dep_station = _station(flight['from'])
            arr_station = _station(flight['to'])
            if dep_station and arr_station:  # some station isn't cached
                flights[key] = Flight(
                    number=flight['number'],
                    company=company_cache.by_id(flight['company']),
                    departure_time=aware_dt(flight['departs_at']),
                    arrival_time=aware_dt(flight['arrives_at']),
                    departure_station=dep_station,
                    arrival_station=arr_station,
                )

        log.info(
            'Parsed %d/%d flights', len(flights), len(daemon_json['flights'])
        )

        trips = {}
        for trip_key, flight_keys in daemon_json['trips'].iteritems():
            try:
                trips[trip_key] = Trip([flights[key] for key in flight_keys])
            except KeyError:  # some flight isn't parsed
                pass

        log.info('Parsed %d/%d trips', len(trips), len(daemon_json['trips']))

        variants = []
        for json_variant in daemon_json['variants']:
            partner_code = json_variant['partner']
            partner = partner_cache.by_code(partner_code)
            if not partner:
                log.error('Partner %s is not found', partner_code)
                continue

            try:
                variants.append(
                    Variant(
                        partner=partner,
                        tag=json_variant['tag'],
                        forward=trips[json_variant['forward']],
                        backward=(
                            trips[json_variant['backward']]
                            if json_variant['backward']
                            else Trip([])
                        ),
                        original_tariff=json_variant['base_tariff'],
                        tariff=json_variant['tariff'],
                        query_time=json_variant['query_time'],
                        is_charter=json_variant['charter'],
                        expires_at=aware_dt(json_variant['expires_at']),
                        deep_link=json_variant['deep_link'],
                        order_link=json_variant['order_link'],
                        order_data=json_variant['order_data'],
                    )
                )
            except KeyError:  # some trip isn't parsed
                pass

        log.info(
            'Parsed %d/%d variants',
            len(variants), len(daemon_json['variants'])
        )

        priced_variants = [v for v in variants if v.tariff['value']]

        log.info(
            'Left %d/%d variants with tariffs',
            len(priced_variants), len(variants)
        )

        return cls(daemon_json['status'], priced_variants)
