#!/usr/bin/env python3

import sys
import os
import re
import ssl
import urllib.request
import urllib.parse
import urllib.error
import socket
import fcntl
import datetime
import time
import subprocess
import json
import base64
import binascii
import argparse


# ==================================================================================================================


def Log(String, File=sys.stdout):
    print("{} - {}".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), String), file=File)


def Fatal(String, ExitCode=1):
    Log(String, File=sys.stderr)
    exit(ExitCode)


class Exec:
    def __init__(self, Command, Env={"LANG": "C"}, ErrToOut=False, Stdin=None):
        self.pid = -1
        self.out = ""
        self.err = ""
        self.status = 1
        try:
            self.p = subprocess.Popen(
                Command,
                env=Env,
                stdin=subprocess.PIPE if Stdin else None,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT if ErrToOut else subprocess.PIPE,
                close_fds=True,
                preexec_fn=lambda: os.setpgid(os.getpid(), os.getpid())
            )
            self.pid = self.p.pid
            try:
                if Stdin:
                    self.p.stdin.write(Stdin.encode("utf-8"))
                    self.p.stdin.close()
                self.out = self.p.stdout.read().strip().decode("utf-8")
                self.err = "" if ErrToOut else self.p.stderr.read().strip().decode("utf-8")
            except KeyboardInterrupt:
                os.killpg(self.pid, 2)
            self.p.wait()
            self.status = self.p.returncode
        except OSError as e:
            Fatal("Got exception running {}: {}".format(Command, e))


class GetURL:
    def __init__(self, URL, Headers={}, Data=None, Timeout=10, Verify=False, Method=None, CACert=None, ClientCert=None):
        Ctx = ssl.create_default_context()
        if not Verify:
            Ctx.check_hostname = False
            Ctx.verify_mode = ssl.CERT_NONE
        elif CACert:
            Ctx.verify_mode = ssl.CERT_OPTIONAL
            Ctx.check_hostname = True
            Ctx.load_verify_locations(cafile=CACert)
        if ClientCert:
            if isinstance(ClientCert, tuple):
                Ctx.load_cert_chain(ClientCert[0], keyfile=ClientCert[1])
            elif isinstance(ClientCert, str):
                Ctx.load_cert_chain(ClientCert)
        self.code = 0
        self.err  = ""
        self.text = ""
        try:
            if isinstance(Data, dict):
                Data = urllib.parse.urlencode(Data)
            if isinstance(Data, str):
                Data = Data.encode("ascii")
            Req  = urllib.request.Request(URL, data=Data, headers=Headers, method=Method)
            Resp = urllib.request.urlopen(Req, timeout=Timeout, context=Ctx)
            self.text = Resp.read().decode("utf-8")
            self.code = Resp.getcode()
        except urllib.error.HTTPError as e:
            self.code = e.code
            self.text = e.read().decode("utf-8")
            self.err  = e.reason
        except urllib.error.URLError as e:
            self.err = e.reason
        except Exception as e:
            self.err = str(e)


def GetHostname():
    return socket.gethostname()


# ==============================================================================================================


def GetFileData(FileName):
    Data = ""
    try:
        File = open(FileName, 'r')
        Data = File.read()
        File.close()
    except (IOError, OSError) as e:
        if e.errno == 2 or e.errno == 3:
            pass
        else:
            raise
    return Data


def PutFileData(FileName, Data, Mode=0o644, Atomic=False):
    TmpFile = "{}.tmp".format(FileName)
    try:
        File = os.open(TmpFile if Atomic else FileName, os.O_WRONLY | os.O_CREAT, Mode)
        os.write(File, Data.encode("utf-8"))
        os.close(File)
        if Atomic:
            os.replace(TmpFile, FileName)
    except OSError as e:
        if e.errno == 3:
            pass
        else:
            raise


class LockFile:
    def __init__(self, FileName):
        self.FileName = FileName
        self.File = None

    def Lock(self, Block=False):
        try:
            self.File = open(self.FileName, 'a')
            fcntl.lockf(self.File, fcntl.LOCK_EX | (0 if Block else fcntl.LOCK_NB))
        except Exception as e:
            Log("Cannot lock {}: {}".format(self.FileName, e))
            return False
        return True

    def UnLock(self):
        try:
            self.File.close()
            os.unlink(self.FileName)
        except Exception:
            return False
        return True

    def __del__(self):
        self.UnLock()


