#!/usr/bin/env python

"""Wall-E agent"""

# Attention: The script must be written to work on both Python 2 and Python 3

from __future__ import unicode_literals
from __future__ import print_function

import argparse
import contextlib
import errno
import fcntl
import functools
import inspect
import json
import logging
import os
import pprint
import signal
import socket
import subprocess
import sys
import threading
import time
import ipaddress

from xml.etree import ElementTree
from pyroute2 import IPRoute

try:
    from agent.constants import *
except ImportError:
    version = "development"

PY2 = sys.version_info < (3,)

if PY2:
    str = unicode

    import urllib2 as urllib
    from httplib import BadStatusLine

    import codecs
    sys.stdout = codecs.getwriter("utf-8")(sys.stdout)
    sys.stderr = codecs.getwriter("utf-8")(sys.stderr)
else:
    import urllib.request as urllib
    from http.client import BadStatusLine

log = logging.getLogger("agent.py")


class Error(Exception):
    def __init__(self, *args):
        message, args = args[0], args[1:]
        super(Error, self).__init__(message.format(*args) if args else message)


class CommandExecutionError(Error):
    pass


class CommandExecutionTimeoutError(CommandExecutionError):
    def __init__(self, command):
        super(CommandExecutionTimeoutError, self).__init__(
            "Command `{}` failed: The execution timeout has exceeded.", command)


class CommandFailedError(CommandExecutionError):
    def __init__(self, command, return_code, error):
        super(CommandFailedError, self).__init__(
            "Command `{}` failed with {} return code: {}.", command, return_code, error)
        self.return_code = return_code


class LldpDaemonFailure(CommandExecutionError):
    pass


class LldpDaemonMalfunction(Error):
    pass


class LoggingHandler(logging.Handler):
    def __init__(self):
        self.messages = []
        super(LoggingHandler, self).__init__()

    def emit(self, record):
        self.acquire()

        try:
            message = self.format(record)
            self.messages.append(message)
            print(message, file=sys.stderr)
        except Exception:
            self.handleError(record)
        finally:
            self.release()


def drop_none(d):
    return {k: v for k, v in d.items() if v is not None}


def path_exists(path):
    try:
        os.lstat(path)
    except EnvironmentError as e:
        if e.errno == errno.ENOENT:
            return False

        raise Error("Failed to lstat() '{}': {}.", path, e)

    else:
        return True


def read_file(path, check_existence=False):
    try:
        with open(path) as f:
            return f.read().strip()
    except EnvironmentError as e:
        if check_existence and e.errno == errno.ENOENT:
            return

        raise Error("Error while reading '{}': {}.", path, e)


def read_dict(path, check_existence=False):
    data = read_file(path, check_existence=check_existence)
    if data:
        return dict(line.split("=", 1) for line in data.splitlines())
    else:
        return {}


