# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import threading

import requests
import zeep

import config
from yabus.common.exceptions import PartnerError
from yabus.util import encode_utf8, pairwise, safedict, monitoring
from yabus.util.connector_context import connector_context
from yabus.util.connector_response import log_response


class _LoggingTransport(zeep.Transport):
    def __init__(self, transport):
        self.__transport = transport
        self._local = threading.local()
        self.get = self.__wrap(transport.get)
        self.post = self.__wrap(transport.post)

    @property
    def last_response(self):
        return self._local.last_response

    def __wrap(self, func):
        def wrapper(*args, **kwargs):
            response = func(*args, **kwargs)
            log_response(response)
            self._local.last_response = response
            if connector_context and connector_context.explainer:
                connector_context.explainer.collect(response)
            return response
        return wrapper

    def __getattr__(self, item):
        return getattr(self.__transport, item)


class Session(requests.Session):
    CONNECTION_TIMEOUT = 3.05  # slightly larger than a multiple of 3
    READ_TIMEOUT = config.REQUEST_READ_TIMEOUT
    MAX_RETRIES = 3

    def __init__(self, *args, **kwargs):
        super(Session, self).__init__(*args, **kwargs)
        self.adapters = requests.compat.OrderedDict()
        self.mount(
            'https://',
            requests.adapters.HTTPAdapter(
                max_retries=self.MAX_RETRIES,
                pool_maxsize=config.MAX_POOL_SIZE,
            ),
        )
        self.mount(
            'http://',
            requests.adapters.HTTPAdapter(
                max_retries=self.MAX_RETRIES,
                pool_maxsize=config.MAX_POOL_SIZE,
            ),
        )

    def request(self, *args, **kwargs):
        if kwargs.get('timeout') is None:
            kwargs['timeout'] = (self.CONNECTION_TIMEOUT, self.READ_TIMEOUT)
        return super(Session, self).request(*args, **kwargs)


class Transport(zeep.Transport):
    def __init__(self, session=None, **kwargs):
        session = session or Session()
        super(Transport, self).__init__(session=session, **kwargs)


class HttpAuthenticated(zeep.Transport):
    def __init__(self, username, password, **kwargs):
        session = Session()
        session.auth = requests.auth.HTTPBasicAuth(username, password)
        super(HttpAuthenticated, self).__init__(session=session, **kwargs)


class SoapClient(zeep.Client):
    """Base SOAP client for soap-based APIs."""
    def __init__(self, url, transport=None, **kwargs):
        transport = _LoggingTransport(transport or Transport())
        super(SoapClient, self).__init__(url, transport=transport, **kwargs)

    def call(self, method, args):
        monitoring_labels = {"client": "soap", "method": method}
        try:
            monitoring.count_request(monitoring_labels)
            result = getattr(self.service, method)(*args)
        except zeep.exceptions.Error as e:
            message = unicode(e).encode('utf8')
            fault = safedict(encode_utf8(self._etree_to_dict(e.detail))) if hasattr(e, 'detail') else {}
            response = self.transport.last_response
            raise PartnerError(message, response, fault)
        else:
            response = self.transport.last_response
            monitoring.count_response(response, monitoring_labels)
            return result

    @staticmethod
    def _etree_to_dict(t):
        if t is None:
            return {}
        if any(x.tag == y.tag for x, y in pairwise(t)):
            raise ValueError("Arrays are not allowed")
        tag = t.tag[t.tag.find('}')+1:]
        children = [SoapClient._etree_to_dict(c) for c in t]
        if not children:
            return {tag: t.text}
        return {
            tag: {k: v for c in children for k, v in c.items()}
        }
