# -*- coding: utf-8 -*-
import logging
import ujson
from collections import namedtuple
from contextlib import closing
from functools import partial
from logging import Logger
from multiprocessing import Pool
from typing import Tuple, Dict, Iterator

import sqlalchemy
import zlib
from sqlalchemy import and_, union_all
from sqlalchemy.orm import Session, Query as SaQuery

from travel.avia.price_index.db_models.result import Result
from travel.avia.price_index.lib.constants import NULL_DATE
from travel.avia.price_index.models.filters import TimeFilter, AirportsFilter, Filters, TransferFilter
from travel.avia.price_index.models.query import Query

MAX_FILTER_APPLYING_WORKERS = 3

QueryVariants = namedtuple('QueryVariants', ['query_idx', 'variants'])
QueryResult = namedtuple('QueryResult', ['query_idx', 'result'])


class FilterQuerySearcher(object):
    def __init__(self, logger):
        # type: (Logger) -> None
        self._logger = logger

    def batch_find(self, session, queries, filters):
        # type: (Session, Tuple[Query, ...], Filters) -> Dict[Query, int]

        if not queries:
            return {}

        q = union_all(*[self._find(session=session, query=query, index=index) for index, query in enumerate(queries)])

        try:
            data = []
            for idx, raw_variants in session.execute(q).fetchall():
                if raw_variants is None:
                    continue
                data.append(QueryVariants(idx, raw_variants))

            with closing(Pool(processes=MAX_FILTER_APPLYING_WORKERS)) as pool:
                variants_list = pool.map(partial(process_variants, filters=filters), data)

            return {queries[elem.query_idx]: elem.result for elem in variants_list}
        except Exception:
            self._logger.exception('Can not filter variants')
            raise

    def _find(self, session, query, index):
        # type: (Session, Query, int) -> SaQuery
        return session.query(sqlalchemy.literal(index).label('idx'), Result.gzip_data,).filter(
            and_(
                Result.national_version_id == query.national_version_id,
                Result.from_id == query.from_id,
                Result.to_id == query.to_id,
                Result.forward_date == query.forward_date,
                Result.backward_date == (query.backward_date or NULL_DATE),
                Result.adults_count == query.adults_count,
                Result.children_count == query.children_count,
                Result.infants_count == query.infants_count,
            )
        )


class VariantsFilterApplier(object):
    def apply_filters(self, variants, filters):
        # type: (Iterator[dict], Filters) -> Iterator[dict]

        variants = self._apply_transfer_filters(variants, filters.transfer_filters)
        variants = self._apply_time_filters(variants, filters.time_filters)
        variants = self._apply_airport_filters(variants, filters.airports_filters)

        if filters.with_baggage is not None:
            variants = [v for v in variants if v['has_baggage'] == filters.with_baggage]

        if filters.airlines:
            airlines = set(filters.airlines)
            variants = [v for v in variants if v['airlines'] == airlines]

        return variants

    def _apply_transfer_filters(self, variants, filters):
        # type: (Iterator[dict], TransferFilter) -> Iterator[dict]
        if filters.count is not None:
            variants = [v for v in variants if v['count_transfer'] <= filters.count]

        if filters.min_duration is not None:
            variants = [v for v in variants if v['duration_transfer'] >= filters.min_duration]
        if filters.max_duration is not None:
            variants = [v for v in variants if v['duration_transfer'] <= filters.max_duration]

        if filters.has_night is not None:
            variants = [v for v in variants if v['has_night_transfer'] == filters.has_night]

        if filters.has_airport_change is not None:
            variants = [v for v in variants if v['has_airport_change'] == filters.has_airport_change]
        return variants

    def _apply_airport_filters(self, variants, filters):
        # type: (Iterator[dict], AirportsFilter) -> Iterator[dict]
        if filters.forward_departure:
            variants = [v for v in variants if v['forward_departure_airport'] in filters.forward_departure]
        if filters.forward_arrival:
            variants = [v for v in variants if v['forward_arrival_airport'] in filters.forward_arrival]
        if filters.forward_transfers:
            forward_transfers = set(filters.forward_transfers)

            variants = [v for v in variants if set(v['forward_transfer_airports']) & forward_transfers]

        if filters.backward_departure:
            variants = [v for v in variants if v['backward_departure_airport'] in filters.backward_departure]
        if filters.backward_arrival:
            variants = [v for v in variants if v['backward_arrival_airport'] in filters.backward_arrival]
        if filters.backward_transfers:
            backward_transfers = set(filters.backward_transfers)
            variants = [v for v in variants if set(v['backward_transfer_airports']) & backward_transfers]
        return variants

    def _apply_time_filters(self, variants, filters):
        # type: (Iterator[dict], TimeFilter) -> Iterator[dict]

        checks = [
            (filters.forward_arrival, 'forward_arrival_time_type'),
            (filters.forward_departure, 'forward_departure_time_type'),
            (filters.backward_arrival, 'backward_arrival_time_type'),
            (filters.backward_departure, 'backward_departure_time_type'),
        ]

        for constraint, key in checks:
            if constraint is not None:
                variants = [v for v in variants if v[key] in constraint]

        return variants


def process_variants(raw_variants, filters):
    variants = ujson.loads(zlib.decompress(raw_variants.variants))
    if not variants:
        return raw_variants.query_idx, None
    filters_applier = VariantsFilterApplier()
    variants = list(filters_applier.apply_filters(variants, filters))
    return QueryResult(raw_variants.query_idx, variants[0]['base_value'] if variants else None)


searcher = FilterQuerySearcher(logging.getLogger(__name__))