def eintr_retry(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        while True:
            try:
                return func(*args, **kwargs)
            except EnvironmentError as e:
                if e.errno != errno.EINTR:
                    raise e

    return wrapper


def get_agent_executable_path(follow_symlinks=True):
    if getattr(sys, 'frozen', False):  # arcadia binary, py2exe, PyInstaller, cx_Freeze
        path = os.path.abspath(sys.executable)
    else:
        path = inspect.getabsfile(get_agent_executable_path)

    if follow_symlinks:
        path = os.path.realpath(path)

    return path


def kill_process(pid, sig):
    """Kills the specified process.

    :returns True if process is already terminated.
    """

    killed = False

    try:
        os.kill(pid, sig)
    except EnvironmentError as e:
        if e.errno == errno.ESRCH:
            killed = True
        else:
            raise

    return killed


def terminate_process(pid, kill_after=5):
    """Terminates the process gracefully.

    :returns True if process has been terminated or False if it is already terminated.
    """

    killed = False
    termination_timeout = time.time() + kill_after

    while True:
        if kill_process(pid, signal.SIGTERM):
            return killed

        killed = True

        if time.time() >= termination_timeout:
            break

        time.sleep(0.1)

    kill_process(pid, signal.SIGKILL)

    return killed


def acquire_agent_lock():
    fd = -1

    try:

        fd = eintr_retry(os.open)(get_agent_executable_path(), os.O_RDONLY)

        try:
            eintr_retry(fcntl.flock)(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
        except EnvironmentError as e:
            if e.errno == errno.EWOULDBLOCK:
                raise Error("the agent is already running")

            raise e
    except Exception as e:
        if fd != -1:
            os.close(fd)

        raise Error("Failed to acquire the agent lock: {}.", e)

    return fd


def release_agent_lock(fd):
    os.close(fd)


def run_command(cmd, timeout):
    start_event, stop_event, killed_event = threading.Event(), threading.Event(), threading.Event()

    def terminate_on_timeout():
        try:
            start_event.wait()
            if stop_event.wait(timeout):
                return

            if terminate_process(process.pid):
                killed_event.set()
        except Exception as e:
            log.error("Process timeout thread has crashed: {}.".format(e))

    command = " ".join(cmd)

    timeout_thread = threading.Thread(target=terminate_on_timeout, name="Kill `{}` on timeout".format(command))
    timeout_thread.daemon = True
    timeout_thread.start()

    try:
        process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                   close_fds=True)
        start_event.set()

        try:
            stdout, stderr = process.communicate()
        except:
            terminate_process(process.pid)
            raise
    except Exception as e:
        raise CommandExecutionError("Failed to execute `{}` command: {}.", command, e)
    finally:
        stop_event.set()
        start_event.set()
        timeout_thread.join()

    if killed_event.is_set():
        raise CommandExecutionTimeoutError(command)

    stdout = stdout.decode("utf-8")
    stderr = stderr.decode("utf-8")

    if process.returncode:
        raise CommandFailedError(command, process.returncode, (stderr or stdout).strip())

    return stdout


def run_lldpctl():
    try:
        return run_command(["lldpctl", "-f", "xml"], timeout=20)
    except CommandFailedError as e:
        if e.return_code == 1:
            raise LldpDaemonFailure(str(e))
        raise


def recover_lldpd(reason):
    """
    lldpd is very unreliable:
    * It fails to start if it hasn't deleted the socket on previous termination.
    * It doesn't survive SIGSTOP->SIGCONT.
    * It doesn't survive EMFILE.

    So it's reasonable to try to recover it on lldpctl errors.
    """

    log.info("lldpd malfunction detected: %s. Restarting lldpd...", reason)

    try:
        # Stopping takes ~1 minute if daemon doesn't respond to SIGTERM
        run_command(["service", "lldpd", "stop"], timeout=90)

        socket_path = "/var/run/lldpd.socket"
        try:
            os.unlink(socket_path)
        except EnvironmentError as e:
            if e.errno != errno.ENOENT:
                log.warn("Failed to delete '%s': %s.", socket_path, e)

        run_command(["service", "lldpd", "start"], timeout=60)
    except Exception as e:
        raise Error("lldpd malfunction detected and restart failed: {}".format(e))

    # have to wait for initialization after service started
    initialize_timeout = 15
    for _ in range(initialize_timeout):
        if os.path.exists(socket_path):
            break
        time.sleep(1)
    else:
        raise Error("lldpd failed to initialize during given timeout ({} sec)", initialize_timeout)


def is_container_cgroup():
    """Check cgroups of the PID 1. We are in a container if any cgroups applied to it."""
    try:
        for line in read_file("/proc/{}/cgroup".format(os.getpid())).splitlines():
            pos, cls, path = line.split(":")
            if path != "/" and ("porto" in path or "docker" in path or "lxc" in path):
                return True
    except Error:
        pass

    return False


def is_container_environment():
    """Check environment variable that some lxc wrappers set in order to indicate a container environment."""
    return os.environ.get("container") == "lxc"


def is_hypervisor_guest():
    """Check CPU flags for any one (first one) of the host's CPU's.
    We are in a guest os if cpu has the `hypervisor` flag.
    """
    try:
        for line in read_file("/proc/cpuinfo").splitlines():
            if line.startswith("flags"):
                return "hypervisor" in line
    except Error:
        pass

    return False


def is_guest_os():
    """Some nasty guys run agent.py inside their virtual machines, like in porto/docker container or inside a kvm
    guest os. Wall-E.agent shouldn't be run there."""
    return any([is_container_environment(), is_container_cgroup(), is_hypervisor_guest()])


def is_physical_device(iface_path):
    # Filter out tunnels and other virtual interfaces that have no pci devices
    device_path = os.path.join(iface_path, "device")
    if not path_exists(device_path):
        return False

    # filter out virtual devices that are actually aliases to real pci devices
    parent_device_path = os.path.join(device_path, "physfn")
    if path_exists(parent_device_path):
        return False

    # filter out cdc devices, they are most probably BMC interfaces, not actual host network
    # see https://st.yandex-team.ru/ITDC-100510 for more
    uevent_device_path = os.path.join(device_path, "uevent")
    if read_dict(uevent_device_path, check_existence=True).get("DRIVER") == "cdc_ether":
        return False

    # See http://lxr.free-electrons.com/source/include/uapi/linux/if_arp.h:
    # #define ARPHRD_ETHER  1  /* Ethernet 10Mbps */
    iface_type = read_file(os.path.join(iface_path, "type"))
    if iface_type != "1":
        return False

    return True


def rewrite_for_bonding(iface_path, mac, iface_up):
    bonding_slave_mac = read_file(os.path.join(iface_path, "bonding_slave", "perm_hwaddr"), check_existence=True)
    if bonding_slave_mac:
        iface_up = iface_up and bonding_slave_mac == mac
        return bonding_slave_mac, iface_up
    else:
        return mac, iface_up


def get_mac_address(iface_path):
    mac = read_file(os.path.join(iface_path, "address"), check_existence=True)
    iface_state = read_file(os.path.join(iface_path, "operstate"))

    # See https://www.kernel.org/doc/Documentation/networking/operstates.txt for Linux kernel state definitions
    # We mustn't check for "up" state here, because due to some issues a working interface may be in "unknown"
    # state.
    iface_up = iface_state != "down"

    mac, iface_up = rewrite_for_bonding(iface_path, mac, iface_up)

    return mac, iface_up


def get_macs():
    # See https://www.kernel.org/doc/Documentation/ABI/testing/sysfs-class-net

    macs = {}
    net_path = "/sys/class/net"

    try:
        try:
            iface_names = os.listdir(net_path)
        except Exception as e:
            raise Error("Error while reading '{}': {}.", net_path, e)

        for iface_name in iface_names:
            iface_path = os.path.join(net_path, iface_name)

            # filter out any miscellaneous files
            if not os.path.isdir(iface_path):
                continue

            # Filter out tunnels and other virtual interfaces that have no pci devices
            if not is_physical_device(iface_path):
                continue

            mac, iface_up = get_mac_address(iface_path)

            if not mac:
                continue

            macs[mac] = iface_up or macs.get(mac, False)

        if not macs:
            raise Error("Unable to find any MAC addresses among the following network interfaces: {}.",
                        ", ".join(iface_names))
    except Exception as e:
        log.error("Failed to determine available MAC addresses: %s", e)
        macs = None

    return macs


def get_switch_info(try_recover=True):
    try:
        try:
            return get_lldp_info()
        except (CommandExecutionTimeoutError, LldpDaemonFailure, LldpDaemonMalfunction) as e:
            if try_recover:
                recover_lldpd(reason=str(e))
                return get_switch_info(try_recover=False)
            else:
                raise
    except Exception as e:
        log.error("Failed to determine host's switch/port: %s", e)


def get_lldp_from_server_info(atime, open_func=open):
    with open_func('/etc/server_info.json', 'r') as f:
        j = json.load(f)
    return [{"switch": r['switch'], "port": r['port'], "time": atime} for r in j['lldp']]


def get_lldp_info():
    # lldpd receives LLDP info from switch every 30 seconds. lldpctl retrieves the cached info from lldpd.
    # Cache TTL - 120 seconds. lldpctl doesn't return info actualization time so we may consider it as current time
    # minus cache TTL to be 100% sure.
    #
    # See https://github.com/vincentbernat/lldpd/blob/4e15b7d1cc9a5a5208c87b6341649d889f21d9d1/src/daemon/lldpd.h#L73
    #
    # #define LLDPD_TX_INTERVAL 30
    # #define LLDPD_TX_HOLD     4
    # #define LLDPD_TTL         LLDPD_TX_INTERVAL * LLDPD_TX_HOLD
    lldpd_ttl = 4 * 30
    server_info_ttl = 900   # /etc/server_info.json could be updated once in 15 minutes
    switches = None

    # try to get lldp info from server_info.json on rtc hosts
    if os.getenv("HOSTMAN", "0") == "1" and os.path.exists('/etc/server_info.json'):
        actualization_time = int(time.time()) - server_info_ttl
        try:
            switches = get_lldp_from_server_info(actualization_time)
        except Exception:
            log.exception("Cannot get lldp info from /etc/server_info.json")

    if not switches:
        actualization_time = int(time.time()) - lldpd_ttl
        output = run_lldpctl()
        switches = parse_lldpctl_output(output, actualization_time)

    return switches or None


def parse_lldpctl_output(output, actualization_time):
    root = ElementTree.fromstring(output.strip())
    if root.tag != "lldp":
        raise Error("Invalid root element: {}.", root.tag)

    ifaces = set()
    switches = set()

    for iface_element in root.iter("interface"):
        try:
            iface = iface_element.attrib["name"]
            if not iface:
                raise ValueError

            via = iface_element.attrib["via"]
            if not via:
                raise ValueError
        except (ValueError, KeyError):
            raise Error("Invalid <interface> element: {}.", iface_element.attrib)

        if via != "LLDP":
            log.error("lldpctl returned an <interface> with unknown 'via' value: %s.", via)
            continue

        chassis_elements = iface_element.findall("chassis")
        if len(chassis_elements) != 1:
            raise Error("Invalid number of <chassis> elements: {}.", len(chassis_elements))

        chassis_descrs = chassis_elements[0].findall("descr")
        cloud_svm_in_descr = False
        for descr in chassis_descrs:
            if descr.attrib.get('label') == "SysDescr" and descr.text == "CloudSvm":
                cloud_svm_in_descr = True
        if cloud_svm_in_descr:
            continue

        if iface in ifaces:
            raise Error("lldpctl returned a duplicated info for {} interface.", iface)
        ifaces.add(iface)

        chassis_names = chassis_elements[0].findall("name")
        if len(chassis_names) != 1:
            raise Error("Invalid number of <name> elements in <chassis> element: {}.", len(chassis_names))

        switch = chassis_names[0].text
        if not switch:
            raise Error("Got an empty chassis name.")

        port_elements = iface_element.findall("port")
        if len(port_elements) != 1:
            raise Error("Invalid number of <port> elements: {}.", len(port_elements))

        port_id_elements = port_elements[0].findall("id")
        if len(port_id_elements) != 1:
            raise Error("Invalid number of <id> elements in <port> element: {}.", len(port_id_elements))

        port_id_element = port_id_elements[0]

        try:
            port_type = port_id_element.attrib["type"]
            if port_type == "mac":
                log.debug("ignoring local vm port %s", port_type)
                continue
            if port_type not in ("ifname", "local"):
                log.error("lldpctl returned an <interface> with unknown port type: %s.", port_type)
                continue
        except (ValueError, KeyError):
            raise Error("Invalid <port><id> element: {}.", port_id_element.attrib)

        port = port_id_element.text
        if not port:
            raise Error("Got an empty port name.")

        switches.add((switch, port))

    if not ifaces:
        raise LldpDaemonMalfunction("No neighbors reported by lldpd.")
    return [{"switch": switch, "port": port, "time": actualization_time} for switch, port in switches]


def detect_ips():
    """Return host's fqdn or raise an Exception if not possible."""

    DENY_IFACE_NAMES = (
        '--all', 'lo', 'tun', 'tunl', 'L3',
        'lxcbr', 'mc', 'tap', 'vif', 'pflog',
        'virbr', 'plip', 'ip6tnl', 'ip4tnl', 'dummy'
    )
    IFNAME = 'IFLA_IFNAME'
    IFAADDRESS = 'IFA_ADDRESS'
    DENY_IPADDR_SUFFIX = 'badc:ab1e'

    def is_valid_iface_name(iface_name):
        for incorrect_name_prefix in DENY_IFACE_NAMES:
            if iface_name.startswith(incorrect_name_prefix):
                return False

        return True

    ips = set()

    ipr = IPRoute()
    up_ifaces_indexes = list()

    # get indexes of interfaces in 'up' state
    for link in ipr.get_links():
        if link['state'] == 'up':
            for attr in link['attrs']:
                if attr[0] == IFNAME:
                    if is_valid_iface_name(attr[1]):
                        up_ifaces_indexes.append(link['index'])

    # collect all global addresses from up interfaces
    for i in up_ifaces_indexes:
        for wtf in ipr.get_addr(index=i):
            for iface_attrs in wtf['attrs']:
                if iface_attrs[0] == IFAADDRESS:
                    addr = ipaddress.ip_address(iface_attrs[1])
                    addr_str = str(addr)
                    if addr.is_global and not addr_str.endswith(DENY_IPADDR_SUFFIX):
                        ips.add(addr_str)


    return list(ips)


def detect_hostname(ips):

    hostname = socket.getfqdn()

    fqdns = set()

    for ip in ips:
        try:
            fqdn = socket.gethostbyaddr(ip)[0]
            if sys.platform == 'cygwin':
                # on cygwin gethostbyaddr can return all DNS prefixes as an address, so we cut it here
                fqdn = fqdn.split()[0]
            fqdns.add(fqdn)
        except socket.herror:
            pass
        except socket.gaierror as ex:
            if ex.errno in not_found_errs:
                continue
            raise

    fqdns = list(fqdns)

    if hostname not in fqdns and len(fqdns) == 1 and fqdns[0].startswith(hostname):
        # Got only 1 fqdn
        hostname = fqdns[0]

    return hostname



def send_report(api_url, host_name, report):
    max_stands_num = 5
    pending_api_urls, processed_api_urls = {strip_api_url(api_url)}, set()

    opener = urllib.build_opener(urllib.HTTPHandler)
    report_data = json.dumps(report).encode("utf-8")

    while pending_api_urls:
        if len(processed_api_urls) >= max_stands_num:
            log.error("Got too many stand URLs. Ignore the following: %s.", ", ".join(pending_api_urls))
            break

        api_url = pending_api_urls.pop()
        processed_api_urls.add(api_url)

        url = "{}/v1/hosts/{}/agent-report".format(api_url, host_name)

        request = urllib.Request(url, data=report_data, headers={
            "User-Agent": "Wall-E.Agent/" + version,
            "Content-Type": "application/json",
        })
        request.get_method = lambda: "PUT"

        try:
            with contextlib.closing(opener.open(request, timeout=30)) as response:
                reply = response.read().decode("utf-8")

            try:
                reply = json.loads(reply)

                if (
                    not isinstance(reply, dict) or
                    not isinstance(reply.get("result"), str) or
                    "other_stands" in reply and (
                        not isinstance(reply["other_stands"], list) or
                        not all(map(lambda api_url: isinstance(api_url, str), reply["other_stands"]))
                    )
                ):
                    raise ValueError
            except ValueError:
                raise Error("The server returned an invalid response")

            print("The report has been successfully sent to {}. Wall-E replied: {}".format(api_url, reply["result"]))
            stand_api_urls = set(strip_api_url(api_url) for api_url in reply.get("other_stands", []))
            pending_api_urls.update(stand_api_urls - processed_api_urls)
        except Exception as e:
            if isinstance(e, BadStatusLine):
                e = "The server has closed connection before sending a valid response or refuses connections"

            log.error("Failed to report host status to %s: %s.", url, e)


def strip_api_url(url):
    return url.rstrip("/")


def parse_args():
    argv = sys.argv[1:]

    if PY2:
        # Fix UnicodeDecodeError on accidental non-ASCII symbol typing in command line
        argv = [arg.decode("utf-8") for arg in argv]

    parser = argparse.ArgumentParser(description="Wall-E agent")
    parser.add_argument("-u", "--wall-e", metavar="URL", default="https://api.wall-e.yandex-team.ru/",
                        help="use the specified Wall-E installation")
    parser.add_argument("--dry-run", action="store_true",
                        help="don't send anything to Wall-E - just print the report to stdout")
    parser.add_argument("-V", "--version", action="store_true", help="print the version and exit")

    args = parser.parse_args(argv)

    if PY2:
        for field, value in args.__dict__.items():
            if isinstance(value, bytes):
                setattr(args, field, value.decode("utf-8"))

    if args.version:
        print(version)
        sys.exit(os.EX_OK)

    return args


def main():
    args = parse_args()

    if is_guest_os():
        # Exit with error to make host admin know this is not an expected configuration.
        print("Wall-E.agent being run in a virtual environment. Wall-E do not need reports from here.",
              file=sys.stderr)
        sys.exit(os.EX_USAGE)

    handler = LoggingHandler()
    handler.setFormatter(logging.Formatter("%(message)s"))

    log.setLevel(logging.ERROR)
    log.addHandler(handler)

    try:
        lock_fd = acquire_agent_lock()

        try:
            ips = detect_ips()
            hostname = detect_hostname(ips)

            report = drop_none({
                "macs": get_macs(),
                "ips": ips,
                "switches": get_switch_info(try_recover=not args.dry_run),
                "version": version,
            })

            errors = handler.messages
            if errors:
                report["errors"] = errors

            if args.dry_run:
                print("API url: {}".format(args.wall_e))
                print("Hostname: {}".format(hostname))
                print("Report data:")
                pprint.pprint(report)
            else:
                send_report(args.wall_e, hostname, report)
        finally:
            release_agent_lock(lock_fd)
    except Error as e:
        print(str(e), file=sys.stderr)
        sys.exit(os.EX_SOFTWARE)


if __name__ == "__main__":
    main()
