"""Hosts"""  # Will be used as category name in API reference

import ipaddress
import logging
from collections import namedtuple

import mongoengine
from cachetools.func import ttl_cache

import object_validator
import walle.host_macs
import walle.hosts
from object_validator import Dict, String, Bool
from sepelib.core import config, constants
from sepelib.core.exceptions import Error
from walle import audit_log, constants as walle_constants
from walle.authorization import iam
from walle.clients import racktables
from walle.clients.network.racktables_client import RacktablesClient
from walle.errors import RequestValidationError
from walle.host_network import HostNetwork
from walle.hosts import Host, HostStatus, HostMessage
from walle.models import timestamp
from walle.util import db_cache, net
from walle.util.api import api_handler, api_response
from walle.util.misc import drop_none, fix_mongo_set_kwargs
from walle.util.mongo import SECONDARY_LOCAL_DC_PREFERRED

log = logging.getLogger(__name__)


_REPORT_PROCESSING_TIMEOUT = 5  # we've got a lot of slow requests during deploy, and even then 5 sec is 94%


class ReportError(Error):
    pass


class HostAuthenticationFailed(ReportError):
    def __init__(self, hostname):
        super().__init__("Reject report: host {} authentication failed.", hostname)


class UnknownHost(ReportError):
    def __init__(self, hostname):
        super().__init__("Ignore report: host {} is not registered in Wall-E.", hostname)


@api_handler(
    "/hosts/<host_name>/agent-report",
    "PUT",
    {
        "properties": {
            "macs": {"description": "Host's MACs (mac -> active)"},
            "ips": {"type": "array", "items": {"type": "string"}, "description": "Host's IP addresses"},
            "switches": {
                "type": "array",
                "items": {
                    "properties": {
                        "switch": {"type": "string", "description": "Switch name"},
                        "port": {"type": "string", "description": "Port name"},
                        "time": {"type": "integer", "description": "Actualization time"},
                    },
                    "required": ["switch", "port"],
                    "additionalProperties": False,
                },
                "description": "Host's switches",
            },
            "errors": {"type": "array", "items": {"type": "string"}, "description": "A list of errors"},
            "version": {"type": "string", "description": "Agent version"},
        },
        "required": ["version"],
        "additionalProperties": False,
    },
    include_to_api_reference=False,
    # NOTE(rocco66): waiting for IAM in juggler and walle agent, https://st.yandex-team.ru/WALLE-4385
    # iam_permissions=iam.AgentApiIamPermission(),
    iam_permissions=iam.AnyoneApiIamPermission(),
)
def agent_report(host_name, request):
    """Host info reporting API used by Wall-E agent."""

    try:
        _process_request(host_name, InputData(request))
    except ReportError as e:
        return _make_response(host_name, str(e))
    else:
        return _make_response(host_name, "The report has been processed.")


def _process_request(host_name, request):
    """

    :type request: InputData
    """
    request_processors = [ErrorReporter(host_name)]

    fields = ("uuid", "inv", "name", "project", "macs", "location", "ips")
    try:
        host = (
            Host.objects(name=host_name, status__ne=HostStatus.INVALID, read_preference=SECONDARY_LOCAL_DC_PREFERRED)
            .only(*fields)
            .get()
        )
        host_network = HostNetwork.get_or_create(host.uuid)

    except (mongoengine.DoesNotExist, mongoengine.NotUniqueError):
        request_processors.extend(
            [
                UnknownHostMacsSaver(host_name),
                UnknownHostAuthenticator(host_name),
            ]
        )
    else:
        request_processors.extend(
            [
                HostMacsSaver(host.name),
                HostAuthenticator(host),
                ActiveMacSaver(host, host_network),
                IPsSaver(host, host_network),
                SwitchesSaver(host, host_network),
                AgentVersionSaver(host),
                ErrorSaver(host),
            ]
        )

    for processor in request_processors:
        processor.process(request)


def _make_response(host_name, message):
    return api_response(drop_none({"result": message, "other_stands": _get_host_development_stands(host_name) or None}))


def _get_host_development_stands(host_name):
    return [api_url for api_url, stand_hosts in _get_development_stands_hosts().items() if host_name in stand_hosts]


@ttl_cache(maxsize=1, ttl=constants.MINUTE_SECONDS)
def _get_development_stands_hosts():
    stands_hosts = {}

    for stand_name, stand_config in config.get_value("development_stands").items():
        cache_time, stand_hosts = db_cache.get_value("stand_hosts:" + stand_name)

        if cache_time is None:
            log.error("Host list for '%s' stand is not cached.", stand_name)
            continue

        if timestamp() - cache_time >= constants.HOUR_SECONDS:
            log.error("Using an outdated list of hosts for '%s' stand.", stand_name)

        stands_hosts[stand_config["url"]] = set(stand_hosts)

    return stands_hosts


