# -*- coding: utf-8 -*-


from typing import Tuple, Dict

import sqlalchemy
from sqlalchemy import and_, union_all
from sqlalchemy.orm import Session, Query as SaQuery

from travel.avia.price_index.db_models.result import Result
from travel.avia.price_index.lib.constants import NULL_DATE
from travel.avia.price_index.models.query import Query


class MinQuerySearcher(object):
    def batch_find(self, session, queries):
        # type: (Session, Tuple[Query, ...]) -> Dict[Query, int]

        if not queries:
            return {}

        q = union_all(
            *[
                self._find(
                    session=session,
                    query=query,
                    index=index,
                )
                for index, query in enumerate(queries)
            ]
        )

        return {queries[index]: base_value for index, base_value in session.execute(q).fetchall()}

    def _find(self, session, query, index):
        # type: (Session, Query, int) -> SaQuery
        return session.query(sqlalchemy.literal(index).label('idx'), Result.base_value,).filter(
            and_(
                Result.national_version_id == query.national_version_id,
                Result.adults_count == query.adults_count,
                Result.children_count == query.children_count,
                Result.infants_count == query.infants_count,
                Result.from_id == query.from_id,
                Result.to_id == query.to_id,
                Result.forward_date == query.forward_date,
                Result.backward_date == (query.backward_date or NULL_DATE),
            )
        )


searcher = MinQuerySearcher()
