# -*- coding: utf-8 -*-
from copy import deepcopy
from datetime import datetime, date, timedelta
from typing import Optional, Dict, Any, List, Iterable
from itertools import product
from operator import itemgetter
from logging import getLogger

import dateutil.parser
from django.conf import settings

from travel.avia.library.python.common.lib.timer import Timeline, TskvTimeline
from travel.avia.ticket_daemon_processing.pretty_fares.errors_handling.safe_run import safe_run
from travel.avia.ticket_daemon_processing.pretty_fares.internal_logic.helpers import (
    get_possible_flight_points,
    PointKeyType, get_settlement_point_key
)
from travel.avia.ticket_daemon_processing.pretty_fares.internal_logic.direction_key import DirectionKey
from travel.avia.ticket_daemon_processing.pretty_fares.internal_logic.point_key import PointKey
from travel.avia.ticket_daemon_processing.pretty_fares.saas import retrying_json_saas as json_saas
from travel.avia.ticket_daemon_processing.pretty_fares.kazoo_lock import lock

log = getLogger(__name__)


def _merge_old_and_new_infos(old_info, new_info, departure_date, return_date):
    # type: (List[Dict[str, Any]], List[Dict[str, Any]], date, Optional[date]) -> List[Dict[str, Any]]

    new_bunch = [
        price
        for price in old_info
        if price['forward'] != departure_date
        or price['backward'] != return_date
    ] + new_info

    new_bunch.sort(key=itemgetter('expires_at'), reverse=True)
    new_bunch.sort(key=itemgetter('price'))

    result = []
    for info in new_bunch:
        if result:
            if info['price'] == result[-1]['price']:
                continue
            if info['expires_at'] <= result[-1]['expires_at']:
                continue
        result.append(info)

    return result


def _deserialize_date(raw_date):
    # type: (Optional[str]) -> Optional[date]

    return dateutil.parser.parse(raw_date).date() if raw_date else None


def _deserialize_direction_prices(direction_prices):
    # type: (Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]

    return [
        {
            'price': date_price['price'],
            'forward': _deserialize_date(date_price['forward']),
            'backward': _deserialize_date(date_price['backward']),
            'expires_at': dateutil.parser.parse(date_price['expires_at'])
        }
        for date_price in direction_prices
    ]


