# coding=utf-8
import functools
import os
import time as _time
from copy import deepcopy
from datetime import datetime
from itertools import product, dropwhile, chain
from logging import getLogger

import gevent
from django.conf import settings
from django.utils.functional import cached_property
from flask import request, copy_current_request_context
from werkzeug.exceptions import BadRequest

from travel.avia.library.python.common.utils.date import parse_date, MSK_TZ
from travel.avia.library.python.common.utils.marketstat import JSONLog
from travel.avia.library.python.ticket_daemon.caches.services import get_service_by_code
from travel.library.python.tvm_ticket_provider import provider_fabric

from travel.avia.ticket_daemon_api.jsonrpc.lib import feature_flags
from travel.avia.ticket_daemon_api.jsonrpc.lib.date import unixtime
from travel.avia.ticket_daemon_api.jsonrpc.lib.enabled_partner_provider import enabled_partner_provider
from travel.avia.ticket_daemon_api.jsonrpc.lib.internal_daemon_client import (
    internal_daemon_client, InternalDaemonClient
)
from travel.avia.library.python.ticket_daemon.memo import memoize
from travel.avia.library.python.shared_flights_client.client import SharedFlightsClient
from travel.avia.ticket_daemon_api.jsonrpc.limits import InitSearchMCRLimits
from travel.avia.ticket_daemon_api.jsonrpc.models_utils.geo import get_point_tuple_by_key, get_settlement_by_id
from travel.avia.ticket_daemon_api.jsonrpc.models_utils.search_codes import get_codes_for_search

http_api_search_log = JSONLog(
    os.path.join(settings.LOG_PATH, 'yt/http_api_search.log')
)
log = getLogger(__name__)


class QueryIsNotValid(Exception):
    pass


class QueryAllException(Exception):
    def __init__(self, message, status_code=500):
        super(QueryAllException, self).__init__(message)
        self._status_code = status_code

    def status_code(self):
        return self._status_code


@memoize(lambda point: point.point_key)
def _get_point_related_settlement(point):
    # assert isinstance(point, PointInterface) todo
    return point.get_related_settlement()


def format_qid(created, service, t_code, qkey, lang):
    return '{:%y%m%d-%H%M%S}-{:0>3d}.{}.{}.{}.{}'.format(
        datetime.fromtimestamp(int(created), tz=MSK_TZ),
        # Миллисекунды времени создания
        int(1000 * created) % 1000,
        service,
        t_code,
        qkey,
        lang
    )


def timestamp(dt):
    """Return POSIX timestamp as float"""
    if dt.tzinfo is None:
        return _time.mktime((dt.year, dt.month, dt.day,
                             dt.hour, dt.minute, dt.second,
                             -1, -1, -1)) + dt.microsecond / 1e6
    else:
        raise NotImplementedError('Only for naive datetimes')


def parse_qid(qid):
    when_created_raw, service, t_code, qkey, lang = qid.split('.')
    when_created = _parse_when_created(when_created_raw)
    return when_created, service, t_code, qkey, lang


def _parse_when_created(raw):
    return timestamp(datetime.strptime(raw, '%y%m%d-%H%M%S-%f'))


def check_limits(_query_all_pure):
    @functools.wraps(_query_all_pure)
    def wrapper(q, ignore_cache=False, test_id=None, base_qid=None, user_info=None):
        if InitSearchMCRLimits.has_limits(q.service):
            if getattr(q, 'partner_codes', None):
                partners = q.partner_codes.split(',')
            else:
                partners = q.get_enabled_partner_codes(for_init_search=True)

            limits = InitSearchMCRLimits.get(q.service, partners)

            if limits and not any(limits.values()):
                response_data = {
                    'error': 'Limits exceeded',
                    'limits': limits,
                    'bound': InitSearchMCRLimits.bound(q.service, limits)
                }
                raise BadRequest(response_data)
            q.partner_codes = ','.join(filter(lambda x: limits[x] > 0, partners))
            outer_result = _query_all_pure(
                q,
                ignore_cache=ignore_cache,
                base_qid=base_qid,
                test_id=test_id,
                user_info=user_info,
            )

            queried_partners = outer_result.get('partners', [])
            updated_limits = InitSearchMCRLimits.set(q.service, queried_partners)
            limits.update(updated_limits)
            outer_result['limits'] = limits
            outer_result['bound'] = InitSearchMCRLimits.bound(q.service, limits)
            return outer_result
        else:
            return _query_all_pure(
                q,
                ignore_cache=ignore_cache,
                base_qid=base_qid,
                test_id=test_id,
                user_info=user_info,
            )

    return wrapper