# ==============================================================================================================


def SslGetModulus(Data, IsCert=False):
    R = Exec(["/usr/bin/openssl", "x509" if IsCert else "rsa", "-noout", "-modulus"], Stdin=Data)
    if R.status != 0:
        Fatal("Failed to get modulus for {} ({}): {}, {}".format(repr(Data), R.status, R.out, R.err))
    return R.out


def SslCertExpirationTime(Path):
    R = Exec(["/usr/bin/openssl", "x509", "-in", Path, "-noout", "-enddate"])
    if R.status != 0:
        Fatal("Failed to get expiration date for {} ({}): {}, {}".format(Path, R.status, R.out, R.err))
    try:
        return time.mktime(time.strptime(R.out.split("=")[1], "%b %d %H:%M:%S %Y %Z"))
    except (IndexError, ValueError):
        Fatal("Bad time for expiration date format (cert={}): {}".format(Path, R.out))


ReGetCAData = re.compile("\s+certificate-authority-data:\s*([A-Za-z0-9=]+)\s+")
def KuberGetCA(Server):
    URL = "{}/api/v1/namespaces/kube-public/configmaps/cluster-info?timeout=8s".format(Server)
    R = GetURL(URL, Headers={"Accept": "application/json, */*"})
    if R.code != 200:
        Fatal("Failed to get CA data from master: {}".format(R.err))
    try:
        Data = json.loads(R.text)
        M = ReGetCAData.search(Data["data"]["kubeconfig"])
        if not M:
            Fatal("Failed to get CA data from master: {}".format(Data["data"]["kubeconfig"]))
        return base64.b64decode(M.group(1), validate=True).decode("utf-8").strip()
    except (json.decoder.JSONDecodeError, KeyError, binascii.Error) as e:
        Fatal("Failed to decode CA data ({}): {}".format(e, R.text))


def KuberSignCert(Csr, Hostname, Server, JoinToken, CAFile=None, Timeout=60):
    URL = "{}/apis/certificates.k8s.io/v1/certificatesigningrequests?fieldManager=kubectl-create".format(Server)
    Csr64 = base64.b64encode(Csr.strip().encode("utf-8")).decode("utf-8")
    Data = json.dumps({
        "apiVersion": "certificates.k8s.io/v1",
        "kind": "CertificateSigningRequest",
        "metadata": {
            "name": Hostname
        },
        "spec": {
            "groups": ["system:authenticated"],
            "request": Csr64,
            "signerName": "kubernetes.io/kube-apiserver-client-kubelet",
            "usages": ["key encipherment", "digital signature", "client auth"]
        }
    }, separators=(",", ":"))
    Headers = {
        "Accept": "application/json, */*",
        "Content-Type": "application/json",
        "Authorization": "Bearer {}".format(JoinToken)
    }
    R = GetURL(URL, Headers=Headers, Data=Data, Verify=(CAFile is not None), CACert=CAFile, Method="POST")
    if R.code == 201:
        Log("Created certificate request at kuber master {}".format(Server))
        time.sleep(0.5)
    elif R.code == 409:
        Log("Certificate request already exit".format(Server))
    else:
        Fatal("Failed to send certificate request to {} ({}): {}".format(Server, R.code, R.err))

    URL = "{}/apis/certificates.k8s.io/v1/certificatesigningrequests/{}".format(Server, Hostname)
    MaxTime = time.time() + Timeout
    while time.time() < MaxTime:
        R = GetURL(URL, Headers=Headers, Verify=(CAFile is not None), CACert=CAFile)
        if R.code == 200:
            try:
                Data = json.loads(R.text)
                return base64.b64decode(Data["status"]["certificate"], validate=True).decode("utf-8").strip()
            except (json.decoder.JSONDecodeError, KeyError, binascii.Error) as e:
                Log("Failed to decode kuber master at {} response ({}): {}".format(Server, e, R.text))
        time.sleep(2)
    Fatal("Failed to wait for certificate from the master {}".format(Server))


