# coding=utf-8
from typing import List

import sqlalchemy
from sqlalchemy import and_, union_all

from travel.avia.price_index.db_models.result import Result
from travel.avia.price_index.lib.db.storage import Storage, slave_storage


class HistoryFetcher(object):
    def __init__(self, storage):
        # type: (Storage) -> None
        self._storage = storage

    def fetch(self, items):
        # type: (List[dict]) -> any
        session = self._storage.get_session()

        try:
            result = self._fetch_history(session, items)
            session.close()
            return result
        except Exception:
            session.rollback()
            raise

    def _fetch_history(self, session, items):
        if len(items) == 0:
            return {}

        q = union_all(
            *[
                session.query(sqlalchemy.literal(item['idx']).label('idx'), Result.updated_at).filter(
                    and_(
                        Result.national_version_id == item['national_version_id'],
                        Result.from_id == item['from_id'],
                        Result.to_id == item['to_id'],
                        Result.forward_date == item['forward_date'],
                        Result.backward_date == item['backward_date'],
                        Result.adults_count == 1,
                        Result.children_count == 0,
                        Result.infants_count == 0,
                    )
                )
                for idx, item in enumerate(items)
            ]
        )

        return dict(session.execute(q).fetchall())


history_fetcher = HistoryFetcher(
    storage=slave_storage,
)
