from __future__ import absolute_import

import re
import abc
import copy
import json
import uuid
import time
import errno
import random
import select
import socket
import httplib
import logging
import warnings
import urlparse
import datetime
import threading
import functools
import contextlib
import collections

import requests
import requests.auth
import requests.structures
import requests.packages.urllib3.exceptions

from . import proxy
from .types import misc as ctm


def patch_session_timeout(session):
    import requests.adapters
    import requests.packages.urllib3.connection
    import requests.packages.urllib3.poolmanager

    class HTTPConnection(requests.packages.urllib3.connection.HTTPConnection):
        def _new_conn(self):
            timeout, self.timeout = self.timeout, Client.CONNECT_TIMEOUT
            try:
                conn = super(HTTPConnection, self)._new_conn()
                conn.settimeout(timeout)
                return conn
            finally:
                self.timeout = timeout

    class HTTPSConnection(requests.packages.urllib3.connection.HTTPSConnection):
        def _new_conn(self):
            timeout, self.timeout = self.timeout, Client.CONNECT_TIMEOUT
            try:
                conn = super(HTTPSConnection, self)._new_conn()
                conn.settimeout(timeout)
                return conn
            finally:
                self.timeout = timeout

    class HTTPConnectionPool(requests.packages.urllib3.HTTPConnectionPool):
        ConnectionCls = HTTPConnection

    class HTTPSConnectionPool(requests.packages.urllib3.HTTPSConnectionPool):
        ConnectionCls = HTTPSConnection

    class PoolManager(requests.packages.urllib3.poolmanager.PoolManager):
        SCHEME2CLASS = {"http": HTTPConnectionPool, "https": HTTPSConnectionPool}

        def _new_pool(self, scheme, host, port):
            pool_cls = self.SCHEME2CLASS[scheme]
            kwargs = self.connection_pool_kw
            if scheme == 'http':
                kwargs = kwargs.copy()
                for kw in requests.packages.urllib3.poolmanager.SSL_KEYWORDS:
                    kwargs.pop(kw, None)

            return pool_cls(host, port, **kwargs)

    class HTTPAdapter(requests.adapters.HTTPAdapter):
        def init_poolmanager(self, connections, maxsize, block=requests.adapters.DEFAULT_POOLBLOCK, **pool_kwargs):
            self._pool_connections = connections
            self._pool_maxsize = maxsize
            self._pool_block = block
            self.poolmanager = PoolManager(num_pools=connections, maxsize=maxsize, block=block, **pool_kwargs)

    session.mount('http://', HTTPAdapter())
    session.mount('https://', HTTPAdapter())


