"""DNS record fixer."""

import logging
import random
import re
from ipaddress import ip_address

import gevent
from apscheduler.schedulers.gevent import GeventScheduler as Scheduler
from gevent.event import Event
from gevent.pool import Pool
from mongoengine import Q, DoesNotExist

from sepelib.core import config, constants
from sepelib.core.exceptions import Error
from sepelib.yandex.dns_api import DnsApiError
from walle import audit_log, constants as walle_constants, network, restrictions
from walle.constants import NETWORK_SOURCE_LLDP, MAC_SOURCE_AGENT, NETWORK_SOURCE_RACKTABLES, MAC_SOURCE_RACKTABLES
from walle.dns.dns_lib import get_operations_for_dns_records
from walle.errors import InvalidHostConfiguration
from walle.expert.automation import (
    GLOBAL_DNS_AUTOMATION,
    PROJECT_DNS_AUTOMATION,
    AutomationDisabledError,
    dns_automation,
)
from walle.expert.decision import Decision
from walle.expert.types import WalleAction
from walle.host_network import HostNetwork
from walle.hosts import Host, HostState, HostStatus, HostMessage
from walle.locks import HostInterruptableLock
from walle.models import timestamp
from walle.projects import Project
from walle.statbox.contexts import host_context
from walle.statbox.loggers import dns_logger
from walle.stats import stats_manager as stats, IntegerLinearHistogram
from walle.util import mongo
from walle.util.gevent_tools import gevent_idle_iter
from walle.util.misc import add_interval_job, StopWatch
from walle.util.mongo import SECONDARY_LOCAL_DC_PREFERRED

log = logging.getLogger(__name__)

_CHECK_PERIOD = 10 * constants.MINUTE_SECONDS
"""Period with which we check for host configuration changes."""

_ERROR_CHECK_PERIOD = _CHECK_PERIOD
"""Time for which we delay the next check after an error."""

_FULL_CHECK_PERIOD = constants.DAY_SECONDS
"""
Period with which we process the full check: current host's configuration is compared to current DNS records for the
host.
"""

_INVALID_CONFIGURATION_CHECK_PERIOD = constants.HOUR_SECONDS
"""Period with which we check misconfigured hosts."""

_MIN_ACTUALIZATION_TIME = constants.DAY_SECONDS + constants.HOUR_SECONDS
"""Minimum network map data actualization time to use it in DNS fixing."""


class DnsFixerTimeout(Error):
    def __init__(self, timeout):
        super().__init__("DNS fix took too long ({} sec)", timeout)


class IPsMismatch(Error):
    def __init__(self, ips_from_agent, ips_from_dns):
        super().__init__(
            "IP addrs from DNS ({}) don't match addrs from agent ({})",
            self.ips_to_str(ips_from_dns),
            self.ips_to_str(ips_from_agent),
        )

    @staticmethod
    def ips_to_str(ips):
        return ", ".join(sorted(str(ip) for ip in ips))


def start(scheduler: Scheduler, partitioner: mongo.MongoPartitionerService):
    if not config.get_value("automation.enabled") or not config.get_value("dns_fixer.enabled"):
        return

    add_interval_job(
        scheduler, lambda: _DNS_FIXER.run(partitioner), name="DNS fixer", interval=constants.MINUTE_SECONDS
    )


def stop():
    _DNS_FIXER.stop()


class _HostDnsConfig:
    def __init__(self, project_id, switch, switch_source, mac, mac_source, vlans, vlan_scheme, ips):
        self.project = project_id
        self.switch = switch
        self.switch_source = switch_source
        self.mac = mac
        self.mac_source = mac_source
        self.vlans = vlans
        self.vlan_scheme = vlan_scheme
        self.ips = ips

    def get_update_host_kwargs(self):
        keys = ("switch", "mac", "vlans", "vlan_scheme", "project", "ips")
        return {"set__dns__{}".format(k): getattr(self, k) for k in keys}

    def get_audit_log_kwargs(self):
        keys = ("switch", "switch_source", "mac", "mac_source", "ips")
        return {k: getattr(self, k) for k in keys}


def _get_host_dns_config(host, host_network, project):
    vlans = network.get_host_expected_vlans(host, project).vlans
    if vlans is None:
        raise InvalidHostConfiguration("Host's project is not configured for DNS auto-configuration.")

    switch, switch_source = host_network.network_switch, host_network.network_source
    mac, mac_source = host_network.active_mac, host_network.active_mac_source
    ips = host_network.ips

    return _HostDnsConfig(
        project_id=host.project,
        switch=switch,
        switch_source=switch_source,
        mac=mac,
        mac_source=mac_source,
        vlans=vlans,
        vlan_scheme=project.vlan_scheme,
        ips=ips,
    )


