# -*- coding: utf-8 -*-
import ujson as json
import zlib
from abc import ABCMeta
from datetime import timedelta
from logging import getLogger
from operator import attrgetter

import opentracing
from django.conf import settings
from travel.avia.library.python.ticket_daemon.date import unixtime, get_utc_now

from travel.avia.ticket_daemon.ticket_daemon.api.result.fare_families import FareFamiliesProvider
from travel.avia.ticket_daemon.ticket_daemon.api.result import cache_backends
from travel.avia.ticket_daemon.ticket_daemon.api.result.abstract_result import AbstractResult
from travel.avia.ticket_daemon.ticket_daemon.api.result.serializer import fares_serializer, flights_serializer
from travel.avia.ticket_daemon.ticket_daemon.lib.min_prices import PartnerMinPrices
from travel.avia.ticket_daemon.ticket_daemon.lib.partner_store_time_provider import partner_store_time_provider
from travel.avia.ticket_daemon.ticket_daemon.lib.utils import group_by

logger = getLogger(__name__)


class Statuses(object):
    QUERYING = 'querying'
    DONE = 'done'
    SKIP = 'skip'
    EMPTY = 'empty'
    NONE = 'none'
    FAIL = 'fail'
    OUTDATED = 'outdated'

    _in_progress = {QUERYING, OUTDATED}

    @classmethod
    def in_progress(cls, status):
        return status in cls._in_progress


def pack(unpacked):
    return zlib.compress(json.dumps(unpacked))


def unpack(packed):
    return json.loads(zlib.decompress(packed))


class BaseResult(AbstractResult):
    __metaclass__ = ABCMeta

    # TODO: разнести по наследникам кому что реально нужно из этого
    def __init__(
        self, query, partner, variants,
        store_time=None, query_time=.0, status=None, all_variants_count=0,
        saved_variants=()
    ):
        self.query = query
        self.qid = query.id
        self.service = query.service
        self.query_time = query_time
        self.status = status
        self.partner = partner
        self.all_variants_count = all_variants_count

        self.store_time = store_time

        self.created = unixtime()
        self.expire = self.created + store_time

        self.variants = variants
        self.saved_variants = saved_variants

    def __repr__(self):
        return '<{} {}[{}] created:{} expire:{}>'.format(
            self.__class__.__name__,
            self.partner.code if hasattr(self, 'partner') else '',
            len(self.variants),
            self.created,
            self.expire
        )


class Result(BaseResult):
    STORE_TIME = 7 * 24 * 60 * 60  # 1 week
    REDIRECT_DATA_STORE_TIME = 6 * 60 * 60  # 6 hours

    @staticmethod
    def _pack(data):
        return zlib.compress(Result._serialize(data), 9)

    @staticmethod
    def _serialize(data):
        return json.dumps(data)

    def _redirect_data_variant_to_dict(self, variant):
        """
        :type variant: ticket_daemon.api.flights.Variant
        """
        return variant.tag, {
            'order_data': variant.order_data,
            'query_source': self.query.service,
        }

    def _store_time(self):
        return min(
            int((self.query.date_forward - get_utc_now().date() + timedelta(days=2)).total_seconds()),
            self.STORE_TIME
        )

    def _redirect_data_store_time(self):
        return self.REDIRECT_DATA_STORE_TIME

    def cache_key(self):
        return self.make_result_key(
            self.query.qkey,
            self.partner.code
        )

    def serialize_redirect_data(self, variants):
        return {
            'variants': dict(map(self._redirect_data_variant_to_dict, variants)),
        }

    def serialize_variants(self, variants):
        fares = fares_serializer.serialize(variants, self.created, self.expire)
        flights = flights_serializer.serialize(variants)
        return {
            'created': self.created,
            'expire': self.expire,
            'query_time': self.query_time,
            'all_variants_count': len(fares),
            'flights': flights,
            'fares': fares,
            'qid': self.qid,
        }

    def meta(self, store_time):
        return {
            'created': self.created,
            'expire': self.expire,
            'instant_search_expiration_time': self.created + store_time,
            'qid': self.query.id
        }

    @staticmethod
    def make_result_key(qkey, partner_code):
        # Deprecated
        return '{}/any/{}'.format(qkey, partner_code)

    def __repr__(self):
        return '<{} {}[{}] created:{} expire:{}>'.format(
            self.__class__.__name__,
            self.partner.code if hasattr(self, 'partner') else '',
            len(self.variants),
            self.created,
            self.expire
        )

    def serialize_variants_and_fare_families(self, variants):
        serialized_variants = self.serialize_variants(variants)
        if settings.FARE_FAMILIES_ENABLED:
            fare_families = FareFamiliesProvider.get_fare_families_for_variants(serialized_variants)
            if fare_families:
                serialized_variants['fare_families_data'] = fare_families
        return serialized_variants

    def store(self):
        variants = self.variants
        if self.saved_variants:
            variants = self.saved_variants + variants

        serialized_variants = self.serialize_variants_and_fare_families(variants)
        packed_result = self._pack(serialized_variants)
        redirect_data = self._pack(self.serialize_redirect_data(variants))
        store_time = self._store_time()
        meta = self._serialize(self.meta(store_time))

        cache_backends.ydb_cache.set(
            self.query, self.partner.code, packed_result, redirect_data, meta, store_time,
            self._redirect_data_store_time(),
        )

    def to_dict(self):
        pass