class InputData:
    SwitchInfo = namedtuple("SwitchInfo", ["switch", "port", "time"])

    def __init__(self, request):
        request["macs"] = self._validate_macs(request.get("macs", {}))
        request["ips"] = self._validate_ips(request.get("ips", []))
        request["timestamp"] = timestamp()
        request["version"] = request.get("version", None)

        self.request = request

    @staticmethod
    def _validate_macs(macs):
        try:
            macs = object_validator.validate("macs", macs, Dict(String(), Bool()))
            return {net.format_mac(mac): active for mac, active in macs.items()}
        except (object_validator.ValidationError, Error) as e:
            raise RequestValidationError("{}", e)

    @staticmethod
    def _validate_ips(ips):
        try:
            for ip in ips:
                ipaddress.ip_address(ip)
        except ValueError as e:
            raise RequestValidationError("{}", e)
        return ips

    @property
    def agent_version(self):
        return self.request["version"]

    @property
    def switches(self):
        return [self.SwitchInfo(**switch) for switch in self.request.get("switches", [])]

    @property
    def macs(self):
        return self.request["macs"]

    @property
    def active_macs(self):
        return [mac for mac, active in self.macs.items() if active]

    @property
    def ips(self):
        return self.request["ips"]

    @property
    def errors(self):
        return self.request.get("errors", [])

    @property
    def timestamp(self):
        return self.request["timestamp"]


def mk_record(message, host_name, request, *args):
    return _make_tuple("%s [%s]: {}.".format(message), host_name, request.agent_version, *args)


def _make_tuple(*args):
    return tuple(args)


class Processor:
    def process(self, request):
        """

        :type request: InputData
        """

        raise NotImplementedError


class HostMacsSaver(Processor):
    def __init__(self, host_name):
        self.host_name = host_name

    def process(self, request):
        if request.macs:
            walle.host_macs.save_macs_info(self.host_name, request.macs.keys())


class UnknownHostMacsSaver(HostMacsSaver):
    def process(self, request):
        if net.is_valid_fqdn(self.host_name):
            super().process(request)
        else:
            log.info(*mk_record("Got a suspicious hostname", self.host_name, request))


class UnknownHostAuthenticator(Processor):
    def __init__(self, host_name):
        self.host_name = host_name

    def process(self, request):
        raise UnknownHost(self.host_name)


class AgentVersionSaver(Processor):
    def __init__(self, host):
        self.host = host

    def process(self, request):
        agent_version = request.agent_version
        if not agent_version:
            return
        if agent_version != self.host.agent_version:
            Host.objects(uuid=self.host.uuid).modify(**fix_mongo_set_kwargs(set__agent_version=agent_version))


class ErrorReporter(Processor):
    def __init__(self, host_name):
        self.host_name = host_name

    def process(self, request):
        for error in request.errors:
            log.info(*mk_record(error.replace("%", "%%"), self.host_name, request))


class ErrorSaver(Processor):
    def __init__(self, host):
        self.host = host

    def process(self, request):
        messages = [HostMessage.error(error) for error in request.errors]
        self.host.set_messages(agent=messages)
        if messages:
            self.host.modify(set__walle_agent_errors_flag=True)
        else:
            self.host.modify(unset__walle_agent_errors_flag=True)


class HostAuthenticator(Processor):
    def __init__(self, host):
        self.host = host

    def process(self, request):
        active_macs = request.active_macs

        if not active_macs:
            log.info(*mk_record("No active MACs detected", self.host.name, request))

        # Authenticate request by host name + MAC pair
        if set(active_macs).isdisjoint(self.host.macs or []):
            if active_macs and self.host.macs:
                log.info(
                    *mk_record(
                        "reported invalid active MACs: %s instead of one of %s",
                        self.host.name,
                        request,
                        ", ".join(active_macs),
                        ", ".join(self.host.macs),
                    )
                )

            raise HostAuthenticationFailed(self.host.name)


class ActiveMacSaver(Processor):
    def __init__(self, host, host_network):
        self.host = host
        self.host_network = host_network
        self.request_time = timestamp()

    def process(self, request):
        active_macs = request.active_macs

        if len(active_macs) != 1:
            log.info(
                *mk_record("reported more than one active MAC: %s", self.host.name, request, ", ".join(active_macs))
            )

        else:
            walle.hosts.update_agent_active_mac(self.host, self.host_network, active_macs[0], request.timestamp)


class IPsSaver(Processor):
    def __init__(self, host, host_network):
        self.host = host
        self.host_network = host_network

    def process(self, request):
        ips = sorted(request.ips) if request.ips else None
        if not ips:
            return

        if ips != self.host.ips:
            Host.objects(uuid=self.host.uuid, ips=self.host.ips).modify(**fix_mongo_set_kwargs(set__ips=ips))
            audit_log.on_ips_changed(self.host, ips).complete()
        HostNetwork.objects(uuid=self.host_network.uuid, ips=self.host_network.ips).modify(
            **fix_mongo_set_kwargs(set__ips=ips, set__ips_time=request.timestamp)
        )


class SwitchesSaver(Processor):
    def __init__(self, host, host_network):
        self.host = host
        self.host_network = host_network

    def process(self, request):
        switches = request.switches

        if len(switches) > 1:
            self._log_many_switches(switches, request)

        elif len(switches) == 1:
            self._save_switch(request, *switches[0])

    def _save_switch(self, request, switch, port, actualization_time):
        try:
            switch, port = RacktablesClient.shorten_switch_port_name(switch, port, trust_switch_name=True)
        except racktables.InvalidSwitchPortError as e:
            log.info(*mk_record("reported a possibly invalid switch/port: %s", self.host.name, request, str(e)))
        else:
            walle.hosts.update_network_location(
                self.host, self.host_network, switch, port, actualization_time, walle_constants.NETWORK_SOURCE_LLDP
            )

    def _log_many_switches(self, switches, request):
        log.log(
            logging.INFO,
            *mk_record(
                "reported more than one switch/port: %s",
                self.host.name,
                request,
                ", ".join(s.switch + "/" + s.port for s in switches),
            )
        )