class DnsFixer:
    def __init__(self) -> None:
        self.pool = Pool(4)
        self.partitioner = None
        self.stopped_event = Event()

    def stop(self):
        self.stopped_event.set()
        if self.partitioner:
            self.partitioner.stop()
        self.pool.kill()

    def run(self, partitioner: mongo.MongoPartitionerService, only_l3_search_hosts=True):
        if not GLOBAL_DNS_AUTOMATION.is_enabled() or self.stopped_event.is_set():
            return
        self.partitioner = partitioner
        projects_query = dict(
            vlan_scheme__in=walle_constants.DNS_VLAN_SCHEMES, dns_automation__enabled=True, dns_domain__exists=True
        )

        fix_dns_projects = config.get_value("dns_fixer.projects", None)
        if fix_dns_projects is not None:
            projects_query.update(id__in=fix_dns_projects)

        enabled_projects = Project.objects(**projects_query).only(*network.DNS_REQUIRED_PROJECT_FIELDS)
        enabled_projects = {p.id: p for p in gevent_idle_iter(enabled_projects)}
        if not enabled_projects:
            return

        all_shards = [str(s) for s in range(config.get_value("dns_fixer.shards_num"))]
        for shard_id in all_shards:
            shard = self.partitioner.get_shard(shard_id)
            if not shard:
                continue
            log.info("Checking hosts' DNS records for #%s shard...", shard)

            try:
                with shard.lock:
                    checked_hosts = self._fix_dns_records(shard, enabled_projects, only_l3_search_hosts)
            except Exception:
                log.exception("Failed to fix DNS records for #%s shard: %s", shard)
            else:
                log.info("DNS records has been checked for %s hosts from #%s shard.", checked_hosts, shard)

    def _fix_dns_records(self, shard, enabled_projects, only_l3_search_hosts):
        """Finds hosts with outdated DNS records and updates them."""

        free_host_name_template = network.get_free_host_name_template()
        hosts_query = _get_host_query(shard, enabled_projects, only_l3_search_hosts)
        hosts_networks = _get_hosts_to_fix(hosts_query)

        # Use randomization here to avoid the situation when one daemon blocks other daemons when a few daemons process the
        # same shard.

        random.shuffle(hosts_networks)
        checked_hosts = 0

        try:
            for host, host_network in gevent_idle_iter(hosts_networks):
                if not GLOBAL_DNS_AUTOMATION.is_enabled_cached():
                    break

                if not PROJECT_DNS_AUTOMATION.enabled_for_project_cached(host.project):
                    continue

                checked_hosts += 1

                project = enabled_projects[host.project]
                try:
                    dns_config = _get_host_dns_config(host, host_network, project)
                except InvalidHostConfiguration as e:
                    log.error("Host %s: Failed to get dns configuration: {}".format(e), host.human_name())
                    continue

                if not _is_outdated_dns_records(host, project, dns_config):
                    _clear_dns_fixer_errors(host)
                    continue

                # A simple workaround for partially freed hosts: when host has been removed from DNS, renamed in BOT, but
                # the operation has failed and we got "almost free" host that hasn't got HostState.FREE state.
                if free_host_name_template.fqdn_matches(host.name):
                    continue

                host_query = hosts_query & Q(inv=host.inv)
                self.pool.spawn(_try_fix_host_dns_records, host, host_query, enabled_projects)
        finally:
            self.pool.join()

        return checked_hosts


_DNS_FIXER = DnsFixer()


def _get_hosts_network_query(uuids):
    return Q(
        **dict(
            network_source__in=[NETWORK_SOURCE_RACKTABLES, NETWORK_SOURCE_LLDP],
            network_timestamp__gt=timestamp() - _MIN_ACTUALIZATION_TIME,
            active_mac__exists=True,
            active_mac_source__in=[MAC_SOURCE_RACKTABLES, MAC_SOURCE_AGENT],
            active_mac_time__gt=timestamp() - _MIN_ACTUALIZATION_TIME,
            uuid__in=uuids,
        )
    )
    pass