class Status(Result):
    def cache_key(self):
        return self.make_result_key(
            self.query.qkey, self.partner.code,
        )

    @staticmethod
    def make_result_key(qkey, partner_code):
        return '%s%s_%s_%s_status' % (
            settings.TICKET_DAEMON_CACHEROOT,
            qkey,
            'any',
            partner_code,
        )

    def store(self):
        packed_result = self.pack()
        cache_backends.shared_cache.set(
            self.cache_key(), packed_result, self.store_time
        )

    def to_dict(self):
        return {
            'created': self.created,
            'expires_at': self.expire,
            'query_time': self.query_time,
            'status': self.status,
            'all_variants_count': self.all_variants_count,
            'qid': self.qid,
            'service': self.service,
            'variants': [],
        }


def _set_statuses(statuses):
    for store_time, _statuses in group_by(statuses, attrgetter('store_time')).iteritems():
        data = {}
        for _status in _statuses:
            data[_status.cache_key()] = _status.pack()

        logger.info('store %d new keys: %s', len(data), _abbr_no_more_than(5, data.keys()))
        cache_backends.shared_cache.set_many(data, store_time)


def set_partners_statuses(query, partners, status_code, custom_store_time=None):
    with opentracing.global_tracer().start_active_span(operation_name='Setting partners statuses',
                                                       finish_on_close=True):
        statuses = []
        for partner in partners:
            store_time = partner_store_time_provider.get_status_time(
                partner=partner,
                custom_store_time=custom_store_time
            )
            statuses.append(Status(query, partner, [], status=status_code, store_time=store_time))

        _set_statuses(statuses)


def _abbr_no_more_than(num, items):
    return '{}{}'.format(', '.join(items[:num]), ' [...]' if len(items) > num else '')


class MinPrice(object):
    @classmethod
    def make_result_key(cls, query_key, partner_code):
        return '{}{}/any/{}/min_price'.format(
            settings.TICKET_DAEMON_CACHEROOT, query_key, partner_code
        )

    @classmethod
    def store(cls, query, partner, store_time, variants, rates):
        mp = PartnerMinPrices.create_from_variants(variants, query.national_version, rates)
        mp_key = cls.make_result_key(query.qkey, partner.code)

        if not mp:
            logger.info('store_partner_min_price: no PartnerMinPrices %s', mp_key)
            return

        try:
            packed_result_min_price = pack(mp.to_dict())
        except Exception:
            logger.exception('Packing PartnerMinPrices error')
        else:
            logger.info('store_partner_min_price new: %r', mp_key)
            cache_backends.shared_cache.set(mp_key, packed_result_min_price, store_time)

    @classmethod
    def aggregate_query_min_prices(cls, query, partners):
        logger.info('aggregate_query_min_prices: %r', query)

        min_price_partner_keys = [
            cls.make_result_key(query.qkey, p.code)
            for p in partners
        ]

        partners_min_prices = list(
            cls.get_partners_min_prices(min_price_partner_keys)
        )

        logger.info('min_prices_of_partners: %d %r', len(partners_min_prices), query)

        PartnerMinPrices.aggregate(partners_min_prices, query)

    @classmethod
    def get_partners_min_prices(cls, mp_keys):
        for packed in cache_backends.shared_cache.get_many(mp_keys).itervalues():
            if not packed:
                continue
            try:
                partner_min_prices = unpack(packed)
                yield PartnerMinPrices.create_from_serialized(partner_min_prices)

            except Exception:
                logger.exception('Restoring PartnerMinPrices from cache error')
