# coding: utf-8

import re
import string
import socket
import time

from multiprocessing.dummy import Pool as ThreadPool
from urllib.parse import urlparse
from collections import Iterable, deque
from itertools import chain

from django.conf import LazySettings
from django.utils.encoding import force_text
from functools import reduce

splitter = ' :: '
DEFAULT_PORT = 80

EXCLUDED_PARAMS = set(map(string.lower, ['admins', 'managers', 'DJANGO_INTRANET_USERIP_DEFAULT']))
EXCLUDED_VALUES = set(map(string.lower, []))

KNOWN_SCHEMES = {
    'http': 80,
    'https': 443,
    'mongodb': 27017,
    'ldap': 389,
    'ldaps': 636,
    'ftp': 21
}

host_pattern = re.compile('(^|_)host(_|$)')
allowed_domains = re.compile('.*\.(ru|com|net|org|tr|ua|by|kz)(\:\d+)?(\W.*)?$')


def _walk(key, value):

    if isinstance(value, str) or (isinstance(value, int) and not isinstance(value, bool)):
        yield key, value

    elif isinstance(value, dict):
        for sub_key, sub_val in chain(*[_walk(splitter.join([key, s_key]), s_val)
                                        for s_key, s_val in value.items()]):
            yield sub_key, sub_val

    elif isinstance(value, Iterable):
        for sub_key, sub_val in chain(*[_walk(key, elem) for elem in value]):
            yield sub_key, sub_val


def walk_settings(conf):

    if isinstance(conf, LazySettings):
        conf = conf._wrapped.__dict__
    elif not isinstance(conf, dict):
        raise ValueError('conf must be a dict or LazySettings instance')

    for key, value in chain(*[_walk(s_key, s_val) for s_key, s_val in conf.items()]):
        yield key.lower(), value


def wormholes_filter(params):

    (param_name, value) = params
    if not param_name or param_name in EXCLUDED_PARAMS:
        return False

    elif not value or not isinstance(value, str) or value in EXCLUDED_VALUES:
        return False

    elif '@' in force_text(value) and 'email' in param_name:
        return False

    elif any(map(force_text(value).__contains__, ['localhost', '127.0.0.1', ' ', '\n'])):
        return False

    elif host_pattern.search(param_name):
        #print 'host pattern', param_name, value
        return True

    elif allowed_domains.search(value):
        #print 'allowed domains', param_name, value
        return True

    else:
        try:
            socket.inet_pton(socket.AF_INET, value)
            #print 'ipv4', param_name, value
            return True
        except:
            try:
                socket.inet_pton(socket.AF_INET6, value)
                #print 'ipv6', param_name, value
                return True
            except:
                return False


def _get_port_from_config(key, config_dict):

    splitted_key = key.rsplit(splitter, 1)
    if len(splitted_key) == 2:
        param_path, param_name = splitted_key
        port_param_name = splitter.join([param_path, param_name.replace('host', 'port')])
    else:
        port_param_name = splitted_key[-1].replace('host', 'port')

    return int(config_dict.get(port_param_name, DEFAULT_PORT)) if port_param_name != key else DEFAULT_PORT


def infer_host_port(key, value, config_dict):

    parsed = urlparse(value)
    if not parsed.hostname:
        if ':' in value:
            # for values like 'cs-zk03gt.yandex.ru:2181'
            host_port = value.rsplit(':', 1)
            try:
                port = int(host_port[-1])
            except ValueError:
                port = DEFAULT_PORT

            return key, (host_port[0], port)

        # trying to find matching for key PORT setting in config
        return key, (value, _get_port_from_config(key, config_dict))

    elif parsed.port:
        return key, (parsed.hostname, parsed.port)

    elif parsed.scheme:
        return key, (parsed.hostname, KNOWN_SCHEMES.get(parsed.scheme, DEFAULT_PORT))


def wormhole_checker(queue, result_q, timeout=1):

    while True:
        try:
            item = queue.pop()
        except IndexError:
            break

        address, name = item
        # print threading.current_thread(), 'got', name, address
        try:
            start = time.time()
            socket.create_connection(address, timeout=timeout)
        except socket.timeout:
            result_q.append(item + ('timeout', ))
        except socket.gaierror:
            result_q.append(item + ('not known', ))
        except socket.error:
            result_q.append(item + ('no route', ))
        else:
            result_q.append(item + ('%0.2f' % ((time.time() - start) * 1000), ))


def spawner(target=None, args=None, maxthreads=0):
    """
    Spawn number of threads with arguments

    @type target: callable
    @param target: callable target for each thread
    @param args: arguments for each thread
    @param maxthreads: maximum number of threads
    """
    pool = ThreadPool(maxthreads, initializer=target, initargs=args)

    pool.close()
    pool.join()


def check_wormholes(conf, include=None):

    assert(isinstance(conf, (dict, LazySettings)))

    candidates = list(filter(wormholes_filter, walk_settings(conf=conf)))
    config_dict = dict(walk_settings(conf=conf))

    candidates = [infer_host_port(param_name_url[0], param_name_url[1], config_dict) for param_name_url in candidates] + \
                 (include or [])

    candidates = reduce(
        lambda acc, (param_name, hostport): acc.update({
            hostport: param_name if hostport not in acc else \
            ', '.join([acc[hostport], param_name])
        }) or acc,
        candidates,
        {}
    )
    candidates_len = len(candidates)
    queue = deque(candidates.items(), maxlen=candidates_len)
    result_q = deque()
    spawner(target=wormhole_checker, args=(queue, result_q), maxthreads=candidates_len)

    return result_q

