import socket
import logging
from cStringIO import StringIO
from uuid import uuid4

import gevent.lock
import gevent.queue
import requests
from gevent import Timeout
from thrift.transport import TSocket
from thrift.transport import TSSLSocket
from thrift.protocol.TBinaryProtocol import TBinaryProtocol, TBinaryProtocolAccelerated
from thrift.transport import TTransport

from sepelib.core.constants import MEGABYTE
from infra.swatlib import metrics as swatlib_metrics
from sepelib.util.retry import RetrySleeper

from infra.swatlib.thrift.transport import TGeventIdleTransport
from infra.swatlib.thrift.sock import TConnectionTimeoutAwareSocket
import six

COMMON_MAX_RETRIES = 4
CONNECTION_ERRORS = (TTransport.TTransportException,
                     socket.error, socket.gaierror, Timeout,
                     requests.ConnectionError)


class ThriftConnectionError(Exception):
    def __init__(self, host, port, method, message, error):
        self.host = host
        self.port = port
        self.method = method
        self.message = message
        self.error = error

    def __str__(self):
        return 'ThriftConnectionError(host={}, port={}, method={}, message={}, error={})'.format(
            self.host,
            self.port,
            self.method,
            self.message,
            self.error,
        )


class RequestsHttpTransport(TTransport.TTransportBase, object):
    """
    Requests based http transport for thrift RPC.
    It is better than default one:
        * less AttributeErrors here and there
        * by using provided session we can reuse connections
        * by using provided session we can issue several call in parallel for free (no code from us)
    """
    REQ_ID_HEADER = 'X-Req-Id'
    _rbuf = StringIO('')  # empty readonly buffer until we get real one after request
    _headers = {"Content-Type": "application/x-thrift"}

    @staticmethod
    def _gen_reqid(uuid_func=uuid4):
        return str(uuid_func())

    def __init__(self, url, session=None, reqid=None):
        self._url = url
        self._reqid = reqid or self._gen_reqid()
        self._session = session or requests.Session()
        self._wbuf = StringIO()

    def read(self, size):
        return self._rbuf.read(size)

    def readAll(self, size):
        buf = self._rbuf.read(size)
        if len(buf) < size:
            raise EOFError
        return buf

    def write(self, buf):
        self._wbuf.write(buf)

    def flush(self):
        data = self._wbuf.getvalue()
        self._wbuf.close()
        if self._reqid is not None:
            headers = self._headers.copy()
            headers[self.REQ_ID_HEADER] = self._reqid
        else:
            headers = self._headers
        response = self._session.post(self._url, data=data, headers=headers)
        response.raise_for_status()
        # read everything into StringIO
        # this can bring speed ups because all small reads from thrift code won't
        # result in extra system calls
        # StringIO does not create copy - so we simply provide
        # file like API for thrift protocol
        self._rbuf = StringIO(response.content)


class TransportFactory(object):
    def __init__(self, host, port, idle_method=None, idle_byte_period=None, use_ssl=False, ca_certs=None,
                 certfile=None, connection_timeout_ms=None):
        self.host = host
        self.port = port
        self.idle_method = idle_method
        self.idle_byte_period = idle_byte_period
        self.use_ssl = use_ssl
        self.ca_certs = ca_certs
        self.certfile = certfile
        self.connection_timeout_ms = connection_timeout_ms

    def create(self):
        if self.use_ssl:
            s = TSSLSocket.TSSLSocket(self.host, self.port, ca_certs=self.ca_certs, certfile=self.certfile)
        else:
            s = TSocket.TSocket(self.host, self.port)
        ext_socket = TConnectionTimeoutAwareSocket(socket=s,
                                                   connection_timeout_ms=self.connection_timeout_ms)
        transport = TTransport.TFramedTransport(ext_socket)
        if self.idle_method is not None:
            transport = TGeventIdleTransport(transport,
                                             idle_byte_period=self.idle_byte_period,
                                             idle_method=self.idle_method)
        return transport


class TransportPool(object):
    DEFAULT_MAXSIZE = 1000

    def __init__(self, client_factory, maxsize=None):
        self._pool = gevent.queue.Queue(maxsize=maxsize or self.DEFAULT_MAXSIZE)
        self._client_factory = client_factory

    def get(self):
        try:
            return self._pool.get_nowait()
        except gevent.queue.Empty:
            return self._client_factory.create()

    def put_back(self, transport):
        try:
            self._pool.put(transport)
        except gevent.queue.Full:
            transport.close()