class Path(object):
    """ Pythonish representation of the URL path """

    DELIMITER = "/"

    def __init__(self, rest, path="", input_mode=None, output_mode=None):
        self.__rest = rest
        self.__path = path
        self.__input = input_mode or Client.JSON()
        self.__output = output_mode or Client.JSON()

    def __str__(self):
        return self.__path or self.DELIMITER

    def __lshift__(self, other):
        own = self.__input
        if callable(other):
            other = other()
        if not isinstance(other, Client.Modifiers):
            raise ValueError("Modifier is not a subclass of `Client.Modifiers`")
        hdrs_jar = own.headers
        if not other.type:
            own.headers = other.headers
            other, own = own, other
            other.headers.update(hdrs_jar)
        else:
            other.headers = own.headers
        return self.__class__(self.__rest, self.__path, other, self.__output)

    def __rshift__(self, other):
        own = self.__output
        if callable(other):
            other = other()
        if not isinstance(other, Client.Modifiers):
            raise ValueError("Modifier is not a subclass of `Client.Modifiers`")
        hdrs_jar = own.headers
        if not other.type:
            own.headers = other.headers
            other, own = own, other
            other.headers.update(hdrs_jar)
        else:
            other.headers = own.headers
        return self.__class__(self.__rest, self.__path, self.__input, other)

    @staticmethod
    def __slice2params(slice_):
        params = {}
        if slice_.start is not None:
            params["offset"] = slice_.start
        if slice_.stop is not None:
            params["limit"] = slice_.stop
        if slice_.step is not None:
            params["order"] = slice_.step
        return params

    def __getitem__(self, item):
        params, slice_ = None, None
        if isinstance(item, tuple):
            params, slice_ = item
            if not (isinstance(params, (dict, Ellipsis)) and isinstance(slice_, slice)):
                raise ValueError("Parameters must be of type (dict or ..., slice)")
            params.update(self.__slice2params(slice_))
        elif isinstance(item, dict):
            params = item
        elif isinstance(item, slice):
            params = self.__slice2params(item)
        if params is not None or item == Ellipsis:
            return self.read(params)
        return Path(self.__rest, self.DELIMITER.join((self.__path, str(item))), self.__input, self.__output)

    def __getattr__(self, item):
        if item.isupper() and hasattr(Client, item):
            return getattr(Client, item)
        return Path(self.__rest, self.DELIMITER.join((self.__path, str(item))), self.__input, self.__output)

    def __setattr__(self, item, value):
        if item.startswith("_"):
            return super(Path, self).__setattr__(item, value)
        self[item].update(value)

    def __setitem__(self, item, value):
        self[item].update(value)

    def create(self, data=None, params=None, **kws):
        if data is not None:
            if params is None:
                params = kws
        elif params is None:
            data = kws
        return self.__rest.create(str(self), data, params, self.__input, self.__output)

    def __call__(self, params=None, **kws):
        return self.create(params, **kws)

    def read(self, params=None, **kws):
        if params is None:
            params = kws
        return self.__rest.read(str(self), params, self.__input, self.__output)

    def delete(self, data=None, params=None, **kws):
        if data is not None:
            if params is None:
                params = kws
        elif params is None:
            data = kws
        self.__rest.delete(str(self), data, params, self.__input, self.__output)

    def __delitem__(self, item):
        self[item].delete()

    __delattr__ = __delitem__

    def update(self, data=None, params=None, **kws):
        if data is not None:
            if params is None:
                params = kws
        elif params is None:
            data = kws
        return self.__rest.update(str(self), data, params, self.__input, self.__output)


def _urllib3_warning_suppress(wrn_cls_name):
    @contextlib.contextmanager
    def __fake_cm(_):
        yield None
    cls = getattr(requests.packages.urllib3.exceptions, wrn_cls_name, None)
    return functools.partial(warnings.catch_warnings, cls) if cls else __fake_cm


def _urllib3_logging_suppress(called=[]):
    """
    Silence urllib3's logging which floods debug.log, as it's unnecessary (common.rest.Client does the thing)
    :param called: fake parameter to only execute meaningful body once
    """

    if called:
        return
    blacklist = {"urllib3.connectionpool"}
    for name, logger in logging.Logger.manager.loggerDict.items():
        if name not in blacklist:
            continue
        if not isinstance(logger, logging.PlaceHolder):
            map(logger.removeHandler, logger.handlers[:])
        logger.disabled = 1
    called.append(True)