# New Logic for updating direction cheapest dates
class DirectionCheapestUpdater(object):
    _TIMELINE_PREFIX = 'DirectionCheapestUpdaterWithAirports. '

    def _parse_date(self, raw_date):
        # type: (str) -> Optional[date]

        if raw_date != 'None':
            return datetime.strptime(raw_date, '%Y-%m-%d').date()

    def _concatenate(self, from_point_key, to_point_key):
        # type: (PointKey, PointKey) -> str

        return '{}_{}'.format(from_point_key.to_string(), to_point_key.to_string())

    def _add_variant_between_points(self, update_infos, from_point, to_point, variant):
        # type: (Dict[str, List[Dict[str, Any]]], PointKey, PointKey, Dict[str, Any]) -> None

        key = self._concatenate(from_point, to_point)
        if key not in update_infos:
            return

        update_infos[key].append(variant)

    def _prepare_update_infos(self, from_point_key, to_point_key, variants, departure_date, return_date, timeline):
        # type: (PointKey, PointKey, List[Dict[str, Any]], date, Optional[date], Timeline) -> Dict[str, List[Dict[str, Any]]]

        from_settlement_point_key = get_settlement_point_key(from_point_key)
        to_settlement_point_key = get_settlement_point_key(to_point_key)

        def _may_change_direction(new_from_point_key, new_to_point_key):
            # type: (Optional[PointKey], Optional[PointKey]) -> bool

            if not new_from_point_key or not new_to_point_key:
                return False

            if new_from_point_key.is_settlement() and from_point_key.is_station():
                return False

            if new_to_point_key.is_settlement() and to_point_key.is_station():
                return False

            return True

        def _is_round_trip(v):
            return len(v['backward_segments']) > 0

        def _starting_point_is_the_same(v):
            return (
                not _is_round_trip(v)
                or v['forward_segments'][0]['depRaspId'] == v['backward_segments'][-1]['arrRaspId']
            )

        def _final_point_is_the_same(v):
            return (
                not _is_round_trip(v)
                or v['forward_segments'][-1]['arrRaspId'] == v['backward_segments'][0]['depRaspId']
            )

        update_infos = {
            self._concatenate(from_point_key, to_point_key): []
            for from_point_key, to_point_key
            in product(
                get_possible_flight_points(from_point_key),
                get_possible_flight_points(to_point_key)
            )
            if _may_change_direction(from_point_key, to_point_key)
        }

        for v in variants:
            if _may_change_direction(from_settlement_point_key, to_settlement_point_key):
                self._add_variant_between_points(update_infos, from_settlement_point_key, to_settlement_point_key, v)

            departure_airport_point_key = PointKey(PointKeyType.Station, v['forward_segments'][0]['depRaspId'])
            arrival_airport_point_key = PointKey(PointKeyType.Station, v['forward_segments'][-1]['arrRaspId'])

            if _starting_point_is_the_same(v) and _may_change_direction(departure_airport_point_key, to_settlement_point_key):
                self._add_variant_between_points(update_infos, departure_airport_point_key, to_settlement_point_key, v)

            if _final_point_is_the_same(v) and _may_change_direction(from_settlement_point_key, arrival_airport_point_key):
                self._add_variant_between_points(update_infos, from_settlement_point_key, arrival_airport_point_key, v)

            if _starting_point_is_the_same(v) and _final_point_is_the_same(v):
                self._add_variant_between_points(update_infos, departure_airport_point_key, arrival_airport_point_key, v)

        for direction, vs in update_infos.iteritems():
            update_infos[direction] = [
                {
                    'price': v['national_tariff_price'],
                    'forward': departure_date,
                    'backward': return_date,
                    'expires_at': v['expires_at'],
                }
                for v in vs
            ]

        return update_infos

    def _current(self, direction_cache_key):
        # type: (str) -> Dict[str, List[Dict[str, Any]]]

        current_raw_result = (
            json_saas.search([direction_cache_key], low_priority=True)
            .get(direction_cache_key, {})
            .get('cheapest', {})
        )

        return {
            direction: _deserialize_direction_prices(direction_prices)
            for direction, direction_prices in current_raw_result.iteritems()
        }

    def _updated(self, old_infos, new_infos, departure_date, return_date, utcnow):
        # type: (Dict[str, List[Dict[str, Any]]], Dict[str, List[Dict[str, Any]]], date, Optional[date], datetime) -> Dict[str, List[Dict[str, Any]]]

        new = deepcopy(old_infos)

        for k, v in new_infos.iteritems():
            new[k] = _merge_old_and_new_infos(old_infos.get(k, []), v, departure_date, return_date)

        new = {
            key: [price for price in value if price['expires_at'] > utcnow]
            for key, value in new.iteritems()
        }

        return {
            key: value
            for key, value in new.iteritems() if value
        }

    def _max_expires_at(self, infos):
        # type: (Dict[str, List[Dict[str, Any]]]) -> datetime

        max_expires_at = None

        for k, v in infos.iteritems():
            max_expires_at_local = max(info['expires_at'] for info in v)
            max_expires_at = max_expires_at_local if not max_expires_at else max(max_expires_at, max_expires_at_local)

        return max_expires_at

    def _prices_count(self, infos):
        # type: (Dict[str, List[Dict[str, Any]]]) -> int

        return sum(len(v) for _, v in infos.iteritems())

    def _create_direction_key_with_settlements(self, initial_direction_key):
        # type: (DirectionKey) -> DirectionKey

        if initial_direction_key.from_point_key.is_settlement() and initial_direction_key.to_point_key.is_settlement():
            return initial_direction_key

        from_settlement_point_key = get_settlement_point_key(initial_direction_key.from_point_key)
        to_settlement_point_key = get_settlement_point_key(initial_direction_key.to_point_key)

        return DirectionKey(
            from_settlement_point_key,
            to_settlement_point_key,
            initial_direction_key.tariff,
            initial_direction_key.adults,
            initial_direction_key.children,
            initial_direction_key.infants,
            initial_direction_key.national_version
        )

    @safe_run(timeline=TskvTimeline(log))
    def update_direction_cheapest_with_airports(self, direction_key, variants, qkey, utcnow, timeline):
        # type: (DirectionKey, List[Dict[str, Any]], str, datetime, Timeline) -> None

        timeline.event(self._TIMELINE_PREFIX + 'Updating direction prices', direction_key=direction_key.to_string(), qkey=qkey)
        direction_cache_key = '{}/{}'.format(settings.WIZARD_CACHE_PREFIX_WITH_AIRPORTS, self._create_direction_key_with_settlements(direction_key).to_string())

        timeline.event('Direction cache key', direction_cache_key=direction_cache_key)

        # Все варианты в списке получены при поиске на одни и те же даты.
        _, _, departure_date, return_date, _ = qkey.split('_', 4)  # 'c213_c99_2018-12-14_None_economy_1_0_0_ru.ru'

        departure_date = self._parse_date(departure_date)
        return_date = self._parse_date(return_date)
        timeline.event(self._TIMELINE_PREFIX + 'Computed dates', departure_date=departure_date, return_date=return_date)

        if departure_date - utcnow.date() > timedelta(days=60):
            timeline.event(self._TIMELINE_PREFIX + 'Departure date is too far')
            return

        if return_date and return_date - departure_date > timedelta(days=21):
            timeline.event(self._TIMELINE_PREFIX + 'The trip is too long')
            return

        update_infos = self._prepare_update_infos(direction_key.from_point_key, direction_key.to_point_key, variants, departure_date, return_date, timeline)

        with lock(direction_cache_key, timeline=timeline):
            current = self._current(direction_cache_key)
            new_infos = self._updated(current, update_infos, departure_date, return_date, utcnow)

            timeline.event('Current saved count', saved_count=self._prices_count(current))
            timeline.event('New infos count', new_infos_count=self._prices_count(new_infos))

            if current == new_infos:
                timeline.event(self._TIMELINE_PREFIX + 'Double-checked locking: no need to update')
                return

            if not new_infos:
                timeline.event(self._TIMELINE_PREFIX + 'Nothing to save. Delete direction cheapest key', cache_key=direction_cache_key)
                json_saas.delete(direction_cache_key)
                return

            json_saas.index(
                direction_cache_key,
                doc={'cheapest': new_infos},
                expires_at=self._max_expires_at(new_infos),
            )

        timeline.event(self._TIMELINE_PREFIX + 'Computed and saved new infos', count=self._prices_count(new_infos))


