from datetime import datetime
from logging import Logger, getLogger
from typing import Dict, Optional

from sqlalchemy.orm import Session
import ujson
from marshmallow import ValidationError

from travel.avia.price_index.lib.price_precision_logger import price_precision_logger, PricePrecisionLogger
from travel.avia.price_index.lib.query_searcher.fast_min_query_searcher import (
    FastMinQuerySearcher,
    searcher as fast_min_query_searcher,
)
from travel.avia.price_index.lib.search_query_finder import QueryFinder, History, search_query_finder
from travel.avia.price_index.models.batch_prices_form import BatchPricesForm
from travel.avia.price_index.models.query import Query
from travel.avia.price_index.schemas.batch_min_requests_query import BatchMinRequestsQuery
from travel.avia.price_index.lib.rates_provider import rates_provider
from travel.avia.price_index.lib.currency_provider import currency_provider
from travel.avia.price_index.views.helpers import BadRequest


class BatchPreciseLogger(object):
    def __init__(self, price_precision_logger, logger):
        # type: (PricePrecisionLogger, Logger) -> None
        self._price_precision_logger = price_precision_logger
        self._logger = logger

    def log(self, history_by_query, base_value_by_query, query_source):
        # type: (Dict[Query, Optional[History]], Dict[Query, int]) -> None
        try:
            self._unsafe_log(history_by_query, base_value_by_query, query_source)
        except Exception as e:
            self._logger.exception('Can not log a precision records %r', e)

    @staticmethod
    def _force_iso(date):
        # type: (datetime) -> Optional[unicode]
        if not date:
            return None

        return date.isoformat()

    def _unsafe_log(self, history_by_query, base_value_by_query, query_source):
        # type: (Dict[Query, Optional[History]], Dict[Query, int]) -> None
        log_context = self._price_precision_logger.start_log(query_source)

        for query in history_by_query.keys():
            try:
                log_context.log(
                    from_id=query.from_id,
                    to_id=query.to_id,
                    forward_date=self._force_iso(query.forward_date),
                    backward_date=self._force_iso(query.backward_date),
                    adults_count=query.adults_count,
                    children_count=query.children_count,
                    infants_count=query.infants_count,
                    national_version_id=query.national_version_id,
                    has_price=query in base_value_by_query,
                )
            except Exception as e:
                self._logger.exception('Can not log a precision record %r', e)


class MinPriceBatchSearchView(object):
    def __init__(self, batch_precision_logger, fast_min_query_searcher, search_query_finder, logger):
        # type: (BatchPreciseLogger, FastMinQuerySearcher, QueryFinder, Logger) -> None
        self._batch_precision_logger = batch_precision_logger
        self._fast_min_query_searcher = fast_min_query_searcher
        self._search_query_finder = search_query_finder
        self._logger = logger

    def parse_form(self, params, body, query_source):
        post_data = ujson.loads(body)

        params = dict(national_version=params['national_version'], **post_data)
        if query_source is not None:
            params['query_source'] = query_source
        try:
            return BatchMinRequestsQuery().load(params)
        except ValidationError as e:
            raise BadRequest('Can not parse search form') from e

    def process(self, session, batch_prices_form):
        # type: (Session, BatchPricesForm) -> any
        self._logger.info('start')
        national_version_id = batch_prices_form.national_version_id

        base_value_by_query, history_by_query = self._fast_min_query_searcher.batch_find(
            session=session,
            queries=batch_prices_form.queries,
        )
        self._logger.info('fetch db objects')

        self._batch_precision_logger.log(history_by_query, base_value_by_query, batch_prices_form.query_source)
        self._logger.info('log precise')

        currency_id = rates_provider.get_base_currency_id(nv_id=national_version_id)
        currency_code = currency_provider.get_by_id(currency_id, national_version_id).code

        result = [
            self._serialize_variant(
                history=history, value=base_value_by_query.get(q), currency_code=currency_code, query=q
            )
            for q, history in history_by_query.items()
        ]

        self._logger.info('serialize')

        return result

    def _serialize_price(self, value, currency_code):
        if not value:
            return None
        return {
            'value': value,
            'currency': currency_code,
        }

    def _serialize_variant(self, history, value, currency_code, query):
        # type: (History, int, currency_code, query) -> dict
        data = {
            'from_id': query.from_id,
            'to_id': query.to_id,
            'adults_count': query.adults_count,
            'children_count': query.children_count,
            'infants_count': query.infants_count,
            'forward_date': query.forward_date.strftime('%Y-%m-%d'),
            'min_price': self._serialize_price(value, currency_code),
            'updatedAt': None,
            'expired': None,
        }

        if query.backward_date is not None:
            data['backward_date'] = query.backward_date.strftime('%Y-%m-%d')

        if history is not None and history.updated_at is not None:
            data['updatedAt'] = history.updated_at.isoformat()
            data['expired'] = history.expired()

        return data


_batch_precise_logger = BatchPreciseLogger(price_precision_logger=price_precision_logger, logger=getLogger(__name__))

min_price_batch_search_view = MinPriceBatchSearchView(
    batch_precision_logger=_batch_precise_logger,
    fast_min_query_searcher=fast_min_query_searcher,
    search_query_finder=search_query_finder,
    logger=getLogger(__name__),
)