class Client(object):
    """
    REST API client, which can retry network and database errors.
    Instances of the class has no multi-threading support, so
    each thread should use its own instance of the class.

    Usage examples:

    .. code-block:: python

        response_headers = sandbox.HEADERS()
        sandbox = Client() << sandbox.HEADERS({"X-Header": "Value"}) >> response_headers

        # GET /<base path>/resource:
        sandbox.resource[:]
        sandbox.resource[...]
        sandbox.resource.read()

        # GET /<base path>/resource?limit=10:
        sandbox.resource[:10]
        sandbox.resource.read(limit=10)
        sandbox.resource.read({"limit": 10})

        # GET /<base path>/resource?limit=10&offset=20:
        sandbox.resource[20:10]
        sandbox.resource.read(offset=20, limit=10)
        sandbox.resource.read({"offset": 20, "limit": 10})

        # GET /<base path>/resource?limit=10&offset=20&order=-id:
        sandbox.resource[20:10:"-id"]
        sandbox.resource.read(offset=20, limit=10, order="-id")
        sandbox.resource.read({"offset": 20, "limit": 10, "order": "-id"})

        # GET /<base path>/resource?type=OTHER_RESOURCE&limit=10&order=state:
        sandbox.resource[{"type": "OTHER_RESOURCE"}, : 10: "state"]
        sandbox.resource.read(type="OTHER_RESOURCE", limit=10, order="state")
        sandbox.resource.read({"type": "OTHER_RESOURCE", "limit": 10, "order": "state"})

        # GET /<base path>/resource/12345:
        sandbox.resource[12345][:]
        sandbox.resource[12345][...]
        sandbox.resource[12345].read()

        # GET /<base path>/resource/12345/attribute:
        sandbox.resource[12345].attribute[:]
        sandbox.resource[12345].attribute[...]
        sandbox.resource[12345].attribute.read()

        # POST /<base path>/resource:
        sandbox.resource(**fields)
        sandbox.resource({<fields>})
        sandbox.resource.create(**fields)
        sandbox.resource.create({<fields>})

        # PUT /<base path>/resource/12345:
        sandbox.resource[12345] = {<fields>}
        sandbox.resource[12345].update(**fields)
        sandbox.resource[12345].update({<fields>})

        # DELETE /<base path>/resource/12345/attribute/attr1:
        del sandbox.resource[12345].attribute
        del sandbox.resource[12345].attribute.attr1
        del sandbox.resource[12345].attribute["attr1"]
        sandbox.resource[12345].attribute.delete("attr1")
    """

    DEFAULT_INTERVAL = 5  # Initial wait period between retries in seconds.
    DEFAULT_TIMEOUT = 60  # Initial call timeout in seconds.
    DEFAULT_BASE_URL = "https://sandbox.yandex-team.ru/api/v1.0"
    CONNECT_TIMEOUT = 5  # Maximum wait time in seconds to establish a new connection.
    MAX_INTERVAL = 300  # Maximum wait period between retries in seconds.
    MAX_TIMEOUT = 180  # Maximum call timeout in seconds.

    # status code for invalid user response
    USER_DISMISSED = 451
    # The user has sent too many requests in a given amount of time
    TOO_MANY_REQUESTS = 429
    # tuple of HTTP response codes to retry requests for
    RETRYABLE_CODES = (
        TOO_MANY_REQUESTS, httplib.REQUEST_TIMEOUT,
        httplib.BAD_GATEWAY, httplib.SERVICE_UNAVAILABLE, httplib.GATEWAY_TIMEOUT
    )

    # The object will be returned by `__call__` operator in case of server respond HTTP 205 Reset Content
    RESET = type("RESET", (object,), {})
    # The object will be returned by `__call__` operator in case of server respond HTTP 204 No Content
    NO_CONTENT = type("NO_CONTENT", (object,), {})

    # Global authentication object case, used for external authentication information providers.
    _external_auth = None
    # Globally defined Sandbox component name to identify internal requests.
    _default_component = None

    # A callback, which is called with arguments (method, request duration) after every request is made.
    # This is done mostly for SANDBOX-4622 and is set up in executor
    _accounter = None

    # urllib3 warning suppressor context managers with backward compatibility
    __insecure_platform_ok = _urllib3_warning_suppress("InsecurePlatformWarning")
    __sni_missing_ok = _urllib3_warning_suppress("SNIMissingWarning")

    class CustomEncoder(json.JSONEncoder):
        def default(self, o):
            if hasattr(o, "__getstate__"):
                return o.__getstate__()
            return super(Client.CustomEncoder, self).default(o)

    class Auth(requests.auth.AuthBase):
        def __init__(self, auth):
            self.auth = auth
            super(Client.Auth, self).__init__()

        def __call__(self, r):
            r.headers.update({k: v for k, v in self.auth})
            return r

    class Modifiers(object):
        """ Request and/or response data representation modifiers. """
        __metaclass__ = abc.ABCMeta

        class HeadersJar(object):
            def __init__(self, default=None, custom=None):
                self.request = requests.structures.CaseInsensitiveDict(default)
                if custom:
                    self.request.update(custom)
                self.response = requests.structures.CaseInsensitiveDict()

            def update(self, other):
                self.request.update(other.request)
                self.response.update(other.response)
                return self

        def __init__(self, headers=None):
            self.headers = self.HeadersJar({"Content-Type": self.type}, headers)

        @abc.abstractproperty
        def type(self):
            """ Content type MIME name. """
            pass

        @staticmethod
        def request(data):
            """ Request data preprocessor. """
            return data

        def response(self, data):
            """ Response data getter. """
            self.headers.response = data.headers
            if data.status_code == httplib.NO_CONTENT:
                return Client.NO_CONTENT
            if data.status_code == httplib.RESET_CONTENT:
                return Client.RESET
            return data.content

    class JSON(Modifiers):
        type = "application/json"

        @staticmethod
        def request(data):
            return json.dumps(data, cls=Client.CustomEncoder)

        def response(self, data):
            self.headers.response = data.headers
            if data.status_code == httplib.NO_CONTENT:
                return Client.NO_CONTENT
            if data.status_code == httplib.RESET_CONTENT:
                return Client.RESET
            return data.json() if data.content else None

    class PLAINTEXT(Modifiers):
        type = "text/plain"

    class BINARY(Modifiers):
        type = "application/octet-stream"

    class HEADERS(Modifiers):
        """ The class is designed to act as request and response headers collector. """
        type = None

        def __setitem__(self, key, value):
            self.headers.request[key] = value

        def __getitem__(self, key):
            return self.headers.response[key]

        def __contains__(self, item):
            return item in self.headers.response

        def __repr__(self):
            return repr(self.headers.response)

    class TimeoutExceeded(Exception):
        """ A error class to be raised in case of maximum amount of wait time exceeded """

    class SessionExpired(BaseException):
        """ Raises if REST API session is expired """

    class HTTPError(requests.HTTPError):
        """ HTTP error while requesting server """
        def __new__(cls, ex):
            """ Creates a new object based on object of :class:`request.HTTPError` instance. """
            ex = copy.copy(ex)
            ex.__class__ = cls
            return ex

        def __init__(self, *args, **kwargs):
            self.args = tuple(self.args + (self.response.text,))

        def __str__(self):
            return "{}: {}".format(super(Client.HTTPError, self).__str__(), self.response.text.encode("utf8"))

        @property
        def status(self):
            return self.response.status_code

    def __init__(self, base_url=None, auth=None, logger=None, total_wait=None, ua=None, component=None, debug=False):
        """
        Constructor.

        :param base_url:    Base URL of REST API server to connect.
                            Supports space-separated servers list and bash brackets.
        :param auth:        Authorization object, an instance of :class:`common.proxy.Authorization`
        :param logger:      Logger object to use. Uses root logger if not provided.
        :param total_wait:  Maximum total wait time in seconds. Restricts maximum time spend to single RPC call,
                            including the time of call itself taken and also call re-tries.
        :param ua:          User Agent identification string to be passed with the request
        :param component:   Sandbox component name. For internal use only.
        :param debug:       Log occured exceptions' types and arguments
        """

        try:
            import py
            from . import config
            settings = config.Registry()
            settings.root()
        except ImportError:  # py module could not exist
            settings = None
        except (py.error.ENOENT, py.error.ENOTDIR, OSError, IOError, AttributeError):
            settings = None

        if base_url is None and settings:
            base_url = settings.client.rest_url

        self.__interval = self.DEFAULT_INTERVAL
        self.__timeout = self.DEFAULT_TIMEOUT
        self.__parsed_base_url = urlparse.urlparse(base_url or self.DEFAULT_BASE_URL)
        self.__hosts = collections.deque(
            proxy.brace_expansion(map(str.strip, self.__parsed_base_url.netloc.split(" ")))
        )
        random.shuffle(self.__hosts)

        if auth and not isinstance(auth, proxy.Authentication):
            auth = proxy.OAuth(auth)
        else:
            auth = auth or self.__class__._external_auth or proxy.NoAuth()

        self.__root = Path(self)
        self.__session = requests.Session()
        self.__session.auth = self.Auth(auth)

        v = [_.isdigit() and int(_) or 0 for _ in requests.__version__.split(".")]
        self.__legacy_requests = len(v) < 3 or v[0] < 2 or v[1] < 10
        if self.__legacy_requests and v[0] == 2:
            patch_session_timeout(self.__session)

        _urllib3_logging_suppress()
        self.logger = logger or logging.getLogger(__name__)
        self.logger.debug(
            "REST API client instance created in thread '%s' with %s authorization%s.",
            threading.current_thread().ident, auth, " with legacy requests module" if self.__legacy_requests else ""
        )
        self.debug = debug  # KORUM: FIXME: Is only to debug "Interrupted system call" in AgentR

        self.total_wait = total_wait
        self.ua = ua or (settings.this.id if settings else socket.getfqdn())
        self.component = component or self._default_component

    def __getitem__(self, item):
        return getattr(self.__root, item)

    __getattr__ = __getitem__

    def __lshift__(self, other):
        return self.__root.__lshift__(other)

    def __rshift__(self, other):
        return self.__root.__rshift__(other)

    def copy(self):
        """ Creates a new full (deep) copy instance of self. """
        return self.__class__(
            urlparse.urlunparse(self.__parsed_base_url),
            self.__session.auth.auth, self.logger, self.total_wait, self.debug
        )

    def create(self, path, data, params=None, input_mode=JSON, output_mode=JSON):
        return output_mode.response(self._request(
            self.__session.post, path,
            {"data": input_mode.request(data), "params": params},
            input_mode.headers.request,
        ))

    def read(self, path, params=None, input_mode=JSON, output_mode=JSON):
        return output_mode.response(self._request(
            self.__session.get, path,
            {"params": params},
            input_mode.headers.request
        ))

    def update(self, path, data, params=None, input_mode=JSON, output_mode=JSON):
        return output_mode.response(self._request(
            self.__session.put, path,
            {"data": input_mode.request(data), "params": params},
            input_mode.headers.request,
        ))

    def delete(self, path, data=None, params=None, input_mode=JSON, output_mode=JSON):
        return output_mode.response(self._request(
            self.__session.delete, path,
            {"data": input_mode.request(data), "params": params},
            input_mode.headers.request,
        ))

    @property
    def interval(self):
        self.__interval = min(self.__interval * 55 / 34, self.MAX_INTERVAL)
        return self.__interval

    @property
    def timeout(self):
        self.__timeout = min(self.__timeout * 55 / 34, self.MAX_TIMEOUT)
        return self.__timeout

    @property
    def host(self):
        return self.__hosts[0]

    def reset(self):
        self.__interval = self.DEFAULT_INTERVAL
        self.__timeout = self.DEFAULT_TIMEOUT

    def _request(self, method, path, params=None, headers=None):
        self.reset()
        request_id = uuid.uuid4().hex
        timeout = min(self.__timeout, self.total_wait if self.total_wait else self.__timeout)
        if not self.__legacy_requests:
            timeout = (self.CONNECT_TIMEOUT, timeout)
        method_name = method.__name__.upper()
        if params:
            params_params = params.get("params") or {}
            for name, value in params_params.iteritems():
                if isinstance(value, dict):
                    params_params[name] = json.dumps(value)
                elif hasattr(value, "__iter__"):
                    params_params[name] = list(value)
        else:
            params = {}

        spent = 0
        started = time.time()
        while spent < self.total_wait if self.total_wait else True:
            url = None
            try:
                if headers is None:
                    headers = {}
                if request_id:
                    headers[ctm.HTTPHeader.REQUEST_ID] = request_id
                headers[ctm.HTTPHeader.USER_AGENT] = self.ua
                if self.component:
                    headers[ctm.HTTPHeader.COMPONENT] = self.component

                url = urlparse.urlunparse(
                    self.__parsed_base_url._replace(
                        netloc=self.host, path="{}{}".format(self.__parsed_base_url.path, re.sub("/+", "/", path))
                    )
                )
                self.logger.debug(
                    "REST request [%s] %s %s, timeout %s, query %s",
                    request_id, method_name, url, self.__timeout, params.get("params") or {}
                )
                with self.__insecure_platform_ok(), self.__sni_missing_ok():
                    ret = method(url, timeout=timeout, headers=headers, allow_redirects=True, **params)
                self.logger.debug(
                    "REST request [%s] finished at %r after %.3fs (HTTP code %d)",
                    request_id, ret.headers.get(ctm.HTTPHeader.BACKEND_NODE), time.time() - started, ret.status_code
                )
                ret.raise_for_status()
                return ret

            except select.error as ex:
                if ex.args[0] == errno.EINTR:
                    self.logger.warning("REST request [%s]: error: '%s'", request_id, ex)
                else:
                    raise

            except EnvironmentError as ex:
                self.logger.warning(
                    "REST request [%s]: failed after %.3fs: '%s'", request_id, time.time() - started, ex
                )
                do_sleep = True
                if self.debug:
                    self.logger.warning("Handle exception on REST API call: %r%r", type(ex), ex.args)
                if ex.errno == errno.EINTR:
                    do_sleep = False
                elif isinstance(ex, requests.ConnectionError):
                    if ex.args and isinstance(ex.args[0], requests.packages.urllib3.exceptions.MaxRetryError):
                        # Do not sleep on connection timeout
                        do_sleep = False
                    else:
                        try:
                            do_sleep = not (
                                isinstance(ex.args[0].args[1], socket.error) and
                                ex.args[0].args[1].errno in (errno.ECONNREFUSED, errno.EHOSTDOWN, errno.ENETUNREACH)
                            )
                        except IndexError:
                            pass
                elif isinstance(ex, requests.HTTPError):
                    if ex.response.status_code == httplib.GONE:
                        raise self.SessionExpired(ex)
                    elif ex.response.status_code not in self.RETRYABLE_CODES:
                        raise self.HTTPError(ex)
                    do_sleep = ex.response.status_code != httplib.BAD_GATEWAY

                self.__hosts.rotate(-1)

                if do_sleep or len(self.__hosts) == 1:
                    tick = self.interval
                    time.sleep(max(0, min(tick, self.total_wait - spent)) if self.total_wait else tick)
                    timeout = self.timeout if self.__legacy_requests else (self.CONNECT_TIMEOUT, self.timeout)

            except BaseException as ex:
                if self.debug:
                    self.logger.warning("Unhandled exception on REST API call: %r%r", type(ex), ex.args)
                raise

            # record call duration
            finally:
                spent = time.time() - started
                if type(self)._accounter:
                    type(self)._accounter(method_name, spent)

        raise self.TimeoutExceeded(
            "Error requesting method '{}' for path '{}' - no response given after {!s}.".format(
                method, path, datetime.timedelta(seconds=spent)
            )
        )


class ThreadLocalCachableClient(Client):
    """
    To eliminate a problem of thread safety and also provide a single method to obtain a server proxy object,
    the class using special meta, which ensures one instance per thread.
    """
    __metaclass__ = proxy.ThreadLocalMeta


class DispatchedClient(object):
    """
    Used to dynamically switch actual REST API client class.
    Usage:

    .. code-block:: python

        with DispatchedClient as dispatch:
            dispatch(RealClientClass)
            # some code using DispatchedClient as REST API client
            ...

    @DynamicAttrs
    """

    # noinspection PyPep8Naming
    class __metaclass__(type):
        __default_client = ThreadLocalCachableClient
        __local = threading.local()

        @property
        def __clients(cls):
            try:
                clients = cls.__local.clients
            except AttributeError:
                clients = cls.__local.clients = []
            return clients

        def __call__(cls, *args, **kwargs):
            return (
                cls.__clients[-1](*args, **kwargs)
                if cls.__clients else
                cls.__default_client(*args, **kwargs)
            )

        def __enter__(cls):
            return lambda client: cls.__clients.append(client)

        def __exit__(cls, *_):
            try:
                cls.__clients.pop()
            except IndexError:
                pass
