# -*- coding: utf-8 -*-
from __future__ import division

import os
import random
import time as _time
from datetime import datetime, timedelta
from importlib import import_module
from logging import getLogger
from urllib import quote_plus

from django.conf import settings
from django.utils.functional import cached_property

from travel.avia.library.python.common.models.geo import Point
from travel.avia.library.python.common.utils.date import parse_date
from travel.avia.library.python.common.utils.marketstat import JSONLog

from travel.avia.library.python.ticket_daemon.memo import CacheWithKeyTTL, memoize
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.country import get_country_by_id
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.interfaces import PointInterface
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.point import get_point_tuple_by_key
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.search_codes import (
    get_iata_code_for_search, get_sirena_code_for_search,
    get_iata_codes_for_search, get_sirena_codes_for_search
)
from travel.avia.library.python.ticket_daemon.caches.services import get_service_by_code
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.station import get_airport_by_id

log = getLogger(__name__)
http_api_search_log = JSONLog(
    os.path.join(settings.LOG_PATH, 'yt/http_api_search.log'))


def _format_qid(created, service, t_code, qkey, lang):
    return '{:%y%m%d-%H%M%S}-{:0>3d}.{}.{}.{}.{}'.format(
        datetime.fromtimestamp(int(created)),
        # Миллисекунды времени создания
        int(1000 * created) % 1000,
        service,
        t_code,
        qkey,
        lang
    )


def _parse_qid(qid):
    when_created_raw, service, t_code, qkey, lang = qid.split('.')
    when_created = _parse_when_created(when_created_raw)
    return when_created, service, t_code, qkey, lang


def _parse_when_created(raw):
    ymd, hms, mils = raw.split('-')
    log.debug('parse_when_created: %r %r %r', ymd, hms, mils)
    return timestamp(
        datetime.strptime(ymd + 'T' + hms, '%y%m%dT%H%M%S') + timedelta(milliseconds=int(mils))
    )


def timestamp(dt):
    """Return POSIX timestamp as float"""
    if dt.tzinfo is None:
        return _time.mktime((
            dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, -1, -1, -1
        )) + dt.microsecond / 1e6
    else:
        raise NotImplementedError('Only for naive datetimes')