class Query(object):
    def __init__(
        self,
        point_from,
        point_to,
        passengers,
        date_forward,
        date_backward=None,
        klass='economy',
        national_version='ru',
        lang='ru',
        t_code='plane',
        service=None,
        created=None,
        base_qid=None,
        **kwargs
    ):
        # RASPTICKETS-9217:
        # Расписания 500тят при поиске цен в ТД в белорусской версии
        if service == 'rasp_morda_backend' and national_version == 'by':
            national_version = 'ru'

        self.point_from = point_from
        self.point_to = point_to
        self.passengers = passengers
        self.klass = klass
        self.date_forward = date_forward
        self.date_backward = date_backward
        self.national_version = national_version
        self.lang = 'ru' if self.national_version == 'kz' and feature_flags.replace_kk_with_ru_in_kz() else lang
        self.t_code = t_code
        self.service = service or getattr(settings, 'PROJECT_CODE', None)
        self.meta = {
            'custom_store_time': kwargs.get('custom_store_time', 0),
            'wizard_caches': kwargs.get('wizard_caches'),
        }

        self.country_code_from = None
        self.country_code_to = None
        self.search_url = None
        self.partners = ()  # Immutable list. Attr for backward compatability
        self.iata_from = None
        self.iata_to = None
        self.iata_real_from = None
        self.iata_real_to = None
        self.partner_codes = None
        self.test_context = None

        if not self.service:
            raise Exception('Either "service" should be provided to Query() or PROJECT_CODE filled')

        if not self.point_from:
            raise QueryIsNotValid('Empty point_from')

        if not self.point_to:
            raise QueryIsNotValid('Empty point_to')

        for k, v in kwargs.iteritems():
            setattr(self, k, v)

        self.created = created or _time.time()

        self.base_qid = base_qid
        self.base_query = None

        self.queries = []
        if base_qid is None:
            self.queries = self.init_multiple_queries()

    def __eq__(self, other):
        return (
            isinstance(other, self.__class__) and
            self.id == other.id
        )

    @cached_property
    def id(self):
        return format_qid(
            self.created, self.service, self.t_code, self.key(), self.lang
        )

    @property
    def qkey(self):
        return self.key()

    @property
    def qid_msg(self):
        return '(qid:%s)' % self.id

    @property
    def is_sub_query(self):
        return bool(self.base_qid)

    def key(self):
        params = (
            self.point_from.point_key,
            self.point_to.point_key,
            str(self.date_forward),
            str(self.date_backward),
            self.klass,
            str(self.passengers.get('adults', 0)),
            str(self.passengers.get('children', 0)),
            str(self.passengers.get('infants', 0)),
            self.national_version,
        )
        key = '_'.join([str(p) for p in params])

        return key

    @classmethod
    def _get_point_tuple_by_key(cls, point_key):
        return get_point_tuple_by_key(point_key)

    @classmethod
    def from_key(cls, key, service=None, lang=None, t_code=None, created=None, base_qid=None):
        (
            point_from_key,
            point_to_key,
            date_forward,
            date_backward,
            klass,
            adults,
            children,
            infants,
            national_version
        ) = key.split('_')[:9]

        if date_backward == 'None':
            date_backward = None
        else:
            date_backward = parse_date(date_backward)

        return cls(
            point_from=cls._get_point_tuple_by_key(point_from_key),
            point_to=cls._get_point_tuple_by_key(point_to_key),
            date_forward=parse_date(date_forward),
            date_backward=date_backward,
            klass=klass,
            passengers={
                'adults': int(adults),
                'children': int(children),
                'infants': int(infants),
            },
            national_version=national_version,
            lang=lang,
            t_code=t_code,
            service=service,
            created=created,
            base_qid=base_qid,
        )

    @classmethod
    def from_qid(cls, qid, base_qid=None):
        created, service, t_code, qkey, lang = parse_qid(qid)
        q = cls.from_key(
            qkey, service=service, lang=lang, t_code=t_code, created=created, base_qid=base_qid
        )
        return q

    @property
    def adults(self):
        return self.passengers['adults']

    @property
    def children(self):
        return self.passengers['children']

    @property
    def infants(self):
        return self.passengers['infants']

    @property
    def only_adults(self):
        return (
            self.adults > 0
            and self.children == 0
            and self.infants == 0
        )

    def create_subquery(self, point_from_key, point_to_key):
        # type: (str, str) -> Query
        subquery = deepcopy(self)

        subquery.point_from = self._get_point_tuple_by_key(point_from_key)
        subquery.point_to = self._get_point_tuple_by_key(point_to_key)
        subquery.base_qid = self.id
        subquery.base_query = self
        return subquery

    def init_multiple_queries(self):
        queries = [self]

        if feature_flags.replace_search_to_station_with_search_to_city():
            city_from = get_settlement_by_id(self.point_from.settlement_id)
            city_to = get_settlement_by_id(self.point_to.settlement_id)
            if not city_from or not city_to:
                log.warning('No settlement for points %r - %r', self.point_from, self.point_to)
                return queries

            if self.point_from.point_key != city_from.point_key or self.point_to.point_key != city_to.point_key:
                subquery = self.create_subquery(city_from.point_key, city_to.point_key)
                return [subquery]

        if not feature_flags.init_multiple_queries():
            return queries

        search_from_codes = get_codes_for_search(self.point_from.point_key)
        search_to_codes = get_codes_for_search(self.point_to.point_key)
        for point_from_key_bundle, point_to_key_bundle in product(search_from_codes, search_to_codes):
            (
                point_from_key,
                point_to_key,
                date_forward,
                date_backward,
                klass,
                adults,
                children,
                infants,
                national_version
            ) = self.qkey.split('_')[:9]
            point_from_key_new, point_from_key_reasons = point_from_key_bundle
            point_to_key_new, point_to_key_reasons = point_to_key_bundle
            if (
                point_from_key == point_from_key_new and point_to_key == point_to_key_new or
                point_from_key == point_to_key_new or point_to_key == point_from_key_new
            ):
                continue

            subquery = self.create_subquery(point_from_key_new, point_to_key_new)
            queries.append(subquery)

        return queries

    def get_enabled_partner_codes(self, for_init_search=False):
        if not enabled_partner_provider.validate_enabled_field_name(
            national_version=self.national_version,
            mobile=self.is_mobile,
            for_init_search=for_init_search,
            is_from_rasp=self.is_from_rasp,
        ):
            raise BadRequest('National version "%s" does not support' % self.national_version)

        return enabled_partner_provider.get_codes(
            national_version=self.national_version,
            mobile=self.is_mobile,
            for_init_search=for_init_search,
            is_from_rasp=self.is_from_rasp,
        )

    @property
    def project_group(self):
        return get_service_by_code(self.service).project_group

    @property
    def is_from_rasp(self):
        return self.project_group == 'rasp'

    @property
    def is_mobile(self):
        return get_service_by_code(self.service).is_mobile

    def city_from(self):
        return _get_point_related_settlement(self.point_from)

    def city_to(self):
        return _get_point_related_settlement(self.point_to)

    def related_settlement_ids_from(self):
        return self.point_from.get_related_settlement_ids()

    def related_settlement_ids_to(self):
        return self.point_to.get_related_settlement_ids()

    def __repr__(self):
        return '<Query: %s>' % self.id


