# -*- coding: utf-8 -*-
from __future__ import absolute_import

import sys
import heapq
import logging
from collections import defaultdict
from datetime import datetime, timedelta

from django.db.models import F, Q
from django.conf import settings

from travel.avia.library.python.avia_data.models import MinPrice, AviaDirectionNational
from travel.avia.library.python.common.models.geo import Settlement

from travel.avia.backend.main.lib.prices import DummyDirectionPrice, key_by_params, get_min_prices_by_keys, AviaPrice
from travel.avia.backend.repository.currency import currency_repository

log = logging.getLogger(__name__)


def _get_left_boarder(city_from, recipe, when):
    today = city_from.get_local_datetime(datetime.now()).date()

    if when:
        left_date_border = when
    elif recipe.date_start:
        left_date_border = recipe.date_start
    else:
        left_date_border = today + timedelta(days=1)

    if left_date_border < today:
        left_date_border = today

    return left_date_border


def _get_right_boarder(city_from, recipe, return_date, left_date_border):
    if return_date:
        right_date_border = return_date
    elif recipe.date_end:
        right_date_border = recipe.date_end
    else:
        right_date_border = (left_date_border + timedelta(days=30))

    return right_date_border


def _get_boarders(city_from, recipe, when, return_date):
    left_date_border = _get_left_boarder(city_from, recipe, when)
    right_date_border = _get_right_boarder(city_from, recipe, return_date, left_date_border)

    return left_date_border, right_date_border


def _get_popular_relative_objects(city_from, national_version):
    arrivals_ids = set((
        AviaDirectionNational.objects
        .filter(
            departure_settlement_id=city_from.id,
            national_version=national_version,
        )
        .values_list('arrival_settlement_id', flat=True)
    ))

    return arrivals_ids, None, None


def _get_common_relative_objects(recipe, city_from):
    country_ids = set(recipe.countries.all().values_list('id', flat=True))
    region_ids = set(recipe.regions.all().values_list('id', flat=True))
    settlement_ids = set(recipe.settlements.all().values_list('id', flat=True))

    settlement_ids.discard(city_from.id)
    region_ids.discard(city_from.region_id)

    return settlement_ids, region_ids, country_ids


def _filter_geo(city_from, recipe, national_version):
    if recipe.recipe_type == 'popular':
        settlement_ids, region_ids, country_ids = _get_popular_relative_objects(city_from, national_version)
    else:
        settlement_ids, region_ids, country_ids = _get_common_relative_objects(recipe, city_from)

    q = Q()

    if (settlement_ids):
        q |= Q(arrival_settlement_id__in=settlement_ids)
    if (region_ids):
        q |= Q(arrival_settlement__region_id__in=region_ids)
    if (country_ids):
        q |= Q(arrival_settlement__country_id__in=country_ids)

    q &= ~Q(arrival_settlement_id=city_from.id)
    q &= Q(departure_settlement_id=city_from.id)

    return q


def _filter_oneway():
    return Q(date_backward=None)


def _filter_twoway(right_boarder, return_date):
    TWOWAY_RANGE_MIN = 5
    TWOWAY_RANGE_MAX = 31

    if return_date:
        q = Q(date_backward=right_boarder)
    else:
        q = Q(date_backward__gte=F('date_forward') + timedelta(days=TWOWAY_RANGE_MIN))
        q &= Q(date_backward__lte=F('date_forward') + timedelta(days=TWOWAY_RANGE_MAX))

    return q


def _filter_date_backward(when, return_date, left_border, right_boarder):
    return _filter_oneway() | _filter_twoway(right_boarder, return_date)


def _filter_date_forward(when, return_date, left_border, right_border):
    if when or return_date:
        return Q(date_forward=left_border)
    else:
        return Q(date_forward__range=(left_border, right_border))


def _filter_weekdays(recipe):
    week_days = filter(lambda x: x, recipe.week_days.split(','))
    week_days = set([(int(i.strip()) - 1) for i in week_days])

    day_filter = Q()
    if week_days:
        Q(day_of_week__in=week_days)

    return day_filter


def _price_compare(d):
    price = d['price']

    if not d['date_backward']:
        price = price * 2

    return price, d['date_forward'], d['id']


def _prefetch_related_settlements(min_prices):
    settlement_ids = set()
    for p in min_prices:
        direct_price = p[0]
        indirect_price = p[1]
        if direct_price:
            settlement_ids.add(direct_price['arrival_settlement_id'])
            settlement_ids.add(direct_price['departure_settlement_id'])
        if indirect_price:
            settlement_ids.add(indirect_price['arrival_settlement_id'])
            settlement_ids.add(indirect_price['departure_settlement_id'])

    settlements_by_id = {s.id: s for s in Settlement.objects.filter(id__in=settlement_ids)}

    for prices in min_prices:
        direct_price = prices[0]
        indirect_price = prices[1]
        if direct_price:
            direct_price['departure_settlement'] = settlements_by_id[direct_price['departure_settlement_id']]
            direct_price['arrival_settlement'] = settlements_by_id[direct_price['arrival_settlement_id']]
        if indirect_price:
            indirect_price['departure_settlement'] = settlements_by_id[indirect_price['departure_settlement_id']]
            indirect_price['arrival_settlement'] = settlements_by_id[indirect_price['arrival_settlement_id']]