class Query(object):
    country_code_from = None
    country_code_to = None
    t_code = None
    partners = ()  # Immutable list. Attr for backward compatibility
    service = None
    lang = None
    iata_from = None
    iata_to = None
    iatas_from = None
    iatas_to = None
    iata_real_from = None
    iata_real_to = None
    partner_codes = None
    importer = None  # type: travel.avia.ticket_daemon.ticket_daemon.daemon.importer.Importer

    def __init__(
        self,
        point_from,
        point_to,
        passengers,
        date_forward,
        date_backward=None,
        klass='economy',
        national_version='ru',
        lang='ru',
        t_code='plane',
        service=None,
        created=None,
        **kwargs
    ):
        # RASPTICKETS-9217: Расписания 500тят при поиске цен в ТД в белорусской версии
        if service == 'rasp_morda_backend' and national_version == 'by':
            national_version = 'ru'

        self.point_from = point_from
        self.point_to = point_to
        self.passengers = passengers
        self.klass = klass
        self.date_forward = date_forward
        self.date_backward = date_backward
        self.national_version = national_version
        self.lang = lang
        self.t_code = t_code
        self.service = service or getattr(settings, 'PROJECT_CODE', None)

        self.meta = {
            'custom_store_time': kwargs.get('custom_store_time', 0),
        }

        if not self.service:
            raise Exception('Either "service" should be provided to Query() '
                            'or PROJECT_CODE filled')

        if not self.point_from:
            raise QueryIsNotValid('Empty point_from')

        if not self.point_to:
            raise QueryIsNotValid('Empty point_to')

        for k, v in kwargs.iteritems():
            setattr(self, k, v)

        self.need_store_tracks = (
            bool(settings.STORE_QUERY_PARTNERS_TRACKS_FACTOR) and
            random.random() < settings.STORE_QUERY_PARTNERS_TRACKS_FACTOR / 100.0
        )

        self.created = created or _time.time()

        self.id = _format_qid(
            self.created, self.service, self.t_code, self.key(), self.lang
        )

    def __eq__(self, other):
        return (
            isinstance(other, self.__class__) and
            self.id == other.id
        )

    def city_from(self):
        return _get_point_related_settlement(self.point_from)

    def city_to(self):
        return _get_point_related_settlement(self.point_to)

    def related_settlement_ids_from(self):
        return self.point_from.get_related_settlement_ids()

    def related_settlement_ids_to(self):
        return self.point_to.get_related_settlement_ids()

    def prepare_attrs_for_import(self):
        """raises QueryIsNotValid if something wrong"""

        city_from = self.city_from()
        city_to = self.city_to()
        if city_from and city_to and city_from.id == city_to.id:
            raise QueryIsNotValid('Same city from and city to')

        self.fill_codes('from')
        self.fill_codes('to')
        self.validate_iata_codes()

        self.allowed_airports_ids_from = _get_allowed_airports_ids_for_point(self.point_from)
        self.allowed_airports_ids_to = _get_allowed_airports_ids_for_point(self.point_to)

        if not self.allowed_airports_ids_from:
            raise QueryIsNotValid('No airports related to point_from')

        if not self.allowed_airports_ids_to:
            raise QueryIsNotValid('No airports related to point_to')

        # For dohop
        self.station_iatas_from = [
            s.iata for s in (get_airport_by_id(sid)
                             for sid in self.allowed_airports_ids_from)
            if s and s.iata and not s.hidden
        ]
        self.station_iatas_to = [
            s.iata for s in (get_airport_by_id(sid)
                             for sid in self.allowed_airports_ids_to)
            if s and s.iata and not s.hidden
        ]

    @property
    def qid_msg(self):
        return '(qid:%s)' % self.id

    @cached_property
    def qkey(self):
        return self.key()

    def key(self):
        params = (
            self.point_from.point_key,
            self.point_to.point_key,
            str(self.date_forward),
            str(self.date_backward),
            self.klass,
            str(self.passengers.get('adults', 0)),
            str(self.passengers.get('children', 0)),
            str(self.passengers.get('infants', 0)),
            self.national_version,
        )
        key = '_'.join(map(str, params))

        return key

    @classmethod
    def _get_point_tuple_by_key(cls, point_key):
        return get_point_tuple_by_key(point_key)

    @classmethod
    def from_key(cls, key, service=None, lang=None, t_code=None, created=None):
        (
            point_from_key,
            point_to_key,
            date_forward,
            date_backward,
            klass,
            adults,
            children,
            infants,
            national_version
        ) = key.split('_')[:9]

        adults = int(adults)
        children = int(children)
        infants = int(infants)

        return cls(
            point_from=cls._get_point_tuple_by_key(point_from_key),
            point_to=cls._get_point_tuple_by_key(point_to_key),
            date_forward=parse_date(date_forward),
            date_backward=None if date_backward == 'None' else parse_date(date_backward),
            klass=klass,
            passengers={
                'adults': int(adults),
                'children': int(children),
                'infants': int(infants),
            },
            national_version=national_version,
            lang=lang,
            t_code=t_code,
            service=service,
            created=created
        )

    @classmethod
    def from_qid(cls, qid):
        created, service, t_code, qkey, lang = _parse_qid(qid)
        return cls.from_key(
            qkey, service=service, lang=lang, t_code=t_code, created=created
        )

    @property
    def adults(self):
        return self.passengers['adults']

    @property
    def children(self):
        return self.passengers['children']

    @property
    def infants(self):
        return self.passengers['infants']

    @property
    def passengers_count(self):
        passengers = getattr(self, 'passengers', None) or {}

        return sum([
            int(passengers.get('adults', 0)),
            int(passengers.get('children', 0)),
            int(passengers.get('infants', 0)),
        ])

    @property
    def project_group(self):
        return get_service_by_code(self.service).project_group

    @property
    def is_from_rasp(self):
        return self.project_group == 'rasp'

    @property
    def is_mobile(self):
        return get_service_by_code(self.service).is_mobile

    def as_header(self):
        return quote_plus(self.key().encode('ASCII', 'replace'))

    def station_postfilter(self, v):
        """
        Берем только сегементы с совпадающими станциями,
        или со станциями относящимися к городу, напрямую и через station2settlement
        """

        stations_from = [v.forward.segments[0].station_from]
        stations_to = [v.forward.segments[-1].station_to]

        if v.backward.segments:
            stations_from.append(v.backward.segments[-1].station_to)
            stations_to.append(v.backward.segments[0].station_from)

        # Варианты без точек отправления и прибытия считаем невалидными сразу
        if not all(stations_from) or not all(stations_to):
            return False

        if self.allowed_airports_ids_from:
            for st in stations_from:
                if st.id not in self.allowed_airports_ids_from:
                    log.info(
                        '[%s] Station from %r not in allowed_airports_ids_from %r. point_from: %r',
                        v.partner.code, st.id, self.allowed_airports_ids_from, self.point_from
                    )
                    return False

        if self.allowed_airports_ids_to:
            for st in stations_to:
                if st.id not in self.allowed_airports_ids_to:
                    log.info(
                        '[%s] Station to %r not in allowed_airports_ids_to %r. point_to: %r',
                        v.partner.code, st.id, self.allowed_airports_ids_to, self.point_to
                    )
                    return False

        return True

    def fill_codes(self, direction):
        point = getattr(self, 'point_{}'.format(direction))
        self.fill_iata_codes(point, direction)

        country = get_country_by_id(point.country_id)
        setattr(self, 'country_code_{}'.format(direction),
                country.code if country else None)

    def fill_iata_codes(self, point, direction):
        iata = get_iata_code_for_search(point.point_key)
        if iata:
            setattr(self, 'iata_{}'.format(direction), iata)
            setattr(self, 'iata_real_{}'.format(direction), iata)
        else:
            sirena_code = get_sirena_code_for_search(point.point_key)
            setattr(self, 'iata_{}'.format(direction), sirena_code or None)

        iatas = get_iata_codes_for_search(point.point_key)
        if iatas:
            setattr(self, 'iatas_{}'.format(direction), iatas)
        else:
            sirenas = get_sirena_codes_for_search(point.point_key)
            setattr(self, 'iatas_{}'.format(direction), sirenas or None)

    def track_path(self):
        return os.path.join('/tmp/ticket_daemon/query', self.key(), self.id)

    def validate(self, query_module):
        if not getattr(query_module, 'NO_DEFAULT_VALIDATION', False):
            self._validate_default()

        if hasattr(query_module, 'validate_query'):
            query_module.validate_query(self)

    def _validate_default(self):
        self.validate_iata_codes()

    def validate_iata_codes(self):
        if not self.iata_from:
            raise QueryIsNotValid('No iata_from')

        if not self.iata_to:
            raise QueryIsNotValid('No iata_to')

        if self.iata_from == self.iata_to:
            raise QueryIsNotValid('Same iata_from and iata_to search codes (%s - %s)' % (self.iata_from, self.iata_to))

    def validate_passengers_type(self, who, max_count):
        passengers = getattr(self, 'passengers', None) or {}
        count = passengers.get(who, 0)

        if count > max_count:
            raise QueryIsNotValid('More %s than allowed: %s' % (who, count))

    def validate_country_codes(self):
        if not self.country_code_from:
            raise QueryIsNotValid('No query.country_code_from')

        if not self.country_code_to:
            raise QueryIsNotValid('No query.country_code_to')

    def validate_not_adults(self, max_count):
        passengers = getattr(self, 'passengers', None) or {}

        not_adults_sum = sum([
            passengers.get('children', 0),
            passengers.get('infants', 0),
        ])

        if not_adults_sum > max_count:
            message = 'More not adults passengers than allowed: %s > %s' % (
                not_adults_sum,
                max_count
            )
            raise QueryIsNotValid(message)

    def validate_passengers(self, adults=None, children=None, infants=None, count=None, not_adults=None):
        if adults is not None:
            self.validate_passengers_type('adults', adults)

        if children is not None:
            self.validate_passengers_type('children', children)

        if infants is not None:
            self.validate_passengers_type('infants', infants)

            if infants > (adults or 0):
                raise QueryIsNotValid(
                    'More infants than adults is not allowed: %r > %r' % (infants, adults))

        if not_adults is not None:
            self.validate_not_adults(not_adults)

        if count is not None and self.passengers_count > count:
            raise QueryIsNotValid(
                'More passengers than allowed: %s' % self.passengers_count)

    def validate_klass(self, allowed_klasses):
        if self.klass not in allowed_klasses:
            raise QueryIsNotValid('Service class not allowed: %s' % self.klass)

    def __repr__(self):
        return '<Query: %s>' % self.id


