import socket
import logging

import gevent.lock
import gevent.queue
from gevent import Timeout
from thrift.transport import TSocket
from thrift.transport import TSSLSocket
from thrift.protocol.TBinaryProtocol import TBinaryProtocol, TBinaryProtocolAccelerated
from thrift.transport import TTransport
from sepelib.core.constants import MEGABYTE
from sepelib.util.retry import RetrySleeper

from infra.deploy_export_stats.src.libs.thrift.transport import TGeventIdleTransport


RetryableErrors = (
    TTransport.TTransportException,
    socket.error,
    socket.gaierror,
    Timeout,
    TypeError,
)


class TransportFactory(object):
    def __init__(self, host, port, idle_method=None, idle_byte_period=None,
                 use_ssl=False, ca_certs=None, certfile=None):
        self.host = host
        self.port = port
        self.idle_method = idle_method
        self.idle_byte_period = idle_byte_period
        self.use_ssl = use_ssl
        self.ca_certs = ca_certs
        self.certfile = certfile

    def create(self):
        if self.use_ssl:
            s = TSSLSocket.TSSLSocket(self.host, self.port, ca_certs=self.ca_certs, certfile=self.certfile)
        else:
            s = TSocket.TSocket(self.host, self.port)
        transport = TTransport.TFramedTransport(s)
        if self.idle_method is not None:
            transport = TGeventIdleTransport(transport,
                                             idle_byte_period=self.idle_byte_period,
                                             idle_method=self.idle_method)
        return transport


class TransportPool(object):
    DEFAULT_MAXSIZE = 1000

    def __init__(self, client_factory, maxsize=None):
        self._pool = gevent.queue.Queue(maxsize=maxsize or self.DEFAULT_MAXSIZE)
        self._client_factory = client_factory

    def get(self):
        try:
            return self._pool.get_nowait()
        except gevent.queue.Empty:
            return self._client_factory.create()

    def put_back(self, transport):
        try:
            self._pool.put(transport)
        except gevent.queue.Full:
            transport.close()


class ThriftBinClient(object):
    DEFAULT_CALL_TIMEOUT = 120

    DEFAULT_MAX_TRIES = 4

    NAME = 'bin_thrift'

    DEFAULT_IDLE_BYTE_PERIOD = MEGABYTE
    DEFAULT_REQUEST_COUNT = 100

    def __init__(self, client_mod, host, port, request_timeout=None, max_tries=None,
                 idle_method=None, idle_byte_period=None,
                 concurrent_request_count=None, conn_maxsize=None,
                 use_ssl=False, ca_certs=None, certfile=None):
        self._client_cls = client_mod.Client
        self.host = host
        self.port = port
        self.request_timeout = request_timeout or self.DEFAULT_CALL_TIMEOUT
        self.log = logging.getLogger('{}_client({}:{})'.format(self.NAME, host, port))
        transport_factory = TransportFactory(
            host=host,
            port=port,
            idle_method=idle_method,
            idle_byte_period=idle_byte_period or self.DEFAULT_IDLE_BYTE_PERIOD,
            use_ssl=use_ssl,
            ca_certs=ca_certs,
            certfile=certfile
        )
        self._transport_pool = TransportPool(transport_factory, conn_maxsize)
        self._mutex = gevent.lock.Semaphore(concurrent_request_count or self.DEFAULT_REQUEST_COUNT)
        self._retry_sleeper = RetrySleeper(max_tries or self.DEFAULT_MAX_TRIES, max_delay=30)

    def __getattr__(self, item):
        if hasattr(self._client_cls, item):
            return self.method_handle(method_name=item)
        else:
            raise AttributeError("'{}' object has no attribute '{}'".format(
                self.__class__.__name__, item
            ))

    def method_handle(self, method_name):
        def run_func(*args, **kwargs):
            with self._mutex:
                t = self._transport_pool.get()
                c = self._client_cls(TBinaryProtocolAccelerated(t))
                func = getattr(c, method_name)

                new_sleeper = self._retry_sleeper.copy()
                while True:
                    try:
                        with Timeout(seconds=self.request_timeout):
                            if not t.isOpen():
                                t.open()
                            result = func(*args, **kwargs)
                    except RetryableErrors:
                        if not new_sleeper.increment(exception=False):
                            raise
                        t.close()
                    except Exception:
                        self._transport_pool.put_back(t)
                        raise
                    else:
                        self._transport_pool.put_back(t)
                        return result
        return run_func
