# -*- coding: utf-8 -*-


from copy import copy
from datetime import datetime, timedelta
import logging
from typing import Dict, Tuple, Optional, List

import sqlalchemy
from sqlalchemy import union_all
from sqlalchemy.orm import Session

from travel.avia.price_index.db_models.result import Result
from travel.avia.price_index.lib.constants import NULL_DATE
from travel.avia.price_index.lib.date_range_iterator import date_range_iterator, DateRangeIterator
from travel.avia.price_index.models.query import Query
from travel.avia.price_index.models.search_form import SearchForm


logger = logging.getLogger(__name__)


class HistoryFetcher(object):
    def butch_fetch(self, session, queries):
        # type: (Session, Tuple[Query, ...]) -> Dict[Query, Tuple[datetime, bool]]

        if not queries:
            return {}

        rows = session.execute(
            union_all(
                *[
                    session.query(sqlalchemy.literal(index).label('idx'), Result.base_value, Result.updated_at).filter(
                        Result.national_version_id == query.national_version_id,
                        Result.from_id == query.from_id,
                        Result.to_id == query.to_id,
                        Result.adults_count == query.adults_count,
                        Result.children_count == query.children_count,
                        Result.infants_count == query.infants_count,
                        Result.forward_date == query.forward_date,
                        Result.backward_date == (query.backward_date or NULL_DATE),
                    )
                    for index, query in enumerate(queries)
                ]
            )
        ).fetchall()

        return {queries[index]: (updated_at, base_value is not None) for index, base_value, updated_at in rows}


class History(object):
    def __init__(self, updated_at, has_variants, requested_query):
        # type: (datetime, bool, Query) -> None
        self.updated_at = updated_at
        self.has_variants = has_variants
        self.requested_query = requested_query

    def __repr__(self):
        return "<History updated_at={} has_variants={} requested_query={}>".format(
            self.updated_at,
            self.has_variants,
            self.requested_query,
        )

    def __eq__(self, o):
        if not isinstance(o, History):
            return False

        return (
            self.updated_at == o.updated_at
            and self.has_variants == o.has_variants
            and self.requested_query == o.requested_query
        )

    def __hash__(self):
        return hash(self.updated_at) + hash(self.has_variants) + self.requested_query.__hash__()

    def expired(self) -> bool:
        current_time = datetime.now()
        if self.updated_at > current_time:
            logger.warning(f'updated_at > current_time: history: {repr(self)}, current_time: {current_time}')
            return False
        else:
            return (current_time - self.updated_at) >= timedelta(hours=6)


def _make_multi_query(query, forward_date, backward_date):
    q = copy(query)
    q.forward_date = forward_date
    q.backward_date = backward_date

    return q


def _make_single_query(query, forward_date, backward_date):
    q = _make_multi_query(query, forward_date, backward_date)
    q.adults_count = 1
    q.children_count = 0
    q.infants_count = 0

    return q


class AbstractQueryFinder(object):
    def find_by_date_range(self, session, search_form):
        # type: (Session, SearchForm) -> Dict[Query, Optional[History]]
        raise NotImplementedError


class QueryFinderForSingle(AbstractQueryFinder):
    def __init__(self, history_fetcher, date_range_iterator):
        # type: (HistoryFetcher, DateRangeIterator) -> None
        self._history_fetcher = history_fetcher
        self._date_range_iterator = date_range_iterator

    def find_by_date_range(self, session, search_form):
        # type: (Session, SearchForm) -> Dict[Query, Optional[History]]
        queries = [
            _make_single_query(search_form.query, forward_date, backward_date)
            for forward_date, backward_date in self._date_range_iterator.iterate(
                search_form.query, search_form.date_range
            )
        ]

        return self.find_by_queries(session, queries)

    def find_by_queries(self, session, queries):
        # type: (Session, List[Query]) -> Dict[Query, Optional[History]]
        history_data_by_queries = self._history_fetcher.butch_fetch(session=session, queries=tuple(queries))

        result = {}

        for q in queries:
            data = history_data_by_queries.get(q)

            if data:
                result[q] = History(updated_at=data[0], has_variants=data[1], requested_query=q)
            else:
                result[q] = None

        return result


class QueryFinderForMulti(AbstractQueryFinder):
    def __init__(self, history_fetcher, date_range_iterator):
        # type: (HistoryFetcher, DateRangeIterator) -> None
        self._history_fetcher = history_fetcher
        self._date_range_iterator = date_range_iterator

    def find_by_date_range(self, session, search_form):
        # type: (Session, SearchForm) -> Dict[Query, Optional[History]]
        exac_queries = [
            _make_multi_query(search_form.query, forward_date, backward_date)
            for forward_date, backward_date in self._date_range_iterator.iterate(
                search_form.query, search_form.date_range
            )
        ]

        approximate_queries = [
            _make_single_query(search_form.query, forward_date, backward_date)
            for forward_date, backward_date in self._date_range_iterator.iterate(
                search_form.query, search_form.date_range
            )
        ]

        history_data_by_queries = self._history_fetcher.butch_fetch(
            session=session, queries=tuple(exac_queries + approximate_queries)
        )

        result = {}

        count_queries = len(exac_queries)

        for index in range(count_queries):
            exac_query = exac_queries[index]
            approximate_query = approximate_queries[index]
            requested_query = self._take_more_fresh_query(
                exac_query,
                approximate_query,
                exac_data=history_data_by_queries.get(exac_query),
                approximate_data=history_data_by_queries.get(approximate_query),
            )

            if requested_query:
                updated_at, has_variants = history_data_by_queries[requested_query]
                result[exac_query] = History(
                    updated_at=updated_at, has_variants=has_variants, requested_query=requested_query
                )
            else:
                result[exac_query] = None

        return result

    def find_by_queries(self, session, queries):
        raise NotImplementedError('todo')

    def _take_more_fresh_query(self, exac_query, approximate_query, exac_data, approximate_data):
        if exac_data is None and approximate_data is None:
            return None
        if exac_data is None or approximate_data is None:
            if exac_data is not None:
                return exac_query
            else:
                return approximate_query

        exec_updated_at, _ = exac_data
        approximate_updated_at, _ = approximate_data
        if approximate_updated_at > exec_updated_at:
            return approximate_query
        return exac_query


class QueryFinder(AbstractQueryFinder):
    def __init__(self, query_finder_for_single, query_finder_for_multi):
        # type: (QueryFinderForSingle, QueryFinderForMulti) -> None
        self._query_finder_for_single = query_finder_for_single
        self._query_finder_for_multi = query_finder_for_multi

    def find_by_date_range(self, session, search_form):
        # type: (Session, SearchForm) -> Dict[Query, History]

        finder = self._query_finder_for_multi
        if search_form.query.is_single():
            finder = self._query_finder_for_single

        return finder.find_by_date_range(session=session, search_form=search_form)

    def find_by_queries(self, session, queries, is_single):
        # type: (Session, List[Query], bool) -> Dict[Query, History]
        finder = self._query_finder_for_multi
        if is_single:
            finder = self._query_finder_for_single

        return finder.find_by_queries(session=session, queries=queries)


_history_fetcher = HistoryFetcher()
_query_finder_for_single = QueryFinderForSingle(
    history_fetcher=_history_fetcher, date_range_iterator=date_range_iterator
)
_query_finder_for_multi = QueryFinderForMulti(history_fetcher=_history_fetcher, date_range_iterator=date_range_iterator)

search_query_finder = QueryFinder(
    query_finder_for_single=_query_finder_for_single, query_finder_for_multi=_query_finder_for_multi
)