def _find_cheapest_direction_min_price(prices, limit):
    def min_price_key(p, is_direct):
        return (p['arrival_settlement_id'], p['date_forward'], p['date_backward'], is_direct)

    order_and_price_by_arrival_id = defaultdict(lambda: ((sys.maxint, sys.maxint, sys.maxint), None))
    prices_by_key = {}
    for p in prices:
        order_and_price = order_and_price_by_arrival_id[p['arrival_settlement_id']]
        order_and_price_by_arrival_id[p['arrival_settlement_id']] = min(order_and_price, (_price_compare(p), p))
        prices_by_key[min_price_key(p, p['direct_flight'])] = p

    min_prices_heap = []
    for arrival_id, order_and_price in order_and_price_by_arrival_id.iteritems():
        heapq.heappush(min_prices_heap, order_and_price)

    min_prices = []
    for _ in range(min(len(min_prices_heap), limit)):
        min_prices.append(heapq.heappop(min_prices_heap)[1])

    result = []
    for p in min_prices:
        if p['direct_flight']:
            direct_price = p
            transfer_price = prices_by_key.get(min_price_key(p, False))
        else:
            direct_price = prices_by_key.get(min_price_key(p, True))
            transfer_price = p

        result.append((direct_price, transfer_price))

    return result


def _get_memcache_keys(prices, is_direct, national_version):
    def get_price(prices):
        return prices[0] or prices[1]

    result = []
    for para in prices:
        direct_price = para[0]
        indirect_price = para[1]
        p = direct_price or indirect_price

        result.append(key_by_params(
            city_from_key=p['departure_settlement'].point_key,
            city_to_key=p['arrival_settlement'].point_key,
            national_version=national_version,
            passengers_key='1_0_0',
            date_forward=p['date_forward'],
            date_backward=p['date_backward'],
            is_direct=is_direct
        ))

    return result


def _get_memcached_prices(prices, national_version):
    keys = _get_memcache_keys(prices, True, national_version) + _get_memcache_keys(prices, False, national_version)
    cashed_prices = get_min_prices_by_keys(keys)

    result = []
    count = len(prices)
    for i in range(len(prices)):
        direct_cached_price = cashed_prices.get(keys[i])
        if direct_cached_price:
            currency = direct_cached_price['currency']
            iso_currency = currency_repository.get_by_code(currency).iso_code
            direct_cached_price = AviaPrice(value=direct_cached_price['price'],
                                            currency=currency,
                                            iso_currency=iso_currency,
                                            roughly=True)
        indirect_cached_price = cashed_prices.get(keys[count + i])
        if indirect_cached_price:
            currency = indirect_cached_price['currency']
            iso_currency = currency_repository.get_by_code(currency).iso_code
            indirect_cached_price = AviaPrice(value=indirect_cached_price['price'],
                                              currency=currency,
                                              iso_currency=iso_currency,
                                              roughly=True)

        result.append((direct_cached_price, indirect_cached_price))

    return result


def _build_direction_prices(min_prices, cached_prices, national_version):
    result = []
    currency = settings.AVIA_NATIONAL_CURRENCIES[national_version]
    iso_currency = currency_repository.get_by_code(currency).iso_code

    for i in range(len(min_prices)):
        direct_price, indirect_price = min_prices[i]
        min_price = direct_price or indirect_price

        if direct_price:
            direct_price = AviaPrice(value=direct_price['price'],
                                     currency=currency,
                                     iso_currency=iso_currency,
                                     roughly=True)
        if indirect_price:
            indirect_price = AviaPrice(value=indirect_price['price'],
                                       currency=currency,
                                       iso_currency=iso_currency,
                                       roughly=True)

        cached_direct_price, cached_indirect_price = cached_prices[i]
        direct_price = cached_direct_price or direct_price
        indirect_price = cached_indirect_price or indirect_price

        result.append(DummyDirectionPrice(
            city_from=min_price['departure_settlement'],
            city_to=min_price['arrival_settlement'],
            direct_price=direct_price,
            indirect_price=indirect_price,
            national_version=national_version,
            date_forward=min_price['date_forward'],
            date_backward=min_price['date_backward'],
            allow_roughly=True
        ))

    return result


def find(city_from, recipe, when, return_date, national_version, limit):
    start = datetime.now()
    left, right = _get_boarders(city_from, recipe, when, return_date)

    query = (
        _filter_weekdays(recipe) &
        _filter_geo(city_from, recipe, national_version) &
        _filter_date_forward(when, return_date, left, right) &
        _filter_date_backward(when, return_date, left, right) &
        Q(national_version=national_version) &
        Q(passengers='1_0_0')
    )
    log.debug('start')
    log.debug('build filter %r', (datetime.now() - start).total_seconds())

    prices = MinPrice.objects.filter(query).values(
        'id',
        'arrival_settlement_id',
        'departure_settlement_id',
        'price', 'currency',
        'date_forward',
        'date_backward',
        'direct_flight'
    )

    log.debug('fetch min prices %r', (datetime.now() - start).total_seconds())

    min_prices = _find_cheapest_direction_min_price(prices, limit)

    log.debug('find best prices %r', (datetime.now() - start).total_seconds())

    _prefetch_related_settlements(min_prices)

    log.debug('prefetch_related_settlemets %r', (datetime.now() - start).total_seconds())
    cached_prices = _get_memcached_prices(min_prices, national_version)
    log.debug('fetch cached prices %r', (datetime.now() - start).total_seconds())

    directions = _build_direction_prices(min_prices, cached_prices, national_version)
    log.debug('build directions prices %r', (datetime.now() - start).total_seconds())
    log.debug('end')

    return directions