def _get_hosts_to_fix(hosts_query):
    fields = ("uuid", "inv", "name", "project", "dns", "location", "extra_vlans", "messages", "tier")
    hosts_iterator = gevent_idle_iter(
        Host.objects(hosts_query, read_preference=SECONDARY_LOCAL_DC_PREFERRED).only(*fields)
    )
    uuid_host_map = {host.uuid: host for host in hosts_iterator}
    return [
        (uuid_host_map[host_network.uuid], host_network)
        for host_network in gevent_idle_iter(
            HostNetwork.objects(
                _get_hosts_network_query(list(uuid_host_map.keys())), read_preference=SECONDARY_LOCAL_DC_PREFERRED
            ).only(
                "uuid",
                "active_mac_time",
                "active_mac",
                "active_mac_source",
                "network_source",
                "network_switch",
                "network_timestamp",
                "ips",
            )
        )
    ]


def _get_host_query(shard, enabled_projects, only_l3_search_hosts, inv=None):
    query = dict(
        project__in=list(enabled_projects.keys()),
        state__in=HostState.ALL_ASSIGNED,
        status__nin=[HostStatus.INVALID] + HostStatus.ALL_RENAMING,
        restrictions__nin=restrictions.expand_restrictions([restrictions.AUTOMATED_DNS]),
    )

    if inv is not None:
        query.update(inv=inv)

    # TODO: At this time DNS fixing works only for Search project and we have too much *.yandex.ru hosts that
    # raise InvalidHostConfiguration after obtaining the lock. Skip them on query stage to not spend CPU and time on
    # hosts that definitely won't be fixed.
    if only_l3_search_hosts:
        domains = {"." + project.dns_domain for project in enabled_projects.values()}
        domains_re = "|".join(map(re.escape, domains))
        query.update(name=re.compile("(?:" + domains_re + ")$"))

    shard_query = mongo.get_host_mongo_shard_query(shard, config.get_value("dns_fixer.shards_num"))
    return (
        shard_query
        & Q(**query)
        & (
            Q(dns__check_time__exists=False)
            | Q(
                dns__switch__exists=False,
                dns__mac__exists=False,
                dns__check_time__lte=timestamp() - _INVALID_CONFIGURATION_CHECK_PERIOD,
            )
            | Q(dns__switch__exists=True, dns__mac__exists=True, dns__check_time__lte=timestamp() - _CHECK_PERIOD)
        )
        & (Q(dns__error_time__exists=False) | Q(dns__error_time__lte=timestamp() - _ERROR_CHECK_PERIOD))
    )


def _is_outdated_dns_records(host, project, dns_config, statbox_logger=None):
    # We don't have any info about host's DNS records
    if (
        host.dns is None
        or host.dns.check_time is None
        or host.dns.switch is None
        or host.dns.mac is None
        or host.dns.project is None
        or host.dns.vlans is None
        or host.dns.vlan_scheme is None
    ):
        if statbox_logger is not None:
            log.warning("%s: DNS records are outdated: there is no info about host's DNS records.", host.human_id())
            statbox_logger.log(decision="no-dns-records-info")

        return True

    # Host's network configuration has been changed
    if host.dns.switch != dns_config.switch or host.dns.mac != dns_config.mac:
        if statbox_logger is not None:
            log.warning(
                "%s: DNS records are outdated: %s/%s has changed to %s/%s.",
                host.human_id(),
                host.dns.switch,
                host.dns.mac,
                dns_config.switch,
                dns_config.mac,
            )
            statbox_logger.log(
                decision="network-changed",
                old_switch=host.dns.switch,
                old_mac=host.dns.mac,
                cur_switch=dns_config.switch,
                cur_mac=dns_config.mac,
            )

        return True

    # Host has switched project
    if host.dns.project != host.project:
        if statbox_logger is not None:
            log.warning(
                "%s: DNS records are outdated: host switched project from %s to %s.",
                host.human_id(),
                host.dns.project,
                host.project,
            )
            statbox_logger.log(decision="project-switched", old_project=host.dns.project, cur_project=host.project)

        return True

    # Project has its VLAN scheme changed. Very rare case but it is still possible.
    if host.dns.vlan_scheme != project.vlan_scheme:
        if statbox_logger is not None:
            log.warning(
                "%s: DNS records are outdated: project's vlan_scheme changed from %s to %s.",
                host.human_id(),
                host.dns.vlan_scheme,
                project.vlan_scheme,
            )
            statbox_logger.log(
                decision="project-vlans-changed", old_vlan_scheme=host.dns.vlans, cur_vlan_scheme=project.vlan_scheme
            )

        return True

    # Project has its VLAN changed. Very rare case but it is still possible.
    if host.dns.vlans != dns_config.vlans:
        if statbox_logger is not None:
            old_vlans = ",".join(str(vlan) for vlan in host.dns.vlans)
            cur_vlans = ",".join(str(vlan) for vlan in dns_config.vlans)
            log.warning(
                "%s: DNS records are outdated: project's VLANs changed from %s to %s.",
                host.human_id(),
                old_vlans,
                cur_vlans,
            )
            statbox_logger.log(decision="host-vlans-changed", old_vlans=old_vlans, cur_vlans=cur_vlans)

        return True

    # Host's IP addrs changed
    if host.dns.ips != dns_config.ips:
        if statbox_logger is not None:
            old_ips = ", ".join(host.dns.ips)
            cur_ips = ",".join(dns_config.ips)
            log.warning(
                "%s: DNS records are outdated: host's IP addrs changed from %s to %s.",
                host.human_id(),
                old_ips,
                cur_ips,
            )
            statbox_logger.log(decision="host-ips-changed", old_ips=old_ips, cur_ips=cur_ips)

        return True

    # All is OK but we should force recheck periodically in case someone has changed DNS manually
    if timestamp() - host.dns.check_time >= _FULL_CHECK_PERIOD:
        if statbox_logger is not None:
            log.warning("%s: Initiate DNS records recheck.", host.human_id())
            statbox_logger.log(decision="time-for-recheck")

        return True

    return False