@memoize(lambda point: point.point_key)
def _get_allowed_airports_ids_for_point(point):
    """Работает только с namedtuple - данными с интерфейсом PointInterface"""
    assert isinstance(point, PointInterface)
    return point.get_allowed_airports_ids()


@memoize(lambda point: point.point_key)
def _get_point_related_settlement(point):
    assert isinstance(point, PointInterface)
    return point.get_related_settlement()


@memoize(lambda k: k, CacheWithKeyTTL(300, maxsize=4096))
def get_point_by_key(key):
    return Point.get_any_by_key(key)


def get_service_by_qid(qid):
    service = _parse_qid(qid)[1]
    return get_service_by_code(service)


class QueryIsNotValid(Exception):
    pass


def get_query_module(partner):
    if partner.__class__.__name__ == 'DohopVendor':
        query_module_name = partner.query_module_name or 'dohop'
    else:
        query_module_name = partner.query_module_name or partner.code

    if query_module_name:
        try:
            query_module = import_module('travel.avia.ticket_daemon.ticket_daemon.partners.%s' % query_module_name)

            return query_module

        except ImportError:
            log.exception(u'Can\'t import partner query module')
    else:
        log.warning('No query_module_name of %r', partner)


class QueryAllException(Exception):
    def __init__(self, message, status_code=500):
        super(QueryAllException, self).__init__(message)
        self._status_code = status_code

    def status_code(self):
        return self._status_code
