from __future__ import absolute_import

import os
import time
import math
import errno
import socket
import logging
import threading as th

import requests
import requests.packages.urllib3.util.connection

import xml.parsers.expat
import xml.etree.ElementTree

from . import utils
from . import config


class MDS(object):
    """
    Class encapsulating methods to working with MDS
    """
    __metaclass__ = utils.SingletonMeta

    class Exception(Exception):
        pass

    orig_create_connection = requests.packages.urllib3.util.connection.create_connection

    __local = th.local()
    __patch_counter = 0

    # noinspection All
    def __create_connection(
        self,
        address,
        timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
        source_address=None,
        socket_options=None
    ):
        """
        This function is copied from "requests.packages.urllib3.utils.connection.py" and
        modified to process `errno.EINPROGRESS` error raised on `socket.connect` call
        which is mostly common case on connecting to MDS.
        """

        self.__local.logger.debug("Create connection to %s", address)
        host, port = address
        err = None
        for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
            af, socktype, proto, canonname, sa = res
            sock = None
            while True:
                try:
                    if not sock:
                        sock = socket.socket(af, socktype, proto)

                        # If provided, set socket level options before connecting.
                        # This is the only addition urllib3 makes to this function.
                        requests.packages.urllib3.util.connection._set_socket_options(sock, socket_options)

                        if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
                            sock.settimeout(timeout)
                        if source_address:
                            sock.bind(source_address)
                    sock.connect(sa)
                    return sock
                except socket.error as _:
                    if _.errno in (errno.EINPROGRESS, errno.EALREADY):
                        if timeout == socket._GLOBAL_DEFAULT_TIMEOUT or timeout > 1:
                            time.sleep(1)
                            if timeout == socket._GLOBAL_DEFAULT_TIMEOUT:
                                timeout -= 1
                            continue
                    err = _
                    if sock is not None:
                        sock.close()
                break
        if err is not None:
            raise err
        else:
            raise socket.error("getaddrinfo returns an empty list")

    def __create_connection_dispatcher(self, *args, **kws):
        return (getattr(self.__local, "create_connection") or self.orig_create_connection)(*args, **kws)

    def upload(self, path, mds_name, token, namespace=None, ttl=None, size=None, timeout=None, logger=None):
        """
        Upload file to MDS

        :param path: path to file to upload
        :param mds_name: name used to form MDS url
        :param token: token used to authenticate on MDS site
        :param namespace: MDS namespace, if not defined, Sandbox's namespace will be used
        :param ttl: TTL in days
        :param size: file size in bytes, if not defined, will be got from file system
        :param timeout: timeout in seconds, if not defined, will be calculated from file size
        :param logger: custom logger, if not defined, logging will be used
        :return:
        """
        mds_settings = config.Registry().client.mds
        logger = logger or logging
        namespace = namespace or mds_settings.namespace
        # TODO: remove this if the bug is not longer reproduced
        # self.__patch_counter += 1
        # requests.packages.urllib3.util.connection.create_connection = self.__create_connection_dispatcher
        # self.__local.create_connection = self.__create_connection
        # self.__local.logger = logger
        try:
            if timeout is None:
                if size is None:
                    size = os.path.getsize(path)
                timeout = max(60, int(math.ceil(float(size) / mds_settings.up.min_speed)))
            url = "{}/upload-{}/{}".format(mds_settings.up.url, namespace, mds_name)
            logger.debug("Uploading data of size %s via URL %s", utils.size2str(size), url)
            with open(path, "rb") as fh:
                params = {}
                if ttl:
                    params["expire"] = "{}d".format(ttl)

                resp = requests.post(
                    url,
                    data=fh,
                    params=params,
                    headers={"Authorization": "Basic " + token},
                    timeout=timeout
                )
            if resp.status_code == requests.codes.FORBIDDEN:
                try:
                    key = xml.etree.ElementTree.fromstring(resp.content).find("key").text
                    logger.info("Data already uploaded with key %s", key)
                except (AttributeError, KeyError, xml.etree.ElementTree.ParseError, xml.parsers.expat.ExpatError) as ex:
                    raise self.Exception(
                        "Data upload failed - unable to parse FORBIDDEN response: {}. Data is: {}".format(
                            ex, resp.content
                        )
                    )
            elif resp.status_code != requests.codes.OK:
                raise self.Exception(
                    "Data upload failed - service respond code {}: {}".format(resp.status_code, resp.reason)
                )
            else:
                try:
                    key = xml.etree.ElementTree.fromstring(resp.content).attrib["key"]
                except (AttributeError, KeyError, xml.etree.ElementTree.ParseError, xml.parsers.expat.ExpatError) as ex:
                    raise self.Exception(
                        "Data upload failed - unable to parse response: {}. Data is: {}".format(ex, resp.content)
                    )
            return key
        finally:
            pass
            # TODO: remove this if the bug is not longer reproduced
            # self.__local.create_connection = None
            # self.__patch_counter -= 1
            # if self.__patch_counter < 1:
            #     requests.packages.urllib3.util.connection.create_connection = self.orig_create_connection


upload = MDS().upload
