# -*- coding: utf-8 -*-

from collections import namedtuple
from functools import partial
import logging
from socket import timeout as SocketTimeout

import six
import urllib3
import urllib3.connection
from urllib3.connectionpool import port_by_scheme
from urllib3.exceptions import LocationValueError
from urllib3.poolmanager import (
    _default_key_normalizer,
    PoolKey,
)


logger = logging.getLogger('passport.useragent')


PassportPoolKey = namedtuple('PassportPoolKey', list(PoolKey._fields) + ['key_host_ip'])


def _create_socket_connection(http_conn):
    extra_kw = {}
    if http_conn.source_address:
        extra_kw['source_address'] = http_conn.source_address

    if http_conn.socket_options:
        extra_kw['socket_options'] = http_conn.socket_options

    try:
        if http_conn.host_ip:
            host = http_conn.host_ip
        else:
            logger.error('Host IP not provided')
            host = http_conn.host
        conn = urllib3.connection.connection.create_connection((host, http_conn.port), http_conn.timeout, **extra_kw)

    except SocketTimeout:
        raise urllib3.exceptions.ConnectTimeoutError(http_conn, 'Connection to %s timed out. (connect timeout=%s)' % (http_conn.format_hostname(), http_conn.timeout))

    return conn


def _format_http_connection_hostname(http_conn):
    if http_conn.host_ip is not None:
        return '%s (ip_address=%s)' % (http_conn.host, http_conn.host_ip)
    else:
        return http_conn.host


class _HTTPConnection(urllib3.connection.HTTPConnection):
    def __init__(self, host, host_ip=None, *args, **kwargs):
        super(_HTTPConnection, self).__init__(host, *args, **kwargs)
        self.host_ip = host_ip

    def _new_conn(self):
        return _create_socket_connection(self)

    def format_hostname(self):
        return _format_http_connection_hostname(self)


class _HTTPSConnection(urllib3.connection.HTTPSConnection):
    def __init__(self, host, host_ip=None, *args, **kwargs):
        super(_HTTPSConnection, self).__init__(host, *args, **kwargs)
        self.host_ip = host_ip

    def _new_conn(self):
        return _create_socket_connection(self)

    def format_hostname(self):
        return _format_http_connection_hostname(self)


class _HTTPConnectionPool(urllib3.HTTPConnectionPool):
    ConnectionCls = _HTTPConnection


class _HTTPSConnectionPool(urllib3.HTTPSConnectionPool):
    ConnectionCls = _HTTPSConnection


class PoolManager(urllib3.PoolManager):
    def __init__(self, *args, **kwargs):
        super(PoolManager, self).__init__(*args, **kwargs)
        self.pool_classes_by_scheme = {
            'http': _HTTPConnectionPool,
            'https': _HTTPSConnectionPool,
        }
        for scheme in self.key_fn_by_scheme:
            self.key_fn_by_scheme[scheme] = partial(_default_key_normalizer, PassportPoolKey)

        self._host_to_ip_table = dict()

    def get_ip_for_host(self, host):
        if six.PY3 and isinstance(host, bytes):
            host = host.decode('utf-8')
        if host.startswith('[') and host.endswith(']'):
            host = host.strip('[]')
        return self._host_to_ip_table.get(host)

    def set_ip_for_host(self, host, ip):
        # urllib3 под капотом работает со строками, поэтому чтобы вернуть строку,
        # лучше здесь привести к ней все аргументы
        if six.PY3:
            if isinstance(host, bytes):
                host = host.decode('utf-8')
            if isinstance(ip, bytes):
                ip = ip.decode('utf-8')
        self._host_to_ip_table[host] = ip

    def connection_from_host(self, host, port=None, scheme='http', pool_kwargs=None):
        if not host:
            raise LocationValueError('No host specified.')

        if pool_kwargs is None:
            pool_kwargs = dict()
        else:
            pool_kwargs = pool_kwargs.copy()

        pool_kwargs.setdefault('host_ip', self.get_ip_for_host(host))

        return super(PoolManager, self).connection_from_host(host, port=port, scheme=scheme, pool_kwargs=pool_kwargs)

    def invalidate_host_connection(self, host, port=None, scheme='http', pool_kwargs=None):
        if pool_kwargs is None:
            pool_kwargs = dict()
        else:
            pool_kwargs = pool_kwargs.copy()
        pool_kwargs.setdefault('host_ip', self.get_ip_for_host(host))

        request_context = self._merge_pool_kwargs(pool_kwargs)
        request_context['scheme'] = scheme.lower() or 'http'
        if not port:
            port = port_by_scheme.get(request_context['scheme'].lower(), 80)
        request_context['port'] = port
        request_context['host'] = host

        # иначе pool_key будет состоять из байт и не найдётся в self.pools
        request_context = {k: v.decode('utf-8') if isinstance(v, bytes) else v for k, v in request_context.items()}

        # key_fn_by_scheme в urllib3 содержит строки как ключи, поэтому нужно перевести байты в строки
        if isinstance(scheme, bytes):
            scheme = scheme.decode('utf-8')
        pool_key_constructor = self.key_fn_by_scheme[scheme]
        pool_key = pool_key_constructor(request_context)

        try:
            del self.pools[pool_key]
        except KeyError:
            pass