def _try_fix_host_dns_records(host, host_query, enabled_projects):
    try:
        with HostInterruptableLock(host.uuid, host.tier):
            try:
                host = Host.objects(host_query).get()
                host_network = HostNetwork.objects(_get_hosts_network_query([host.uuid])).get()
                project = enabled_projects[host.project]
            except DoesNotExist:
                return

            dns_config = _get_host_dns_config(host, host_network, project)
            statbox_logger = dns_logger(walle_action="fix-dns", **host_context(host))
            if not _is_outdated_dns_records(host, project, dns_config, statbox_logger):
                _clear_dns_fixer_errors(host)
                return

            try:
                fixer_timeout = config.get_value("dns_fixer.host_processing_timeout")
                with gevent.Timeout(fixer_timeout, exception=DnsFixerTimeout(fixer_timeout)):
                    _fix_host_dns_records(host, host_network, project, dns_config, statbox_logger)

            except AutomationDisabledError:
                pass
            except Exception:
                Host.objects(host_query).update(set__dns__error_time=timestamp(), multi=False)
                raise
    except Exception as e:
        log.exception("%s: Failed to check/fix DNS records: %s", host.human_id(), e)


def _fix_host_dns_records(host, host_network, project, dns_config, statbox_logger):
    stopwatch = StopWatch()
    log.debug("%s: Checking DNS records...", host.human_id())
    statbox_logger.log(decision="check-records")

    host_query = dict(
        project=host.project, state__in=HostState.ALL_ASSIGNED, status__ne=HostStatus.INVALID, name=host.name
    )

    try:
        dns_records = network.get_host_dns_records(host, project, host.name, dns_config.switch, dns_config.mac)
    except (InvalidHostConfiguration, network.NoNetworkOnSwitch) as e:
        _handle_invalid_host_configuration(host, e, host_query, statbox_logger, stopwatch)
        return

    dns_operations = get_operations_for_dns_records(dns_records, statbox_logger)

    if not dns_operations:
        _handle_no_dns_operations(host, host_query, dns_config, statbox_logger, stopwatch)
        return

    log.warning("%s: DNS records are broken. Fixing...", host.human_id())
    statbox_logger.log(decision="fix-records")

    # N.B. no test checks that this call actually happen.
    decision = Decision(WalleAction.FIX_DNS, reason="DNS records are broken.")
    dns_automation(host.project).register_automated_failure(host, decision)

    dns_client = host.get_dns()
    try:
        _assert_ips_match(host.ips, dns_records)
    except IPsMismatch as e:
        _report_ips_mismatch_and_reraise(host, e, stopwatch)

    audit_entry = audit_log.on_fix_dns_records(
        host.project,
        host.inv,
        host.name,
        host.uuid,
        scenario_id=host.scenario_id,
        records=[dns_record._asdict() for dns_record in dns_records],
        operations=[operation.to_dict() for operation in dns_operations],
        **dns_config.get_audit_log_kwargs()
    )

    with audit_entry:
        stopwatch.split()

        try:
            dns_client.apply_operations(dns_operations)
        except DnsApiError as e:
            _log_dns_api_error_and_reraise(host, e, dns_operations, stopwatch)

    _handle_successful_dns_fix(host, dns_config, dns_operations, host_query, statbox_logger, stopwatch)


