# coding=utf-8
import heapq
from logging import Logger, getLogger

from sqlalchemy import and_, or_, func

from travel.avia.price_index.db_models.result import Result
from travel.avia.price_index.lib.adjusted_date_window import AdjustedDateWindow
from travel.avia.price_index.lib.constants import NULL_DATE
from travel.avia.price_index.lib.db.storage import Storage, slave_storage
from travel.avia.price_index.lib.iterrecipes import group_by
from travel.avia.price_index.models.top_directions_by_date_window_request import TopDirectionsByDateWindowRequest


class TopDirectionsByDateWindowSearcher(object):
    def __init__(self, storage, logger):
        # type: (Storage, Logger) -> None
        self._storage = storage
        self._logger = logger

    def search(self, national_version_id, request):
        # type: (int, TopDirectionsByDateWindowRequest) -> any
        session = self._storage.get_session()

        self._logger.info('Batch search %r', request)

        try:
            result = self._search(session, national_version_id, request)
            session.rollback()
            return result
        except Exception as e:
            self._logger.exception('An error occurred while searching: %r %r', request, e)
            session.rollback()
            raise
        finally:
            session.close()

    def _search(self, session, national_version_id, request):
        query = (
            session.query(
                func.min(Result.base_value), Result.from_id, Result.to_id, Result.forward_date, Result.backward_date
            )
            .filter(
                and_(
                    self._date_part(request.forward_date, request.backward_date, request.window_size),
                    self._directions_part(request.directions),
                    Result.national_version_id == national_version_id,
                    Result.adults_count == 1,
                    Result.children_count == 0,
                    Result.infants_count == 0,
                )
            )
            .group_by(Result.from_id, Result.to_id, Result.forward_date, Result.backward_date)
        )
        result = session.execute(query).fetchall()
        return [
            {
                'min_price': min_row[0],
                'from_id': min_row[1],
                'to_id': min_row[2],
                'forward_date': min_row[3],
                'backward_date': min_row[4],
            }
            for _, rows in group_by(result, key=lambda item: (item[1], item[2]))
            for min_row in heapq.nsmallest(request.results_per_direction, rows, key=lambda item: item[0])
        ]

    def _directions_part(self, directions):
        return or_(*[and_(Result.from_id == d.orig_id, Result.to_id == d.dest_id) for d in directions])

    def _date_part(self, forward_date, backward_date, window_size):
        if backward_date:
            adjusted_forward_date = AdjustedDateWindow(forward_date, window_size)
            adjusted_backward_date = AdjustedDateWindow(backward_date, window_size)
            return and_(
                Result.base_value.isnot(None),
                Result.backward_date >= adjusted_backward_date.left_boundary,
                Result.backward_date <= adjusted_backward_date.right_boundary,
                Result.forward_date >= adjusted_forward_date.left_boundary,
                Result.forward_date <= adjusted_forward_date.right_boundary,
                Result.forward_date < Result.backward_date,
            )
        else:
            adjusted_forward_date = AdjustedDateWindow(forward_date, window_size)
            return and_(
                Result.base_value.isnot(None),
                Result.backward_date == NULL_DATE,
                Result.forward_date >= adjusted_forward_date.left_boundary,
                Result.forward_date <= adjusted_forward_date.right_boundary,
            )


top_directions_by_date_windows_searcher = TopDirectionsByDateWindowSearcher(
    storage=slave_storage, logger=getLogger(__name__)
)