# ==============================================================================================================


def SetCerts(Server, JoinToken, CertDir, Force=False):
    if not Server:
        Fatal("Bad server: {}".format(Server))
    if not os.path.isdir(CertDir):
        Fatal("No such directory: {}".format(CertDir))
    Hostname = GetHostname()
    CAFile = os.path.join(CertDir, "ca.crt")
    KeyFile = os.path.join(CertDir, "kubelet.key")
    CertFile = os.path.join(CertDir, "kubelet.crt")
    BorderExpirationTime = time.time() + 7*86400

    if not os.path.exists(KeyFile):
        Force = True

    if Force:
        Log("Generating RSA key {}".format(KeyFile))
        R = Exec(["/usr/bin/openssl", "genrsa", "2048"])
        if R.status != 0:
            Fatal("Failed to generate private key {} ({}): {}, {}".format(KeyFile, R.status, R.out, R.err))
        Key = R.out.strip()
        PutFileData(KeyFile, Key, Mode=0o600, Atomic=True)
    else:
        Log("Found RSA key {}".format(KeyFile))
        Key = GetFileData(KeyFile)

    CA = KuberGetCA(Server)
    CALocal = GetFileData(CAFile).strip()
    if CA == CALocal:
        Log("Local CA ({}) == CA from master (len = {})".format(CAFile, len(CA)))
    else:
        Log("Local CA (len = {}) != CA from master (len = {}). Replacing it".format(len(CALocal), len(CA)))
        PutFileData(CAFile, CA, Atomic=True)
        Force = True

    if os.path.exists(CertFile):
        CertTime = SslCertExpirationTime(CertFile)
        Log("Certificate {} will expire in {} days (expiration time = {})".format(
            CertFile,
            int((CertTime - time.time()) / 86400),
            time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(CertTime))
        ))
        if CertTime < BorderExpirationTime:
            Force = True
    else:
        Log("No certificate file at {}".format(CertFile))

    KeyModulus = SslGetModulus(Key, IsCert=False)
    if not os.path.exists(CertFile):
        Force = True
    elif KeyModulus != SslGetModulus(GetFileData(CertFile), IsCert=True):
        Log("Key {} does not correspond to certificate {}".format(KeyFile, CertFile))
        Force = True

    if Force:
        Log("Creating certificate request")
        R = Exec([
            "/usr/bin/openssl", "req", "-new",
            "-key", KeyFile,
            "-subj", "/CN=system:node:{}/O=system:nodes".format(Hostname),
            "-days", "365"
        ])
        if R.status != 0:
            Fatal("Failed to generate csr, key file={} ({}): {}, {}".format(KeyFile, R.status, R.out, R.err))
        Cert = KuberSignCert(R.out, Hostname, Server, JoinToken, CAFile=CAFile)
        if KeyModulus != SslGetModulus(Cert, IsCert=True):
            Fatal("Got certificate from kuber master {} which do not correspond to my private key {}: {}".format(
                Server,
                repr(Key),
                repr(Cert)
            ))
        Log("Writing certificate to {}".format(CertFile))
        PutFileData(CertFile, Cert, Atomic=True)