def _handle_successful_dns_fix(host, dns_config, dns_operations, host_query, statbox_logger, stopwatch):
    update_kwargs = dict(
        dns_config.get_update_host_kwargs(),
        set__dns__check_time=timestamp(),
        set__dns__update_time=timestamp(),
        unset__dns__error_time=True,
        **host.set_messages_kwargs(dns_fixer=None)
    )
    host.modify(host_query, **update_kwargs)

    log.info("%s: DNS records has been fixed.", host.human_id())
    statbox_logger.log(decision="records-fixed")

    stats.increment_counter(("dns_fixer", "fixed", "count"))
    stats.add_sample(("dns_fixer", "fixed", "operations_number"), len(dns_operations), IntegerLinearHistogram)
    stats.add_sample(("dns_fixer", "fixed", "processing_time"), stopwatch.reset())


def _clear_dns_fixer_errors(host):
    if "dns_fixer" in host.messages:
        host.set_messages(dns_fixer=None)


def _handle_invalid_host_configuration(host, e, host_query, statbox_logger, stopwatch):
    log.warning("%s: Failed to check DNS records: %s", host.human_id(), e)
    statbox_logger.log(decision="invalid-host-configuration", reason=str(e))

    update_kwargs = dict(
        unset__dns__switch=True,
        unset__dns__mac=True,
        unset__dns__project=True,
        unset__dns__vlans=True,
        unset__dns__vlan_scheme=True,
        unset__dns__ips=True,
        set__dns__check_time=timestamp(),
        unset__dns__error_time=True,
        **host.set_messages_kwargs(dns_fixer=[HostMessage.error("Failed to check DNS records: {}.".format(e))])
    )
    host.modify(host_query, **update_kwargs)

    stats.increment_counter(("dns_fixer", "invalid", "count"))
    stats.add_sample(("dns_fixer", "invalid", "processing_time"), stopwatch.reset())


def _handle_no_dns_operations(host, host_query, dns_config, statbox_logger, stopwatch):
    log.debug("%s: DNS records are OK.", host.human_id())
    statbox_logger.log(decision="records-ok")

    update_kwargs = dict(
        dns_config.get_update_host_kwargs(),
        set__dns__check_time=timestamp(),
        unset__dns__error_time=True,
        **host.set_messages_kwargs(dns_fixer=None)
    )
    host.modify(host_query, **update_kwargs)

    stats.increment_counter(("dns_fixer", "records-ok", "count"))
    stats.add_sample(("dns_fixer", "records-ok", "processing_time"), stopwatch.reset())


def _log_dns_api_error_and_reraise(host, e, dns_operations, stopwatch):
    stats.add_sample(("dns_fixer", "failed", "processing_time"), stopwatch.split())
    stats.increment_counter(("dns_fixer", "failed", "count"))
    log.debug(
        "%s: DNS API primitives failed to apply: %s",
        host.human_id(),
        repr([op.to_slayer_api_primitive() for op in dns_operations]),
    )
    if e.response:
        log.debug("%s: DNS API returned an error: %s", host.human_id(), repr(e.response.content))

    host.set_messages(dns_fixer=[HostMessage.error("DNS API error: {}".format(e))])
    raise Error("DNS API error: {}", e)


def _report_ips_mismatch_and_reraise(host, e, stopwatch):
    stats.add_sample(("dns_fixer", "failed", "processing_time"), stopwatch.split())
    stats.increment_counter(("dns_fixer", "ips_mismatch", "count"))
    log.error("Cannot fix DNS records of %s: %s", host.human_name(), e)
    message = "Cannot fix DNS records of {}: {}".format(host.human_name(), e)
    host.set_messages(dns_fixer=[HostMessage.error(message)])
    raise e


def _assert_ips_match(ips_from_agent, dns_records):
    if not ips_from_agent:
        return
    ips_from_agent = {ip_address(addr) for addr in ips_from_agent}

    ips_from_dns = set()
    for record in dns_records:
        for val in record.value:
            ips_from_dns.add(ip_address(val))

    if ips_from_agent != ips_from_dns:
        raise IPsMismatch(ips_from_agent, ips_from_dns)