@check_limits
def query_all_pure(query, ignore_cache=False, test_id=None, base_qid=None, user_info=None):
    log_query_all_pure(query, ignore_cache, user_info, base_qid)
    td_client = internal_daemon_client
    ticked_daemon_host = request.args.get('force-ticket-daemon-host')

    if ticked_daemon_host:
        td_client = InternalDaemonClient(
            tvm_provider=provider_fabric.create(
                settings,
                timeout=settings.TVM_TIMEOUT,
            ),
            ticket_daemon_host=ticked_daemon_host,
        )

    return filter_disabled_partners(
        query,
        td_client.query(
            query=query,
            ignore_cache=ignore_cache,
            test_id=test_id,
            base_qid=base_qid,
        ),
    )


def filter_disabled_partners(query, data):
    enabled_partners = set(query.get_enabled_partner_codes()) | settings.BANNED_PARTNERS_FOR_RESULTS
    data['partners'] = [p for p in data.get('partners', []) if p in enabled_partners]
    return data


def log_query_all_pure(q, ignore_cache, user_info, base_qid):
    date_backward = None
    if q.date_backward:
        date_backward = q.date_backward.strftime('%Y-%m-%d')

    data = {
        'qid': q.id,
        'base_qid': base_qid,
        'adults': q.adults,
        'children': q.children,
        'unixtime': unixtime(),
        'eventtime': datetime.now().strftime('%Y%m%d%H%M%S'),
        'from_id': q.point_from.point_key,
        'infants': q.infants,
        'class': q.klass,
        'national_version': q.national_version,
        'return_date': date_backward,
        'service': q.service,
        'to_id': q.point_to.point_key,
        'when': q.date_forward.strftime('%Y-%m-%d'),
        'yandexuid': user_info.get('yandexuid'),
        'userip': user_info.get('userip'),
        'passportuid': user_info.get('passportuid'),
        'request_ignore_cache': ignore_cache,
        'request_partners': q.partner_codes,
        'request_custom_store_time': q.meta.get('custom_store_time'),
    }

    http_api_search_log.log(data)