def WaitClusterJoin(Server, JoinToken, CertDir, Timeout=600):
    Hostname = GetHostname()
    CAFile = os.path.join(CertDir, "ca.crt")
    KeyFile = os.path.join(CertDir, "kubelet.key")
    CertFile = os.path.join(CertDir, "kubelet.crt")

    URL = "{}/api/v1/nodes/{}?timeout=10s".format(Server, Hostname)
    MaxTime = time.time() + Timeout
    if JoinToken:
        CAFile = CAFile if os.path.exists(CAFile) else None
        Headers = {
            "Accept": "application/json, */*",
            "Authorization": "Bearer {}".format(JoinToken)
        }
        while time.time() < MaxTime:
            R = GetURL(URL, Headers=Headers, Verify=(CAFile is not None), CACert=CAFile)
            if R.code == 404:
                Log("Kuber master ({}): {} not found".format(Server, Hostname))
                time.sleep(2)
            elif R.code == 200:
                Log("Kuber master ({}): {} found!".format(Server, Hostname))
                return True
        Log("Waiting for {} to join the cluster (master = {}) has failed: timed out ({}s)".format(Hostname, Server, Timeout))
    elif os.path.exists(CAFile) and os.path.exists(KeyFile) and os.path.exists(CertFile):
        Headers = {"Accept": "application/json, */*"}
        while time.time() < MaxTime:
            R = GetURL(URL, Headers=Headers, Verify=True, CACert=CAFile, ClientCert=(CertFile, KeyFile))
            if R.code == 404:
                Log("Kuber master ({}): {} not found".format(Server, Hostname))
                time.sleep(2)
            elif R.code == 200:
                Log("Kuber master ({}): {} found!".format(Server, Hostname))
                return True
        Log("Waiting for {} to join the cluster (master = {}) has failed: timed out ({}s)".format(Hostname, Server, Timeout))
    else:
        Log("Failed to find needed certs and keys at {}".format(CertFile))
    return False


# ==============================================================================================================


def AddProtocol(Server):
    if Server.startswith("http://") or Server.startswith("https://"):
        return Server
    return "https://{}".format(Server)


def Cert(Args):
    JoinToken = Args.join_token if Args.join_token else \
        GetFileData(Args.join_token_file).strip() if Args.join_token_file else \
        None
    if not JoinToken:
        Fatal("Please set correct join token or join token file by -j TOKEN or -i FILE")
    SetCerts(AddProtocol(Args.server), JoinToken, Args.cert_dir, Force=Args.force)


def WaitJoin(Args):
    JoinToken = Args.join_token if Args.join_token else \
        GetFileData(Args.join_token_file).strip() if Args.join_token_file else \
        None
    WaitClusterJoin(AddProtocol(Args.server), JoinToken, Args.cert_dir, Args.timeout)


def Main():
    Formatter = lambda prog: argparse.RawDescriptionHelpFormatter(prog, max_help_position=50)
    Parser = argparse.ArgumentParser(formatter_class=Formatter, description="Configure kubelet")
    Parser.add_argument("--lock-file", type=str, default="/var/lock/kube_config.lock", metavar="LOCK", help="lock file path (default=%(default)s)")
    Parser.add_argument("-d", "--cert-dir", type=str, default="/etc/kubernetes/pki", metavar="DIR", help="directory for certificates (default=%(default)s)")
    Parser.add_argument("-s", "--server", type=str, required=True, metavar="SERVER", help="kubernetes master server")
    SubParsers = Parser.add_subparsers(title="Subcommands")

    CertParser = SubParsers.add_parser("cert", formatter_class=Formatter, description="""
    Set and check certificates and keys.
    """)
    CertParser.add_argument("-j", "--join-token", type=str, metavar="TOKEN", help="token to join kubernetes cluster")
    CertParser.add_argument("-i", "--join-token-file", type=str, metavar="FILE", help="read join token from this file")
    CertParser.add_argument("-f", "--force", action="store_true", help="force reset private key and certificate")
    CertParser.set_defaults(func=Cert)

    WaitJoinParser = SubParsers.add_parser("waitjoin", formatter_class=Formatter, description="""
    Wait for the node to join the kubernetes cluster.
    """)
    WaitJoinParser.add_argument("-j", "--join-token", type=str, metavar="TOKEN", help="token to join kubernetes cluster")
    WaitJoinParser.add_argument("-i", "--join-token-file", type=str, metavar="FILE", help="read join token from this file")
    WaitJoinParser.add_argument("-t", "--timeout", type=int, default=120, metavar="SEC", help="time to wait for join (default=%(default)d)")
    WaitJoinParser.set_defaults(func=WaitJoin)

    Args = Parser.parse_args()
    if not hasattr(Args, "func"):
        Parser.print_help()
        exit(1)

    Lock = LockFile(Args.lock_file)
    if not Lock.Lock():
        Fatal("Failed to lock", ExitCode=0)
    Args.func(Args)
    Lock.UnLock()


if __name__ == "__main__":
    Main()