direction_cheapest_updater = DirectionCheapestUpdater()


# Old logic for storing direction cheapest dates
def update_direction_cheapest_old(direction_key, variants, qkey, utcnow, timeline):
    # type: (DirectionKey, List[Dict[str, Any]], str, datetime, Timeline) -> None

    timeline.event('Updating direction prices', direction_key=direction_key.to_string(), qkey=qkey)
    direction_cache_key = '{}/{}'.format(settings.WIZARD_CACHE_PREFIX, direction_key.to_string())

    def parse_date(departure_date):
        if departure_date != 'None':
            return datetime.strptime(departure_date, '%Y-%m-%d').date()

    # Все варианты в списке получены при поиске на одни и те же даты.
    _, _, departure_date, return_date, _ = qkey.split('_', 4)  # 'c213_c99_2018-12-14_None_economy_1_0_0_ru.ru'

    departure_date = parse_date(departure_date)
    return_date = parse_date(return_date)
    timeline.event('Computed dates', departure_date=departure_date, return_date=return_date)

    if departure_date - utcnow.date() > timedelta(days=60):
        timeline.event('Departure date is too far')
        return

    if return_date and return_date - departure_date > timedelta(days=21):
        timeline.event('The trip is too long')
        return

    update_infos = [
        {
            'price': v['national_tariff_price'],
            'forward': departure_date,
            'backward': return_date,
            'expires_at': v['expires_at'],
        }
        for v in variants
    ]

    def _get_current():
        current_raw_result = (
            json_saas.search([direction_cache_key], low_priority=True)
            .get(direction_cache_key, {})
            .get('cheapest', [])
        )

        return _deserialize_direction_prices(current_raw_result)

    with lock(direction_cache_key, timeline=timeline):
        cur = _get_current()
        new_infos = _merge_old_and_new_infos(cur, update_infos, departure_date, return_date)
        if cur == new_infos:
            timeline.event('Double-checked locking: no need to update')
            return

        if not new_infos:
            timeline.event('Nothing to save. Delete direction cheapest key', cache_key=direction_cache_key)
            json_saas.delete(direction_cache_key)
            return
        json_saas.index(
            direction_cache_key,
            doc={'cheapest': new_infos},
            expires_at=max(info['expires_at'] for info in new_infos),
        )

    timeline.event('Computed and saved new infos', count=len(new_infos))
