import re
import threading
import logging
import socket
from typing import List, Dict, Iterable, Optional

from ids.configuration import get_config

from django.conf import settings
from django.http import JsonResponse
from django.utils.functional import SimpleLazyObject

from staff.racktables.objects import YaIP
from staff.racktables.tasks import get_nets_by_macro
from staff.monitorings.utils import check_host


logger = logging.getLogger(__name__)
host_re = re.compile(r'[^\-\.a-zA-Z0-9]*([\-\.a-zA-Z0-9]+(\.ru|\.net)):?([0-9]*)')


def find_host_port(host_raw):
    for host_port in host_re.findall(host_raw):
        yield host_port[::2]


def get_hosts_from_setting(host_setting, port_setting, _type='settings'):
    if _type == 'settings':
        hosts = getattr(settings, host_setting)
        if isinstance(port_setting, int):
            port = port_setting
        elif port_setting:
            port = getattr(settings, port_setting)
        else:
            port = None
        if isinstance(hosts, (list, tuple)):
            for host_row in hosts:
                for host, host_port in find_host_port(host_row):
                    yield host, port if port else host_port
        else:
            for host, host_port in find_host_port(hosts):
                yield host, port if port else host_port
    elif _type == 'ids':
        yield get_config(host_setting)['host'], 443
    elif _type == 'databases':
        conf = getattr(settings, host_setting)
        for db_conf in conf.values():
            for one_host in db_conf['HOST'].split(','):
                yield one_host, db_conf['PORT']


def collect_fw_rules(skip=None):
    skip = skip or []
    holes = [
        ('LDAP', ('LDAP_HOST', 'LDAP_PORT')),
        ('AVATAR_STORAGE', ('AVATAR_STORAGE_HOST', 'AVATAR_STORAGE_PORT')),
        ('BLACKBOX', ('BLACKBOX_URL', 80)),
        ('BOT', ('BOT_HOST', 443)),
        ('CALENDAR_API', ('CALENDAR_API_HOST', 80)),
        ('CRM', ('CRM_HOST', 'CRM_PORT')),
        ('DIALER', ('DIALER_HOST', None)),
        ('YA_DISK', ('YA_DISK_HOST', 80)),
        ('EXT_BLACKBOX', ('EXT_BLACKBOX_URL', 80)),
        ('ML', ('ML_HOST', 443)),
        ('OAUTH', ('OAUTH', None, 'ids')),
        ('HBF', ('HBF_URL', 80)),
        ('RACKTABLES', ('RACKTABLES_URL', 443)),
        ('STARTREK_API', ('STARTREK_API', None, 'ids')),
        ('WIKIFORMATTER', ('WIKIFORMATTER_URL', 443)),
        ('GOALS', ('GOALS_HOST', 443)),
        ('DATABASE', ('DATABASES', None, 'databases')),
        ('MONGO', ('MONGO_HOSTS', None)),
        ('EMAIL', ('EMAIL_HOST', 'EMAIL_PORT')),
        ('OEBS_HOST', ('OEBS_HOST', 443)),
        ('REVIEW', ('REVIEW_URL', 443)),
        ('ABC', ('ABC_URL', 443)),
    ]
    from collections import defaultdict
    result = defaultdict(list)

    def add_check_host_results(system, host, port):
        state = check_host(host, int(port))
        result[system].append({
            'state': int(state),
            'host': host,
            'port': port,
        })

    threads = []
    for hole_name, host_port_settings in holes:
        if hole_name in skip:
            continue

        for host, port in get_hosts_from_setting(*host_port_settings):
            threads.append(
                threading.Thread(
                    target=add_check_host_results,
                    kwargs={'system': hole_name, 'host': host, 'port': port}
                )
            )

    for thread in threads:
        thread.start()

    for thread in threads:
        thread.join()

    return result


KNOWN_DATA_CENTERS = ['MYT', 'SAS', 'VLA', 'IVA', 'MAN']
DATA_CENTER_NETS = SimpleLazyObject(lambda: {dc: get_nets_by_macro(f'_DC_{dc}_NETS_') for dc in KNOWN_DATA_CENTERS})


def get_data_centers_by_hosts(hosts: Iterable[str]) -> Iterable[str]:
    try:
        for host in hosts:
            yield get_data_center(host)
    except Exception:
        logger.info('Failed to get DC, assuming no DC.', exc_info=True)
        return []


def get_data_center(host: str) -> Optional[str]:
    rule = YaIP(socket.getaddrinfo(host, None)[0][4][0]).as_rule()

    for dc, ip_range in DATA_CENTER_NETS.items():
        for candidate in ip_range:
            if rule.is_inside(candidate.as_rule()):
                return dc

    return None


def get_location_tags(data_centers: Iterable[str]) -> Iterable[str]:
    all_dcs = list(data_centers)
    if any(not dc for dc in all_dcs):
        return []

    for dc in all_dcs:
        lowered_dc = dc.lower()
        yield 'a_geo_' + lowered_dc
        yield 'a_dc_' + lowered_dc


def check_fw_rules_impl(skip_holes: List, show_all: bool) -> Dict:
    rules_result = collect_fw_rules(skip_holes)
    if show_all:
        return rules_result

    lost_holes = {}
    failed_hosts = []

    for system_name, holes_state in rules_result.items():
        system_errors = [hole_state for hole_state in holes_state if hole_state.pop('state') == 0]
        failed_hosts.extend(error_info['host'] for error_info in system_errors)

        if system_errors:
            lost_holes[system_name] = {
                'failed': len(system_errors),
                'total': len(holes_state),
                'details': system_errors,
            }

    if failed_hosts:
        lost_holes['tags'] = list(get_location_tags(get_data_centers_by_hosts(failed_hosts)))

    return lost_holes


def check_fw_rules(request):
    skip = request.GET.get('skip', '').split(',')
    skip = [s.upper() for s in skip if s]
    show_all = bool(request.GET.get('show_all', False))

    holes = check_fw_rules_impl(skip, show_all)

    if not show_all and holes:
        logger.info('Lost holes: %s', list(holes.keys()))

    return JsonResponse(data=holes)