class ThriftBinClient(object):
    DEFAULT_CALL_TIMEOUT = 25

    DEFAULT_MAX_TRIES = 4

    NAME = 'bin_thrift'

    DEFAULT_IDLE_BYTE_PERIOD = MEGABYTE
    DEFAULT_REQUEST_COUNT = 100

    @classmethod
    def from_config(cls, config):
        """
        :param dict config:
        :rtype: ThriftClient
        """
        return cls(config['thrift_service'], config['host'], config['port'], config.get('request_timeout'))

    def __init__(self, client_mod, host, port, request_timeout=None, max_tries=None, metrics=None,
                 idle_method=None, idle_byte_period=None,
                 concurrent_request_count=None, conn_maxsize=None,
                 use_ssl=False, ca_certs=None, certfile=None, connection_timeout_ms=None):
        """
        :param client_mod: module containing Thrift RPC
        :type idle_method: collections.Callable
        :param idle_method: idle method for calling at least each :param idle_byte_period: read bytes
        :type idle_byte_period: int
        :param int concurrent_request_count: number of concurrent api requests
        :param int conn_maxsize: maxsize of thrift connection pool
        :param bool use_ssl: use ssl or not
        :param str ca_certs: filename to the Certificate Authority pem file
        :param str certfile: filename to certificate pem file
        """
        self._client_cls = client_mod.Client
        self.host = host
        self.port = port
        self.request_timeout = request_timeout or self.DEFAULT_CALL_TIMEOUT
        self.log = logging.getLogger('{}_client({}:{})'.format(self.NAME, host, port))
        self._metrics = metrics or swatlib_metrics.ROOT_REGISTRY.path('clients', self.NAME)
        transport_factory = TransportFactory(
            host=host,
            port=port,
            idle_method=idle_method,
            idle_byte_period=idle_byte_period or self.DEFAULT_IDLE_BYTE_PERIOD,
            use_ssl=use_ssl,
            ca_certs=ca_certs,
            certfile=certfile,
            connection_timeout_ms=connection_timeout_ms,
        )
        self._transport_pool = TransportPool(transport_factory, conn_maxsize)
        self._mutex = gevent.lock.Semaphore(concurrent_request_count or self.DEFAULT_REQUEST_COUNT)
        self._retry_sleeper = RetrySleeper(max_tries or self.DEFAULT_MAX_TRIES, max_delay=1)

    def __getattr__(self, item):
        if hasattr(self._client_cls, item):
            return self.method_handle(method_name=item)
        else:
            raise AttributeError("'{}' object has no attribute '{}'".format(
                self.__class__.__name__, item
            ))

    def method_handle(self, method_name):
        def run_func(*args, **kwargs):
            with self._mutex:
                t = self._transport_pool.get()
                c = self._client_cls(TBinaryProtocolAccelerated(t))
                func = getattr(c, method_name)

                new_sleeper = self._retry_sleeper.copy()
                while True:
                    try:
                        func_timer = self._metrics.get_histogram(func.__name__).timer()
                        total_timer = self._metrics.get_histogram('total_timings').timer()
                        self._metrics.get_counter('total_calls').inc()
                        with Timeout(seconds=self.request_timeout), func_timer, total_timer:
                            if not t.isOpen():
                                self._metrics.get_counter('tcp_reopen').inc()
                                t.open()
                            result = func(*args, **kwargs)
                    except CONNECTION_ERRORS as e:
                        self._metrics.get_counter('conn_error_calls').inc()
                        if not new_sleeper.increment(exception=False):
                            raise ThriftConnectionError(self.host, self.port, func.__name__, str(e), e)
                        t.close()
                    except Exception:
                        self._metrics.get_counter('other_error_calls').inc()
                        self._transport_pool.put_back(t)
                        raise
                    else:
                        self._metrics.get_counter('successful_calls').inc()
                        self._transport_pool.put_back(t)
                        return result
        return run_func


class TUnicodeBinaryProtocol(TBinaryProtocol):
    """
    Hack for unicode support
    """

    def writeString(self, msg):
        if isinstance(msg, six.text_type):
            msg = msg.encode('utf-8')
        TBinaryProtocol.writeString(self, msg)

    def readString(self):
        msg = TBinaryProtocol.readString(self)
        return msg.decode('utf-8')
