import ujson
from datetime import datetime
from logging import Logger, getLogger

from sqlalchemy.orm import Session
from typing import Dict, Optional
from marshmallow import ValidationError

from travel.avia.price_index.lib.base_currency_provider import BaseCurrencyProvider, base_currency_provider
from travel.avia.price_index.lib.passengers_multiplier import PassengersMultiplier, passengers_multiplier
from travel.avia.price_index.lib.price_precision_logger import price_precision_logger, PricePrecisionLogger
from travel.avia.price_index.lib.query_searcher import QuerySearcher, query_searcher

from travel.avia.price_index.lib.search_query_finder import QueryFinder, History, search_query_finder
from travel.avia.price_index.models.query import Query
from travel.avia.price_index.models.search_form import SearchForm
from travel.avia.price_index.schemas.search import SearchFormSchema
from travel.avia.price_index.views.helpers import BadRequest

VARIANT_STATUS_NO_UNKNOWN = 'unknown'
VARIANT_STATUS_NO_SEARCH_DATA = 'no-data'
VARIANT_STATUS_NO_FILTERED_DATA = 'no-filter-data'
VARIANT_STATUS_HAS_DATA = 'has-data'


class SearchPreciseLogger(object):
    def __init__(self, price_precision_logger, logger):
        # type: (PricePrecisionLogger, Logger) -> None
        self._price_precision_logger = price_precision_logger
        self._logger = logger

    def log(self, history_by_query, base_value_by_query, search_form):
        # type: (Dict[Query, Optional[History]], Dict[Query, int], SearchForm) -> None
        try:
            self._unsafe_log(history_by_query, base_value_by_query, search_form)
        except Exception as e:
            self._logger.exception('Can not log a precision records %r', e)

    @staticmethod
    def _force_iso(date):
        # type: (datetime) -> Optional[unicode]
        if not date:
            return None

        return date.isoformat()

    def _unsafe_log(self, history_by_query, base_value_by_query, search_form):
        # type: (Dict[Query, Optional[History]], Dict[Query, int], SearchForm) -> None
        log_context = self._price_precision_logger.start_log(query_source=search_form.query_source)
        log_context.fill_filter_data(filter_model=search_form.filters)

        for query in history_by_query.keys():
            try:
                log_context.log(
                    from_id=query.from_id,
                    to_id=query.to_id,
                    forward_date=self._force_iso(query.forward_date),
                    backward_date=self._force_iso(query.backward_date),
                    adults_count=query.adults_count,
                    children_count=query.children_count,
                    infants_count=query.infants_count,
                    national_version_id=query.national_version_id,
                    has_price=query in base_value_by_query,
                )
            except Exception as e:
                self._logger.exception('Can not log a precision record %r', e)


class SearchView(object):
    def __init__(
        self,
        search_precise_logger,
        passengers_multiplier,
        query_searcher,
        search_query_finder,
        base_currency_provider,
        logger,
    ):
        # type: (SearchPreciseLogger, PassengersMultiplier, QuerySearcher, QueryFinder, BaseCurrencyProvider, Logger) -> None
        self._search_precise_logger = search_precise_logger
        self._passengers_multiplier = passengers_multiplier
        self._query_searcher = query_searcher
        self._query_finder = search_query_finder
        self._base_currency_provider = base_currency_provider

        self._logger = logger

    def parse_form(self, params, body, query_source=None):
        return self._parse_form(params, body, query_source)

    def _parse_form(self, params, body, query_source):
        search_data = {'query': params, 'range': params}

        post_data = ujson.loads(body)
        if 'filters' in post_data:
            search_data['filters'] = post_data['filters']
        if 'formatVersion' in post_data:
            search_data['formatVersion'] = post_data['formatVersion']

        if query_source is not None:
            search_data['querySource'] = query_source

        try:
            result = SearchFormSchema().load(search_data)
        except ValidationError as e:
            raise BadRequest('Can not parse search form') from e

        return result

    def process(self, session, search_form):
        # type: (Session, SearchForm) -> object

        currency_code = self._base_currency_provider.get_code(search_form.query.national_version_id)
        if currency_code is None:
            raise BadRequest(
                'Can not find currency for [{}] national_version_id'.format(search_form.query.national_version_id)
            )

        history_by_query = self._query_finder.find_by_date_range(session=session, search_form=search_form)

        requested_queries = tuple(
            history.requested_query
            for history in history_by_query.values()
            if history is not None and history.has_variants
        )

        query_by_requested_query = {
            history.requested_query: query
            for query, history in history_by_query.items()
            if history is not None and history.has_variants
        }

        raw_base_value_by_requested_query = self._query_searcher.batch_find(
            session=session, queries=requested_queries, filters=search_form.filters
        )

        base_value_by_query = {
            query_by_requested_query[requested_query]: self._passengers_multiplier.normalize(
                raw_base_value=raw_base_value,
                requested_query=requested_query,
                query=query_by_requested_query[requested_query],
            )
            for requested_query, raw_base_value in raw_base_value_by_requested_query.items()
        }

        self._search_precise_logger.log(
            history_by_query=history_by_query, base_value_by_query=base_value_by_query, search_form=search_form
        )

        result = {}
        for q, history in history_by_query.items():
            item = self._serialize_variant(
                history=history,
                value=base_value_by_query.get(q),
                currency_code=currency_code,
                query=q,
                has_filters=not search_form.filters.is_empty(),
                format_version=search_form.format_version,
            )

            if item:
                result[self._serialize_route(q)] = item

        return result

    def _serialize_route(self, query):
        # type: (Query) -> str
        return query.forward_date.isoformat()

    def _serialize_variant(self, history, value, currency_code, query, has_filters, format_version):
        # type: (History, int, currency_code, query, has_filters) -> Optional[dict]
        if history is None:
            if format_version == "1.1.0":
                return {'status': VARIANT_STATUS_NO_UNKNOWN}
            return None

        serialized_updated_at = history.updated_at.isoformat() if history.updated_at else None
        expired = history.expired()

        if not history.has_variants:
            if format_version == "1.1.0":
                return {'status': VARIANT_STATUS_NO_SEARCH_DATA, 'updatedAt': serialized_updated_at, 'expired': expired}
            return None

        if history.has_variants and value is None:
            if format_version == "1.1.0":
                return {
                    'status': VARIANT_STATUS_NO_FILTERED_DATA if has_filters else VARIANT_STATUS_NO_SEARCH_DATA,
                    'updatedAt': serialized_updated_at,
                    'expired': expired,
                }
            return None

        if format_version == "1.1.0":
            return {
                'status': VARIANT_STATUS_HAS_DATA,
                'value': value,
                'baseValue': value,
                'roughly': not query.is_single() and history.requested_query.is_single(),
                'currency': currency_code,
                'updatedAt': serialized_updated_at,
                'expired': expired,
            }
        else:
            return {
                'value': value,
                'baseValue': value,
                'currency': currency_code,
                'updatedAt': serialized_updated_at,
                'expired': expired,
            }


_search_precise_logger = SearchPreciseLogger(price_precision_logger=price_precision_logger, logger=getLogger(__name__))

search_view = SearchView(
    search_precise_logger=_search_precise_logger,
    passengers_multiplier=passengers_multiplier,
    query_searcher=query_searcher,
    search_query_finder=search_query_finder,
    base_currency_provider=base_currency_provider,
    logger=getLogger(__name__),
)
