# coding=utf-8
import zlib
import ujson
from logging import Logger, getLogger
from datetime import date, timedelta

from sqlalchemy import and_, or_, func

from travel.avia.price_index.db_models.result import Result
from travel.avia.price_index.lib.constants import NULL_DATE
from travel.avia.price_index.lib.db.storage import Storage, slave_storage
from travel.avia.price_index.models.next_days_min_prices_request import NextDaysMinPricesRequest


class NextDaysMinPricesSearcher(object):
    def __init__(self, storage, logger):
        # type: (Storage, Logger) -> None
        self._storage = storage
        self._logger = logger

    def search(self, national_version_id, request):
        # type: (int, NextDaysMinPricesRequest) -> any
        session = self._storage.get_session()

        self._logger.info('Batch search %r', request)

        try:
            result = self._search(session, national_version_id, request)
            session.rollback()
            return result
        except Exception as e:
            self._logger.exception('An error occurred while searching: %r %r', request, e)
            session.rollback()
            raise
        finally:
            session.close()

    def _search(self, session, national_version_id, request):
        rows_filter = and_(
            self._date_part(request.window_size),
            Result.national_version_id == national_version_id,
            Result.adults_count == 1,
            Result.children_count == 0,
            Result.infants_count == 0,
            self._directions_part(request.directions),
        )

        min_prices_query = (
            session.query(
                func.min(Result.base_value).label('min_price'),
                Result.from_id.label('from_id'),
                Result.to_id.label('to_id'),
            )
            .filter(rows_filter)
            .group_by(
                Result.from_id,
                Result.to_id,
            )
            .subquery()
        )

        row_number_column = (
            func.row_number()
            .over(
                partition_by=(Result.from_id, Result.to_id),
                order_by=(Result.base_value, Result.forward_date),
            )
            .label('row_number')
        )

        all_rows = (
            session.query(
                Result.base_value.label('base_value'),
                Result.from_id.label('from_id'),
                Result.to_id.label('to_id'),
                Result.forward_date,
                Result.backward_date,
                Result.gzip_data,
                row_number_column,
            )
            .filter(rows_filter)
            .subquery()
        )

        query = session.query(
            all_rows.c.base_value,
            all_rows.c.from_id,
            all_rows.c.to_id,
            all_rows.c.forward_date,
            all_rows.c.backward_date,
            all_rows.c.gzip_data,
        ).join(
            min_prices_query,
            and_(
                all_rows.c.row_number == 1,
                all_rows.c.base_value == min_prices_query.c.min_price,
                all_rows.c.from_id == min_prices_query.c.from_id,
                all_rows.c.to_id == min_prices_query.c.to_id,
            ),
        )
        result = session.execute(query).fetchall()
        return [
            {
                'min_price': min_row[0],
                'from_id': min_row[1],
                'to_id': min_row[2],
                'forward_date': min_row[3],
                'backward_date': min_row[4],
                'transfers_count': self._get_transfers_count(min_row[5]),
            }
            for min_row in result
        ]

    @staticmethod
    def _get_transfers_count(gzip_data):
        if not gzip_data:
            return None
        variants_json = zlib.decompress(gzip_data)
        if not variants_json:
            return None

        variants = ujson.loads(variants_json)
        if not variants:
            return None
        best_variant = variants[0]
        return best_variant['count_transfer']

    @staticmethod
    def _directions_part(directions):
        return or_(
            *[
                and_(
                    Result.from_id == d.orig_id,
                    Result.to_id == d.dest_id,
                )
                for d in directions
            ]
        )

    @staticmethod
    def _date_part(window_size):
        from_date = date.today()
        to_date = date.today() + timedelta(days=window_size)
        return and_(
            Result.base_value.isnot(None),
            Result.backward_date == NULL_DATE,
            Result.forward_date >= from_date,
            Result.forward_date <= to_date,
        )


next_days_min_prices_searcher = NextDaysMinPricesSearcher(
    storage=slave_storage,
    logger=getLogger(__name__),
)