def init_search(query, ignore_cache=False, user_info=None, test_id=None):
    threads = []
    for subquery in query.queries:
        threads.append(
            gevent.spawn(
                copy_current_request_context(query_all_pure),
                subquery,
                ignore_cache=ignore_cache,
                test_id=test_id,
                base_qid=subquery.base_qid,
                user_info=user_info or {},
            )
        )

    extra_threads = []
    is_direct_flights_available_thread = None
    if feature_flags.use_shared_flights_for_close_dates_direct_flights():
        is_direct_flights_available_thread = gevent.spawn(get_close_date_flights, query)
        extra_threads.append(is_direct_flights_available_thread)

    gevent.joinall(threads+extra_threads)
    for thread in dropwhile(lambda t: t.successful(), threads):
        raise thread.exception

    partners = set(chain.from_iterable(thread.value['partners'] for thread in threads))
    is_direct_flights_available = False
    if is_direct_flights_available_thread is not None:
        is_direct_flights_available = is_direct_flights_available_thread.value
    return {
        'qid': query.id,
        'partners': partners,
        'queries': [t.value for t in threads],
        'isDirectFlightsAvailable': is_direct_flights_available,
    }


def get_close_date_flights(query):
    aiports_from = query.point_from.get_airports_except_hidden()
    airports_to = query.point_to.get_airports_except_hidden()
    # Don't bother asking shared-flights for direct flights to or from nowhere
    if not aiports_from:
        log.info('Unable to find airports for the \'from\' point %s', query.point_from.point_key)
        return False
    if not airports_to:
        log.info('Unable to find airports for the \'to\' point %s', query.point_to.point_key)
        return False
    try:
        result = SharedFlightsClient().close_dates_with_direct_flight(
            query.date_forward,
            query.date_backward,
            aiports_from,
            airports_to,
        )
        return len(result['forward']['dates']) > 0
    except Exception:
        log.exception('Fail figuring out direct flights on close dates')
        return False
