# cython: language_level=3

import logging
import threading

from libc.stdint cimport uint64_t

from infra.skylib.intervals.intervals cimport IntervalController
from infra.skylib.openssh_krl.impl cimport loads as krl_loads, KRL
from infra.skylib.sysutils.time import sleep
from library.auth.key import CryptoKey, Certificate


cpdef enum CertKind:
    Insecure = 1
    Secure = 2
    Static = 4
    SshKey = 8

    Any = 15


cdef class CAStorage:
    cdef readonly list insecure_ca_files
    cdef readonly list secure_ca_files
    cdef readonly bytes krl_file
    cdef readonly bytes serveradmins_file
    cdef readonly frozenset serveradmins

    cdef object log
    cdef set static_cas
    cdef set insecure_cas
    cdef set secure_cas
    cdef KRL krl
    cdef IntervalController fail_interval
    cdef IntervalController success_interval

    def __init__(self, list insecure_ca_files, list secure_ca_files, bytes krl_file, bytes serveradmins_file = None, list static_cas = None, object log = None):
        self.insecure_ca_files = insecure_ca_files
        self.secure_ca_files = secure_ca_files
        self.krl_file = krl_file
        self.serveradmins_file = serveradmins_file
        self.krl = None
        self.fail_interval = IntervalController(
            initial=60.,
            multiplier=1.2,
            variance=.1,
            maximum=600.,
        )
        self.success_interval = IntervalController(
            initial=600.,
            multiplier=1.,
            variance=.2,
            maximum=600.,
        )

        self.log = log or logging.getLogger(__name__)

        cdef set cas = set()
        static_cas = static_cas or []
        for ca in static_cas:
            if isinstance(ca, Certificate):
                self.log.debug("loaded static key: %s", ca)
                cas.add(ca)
            else:
                for key in CryptoKey.loads(ca, log=self.log):
                    self.log.debug("loaded static key: %s", key)
                    cas.add(key)
        self.static_cas = cas
        self.insecure_cas = set()
        self.secure_cas = set()
        self.serveradmins = frozenset()

    cdef set load_cas(self, list ca_files):
        cdef cas = set()

        for filename in ca_files:
            try:
                with open(filename, 'r') as f:
                    for key in CryptoKey.load(f, comment=filename, log=self.log):
                        self.log.debug("loaded key: %s", key)
                        cas.add(key.fingerprint())
            except EnvironmentError as e:
                self.log.warning("failed to load keys from %r: %s", filename, e)

        return cas

    cdef void load_krl(self):
        cdef KRL krl
        try:
            with open(self.krl_file, 'rb') as f:
                krl = krl_loads(f.read())
        except Exception as e:
            self.log.warning("failed to load KRL from %r: %s", self.krl_file, e)
            self.krl = None
        else:
            self.krl = krl

    cdef void load_serveradmins(self):
        if not self.serveradmins_file:
            self.serveradmins = frozenset()
            return

        serveradmins = set()
        try:
            with open(self.serveradmins_file, 'rb') as f:
                for line in f:
                    line = line.strip()
                    if not line or line.startswith(b'#'):
                        continue
                    parts = line.split(b':', 1)  # this is 7-element pwd.struct_passwd sequence, but we need only 1st one
                    serveradmins.add(parts[0])
        except EnvironmentError as e:
            self.log.warning("failed to load serveradmins from %r: %s", self.serveradmins_file, e)

        self.serveradmins = frozenset(serveradmins)

    def load(self):
        self.load_krl()
        self.load_serveradmins()
        self.insecure_cas = self.load_cas(self.insecure_ca_files)
        self.secure_cas = self.load_cas(self.secure_ca_files)

    def cert_valid(
        self,
        object cert: Certificate,
        int allowed_cas = CertKind.Any,
    ):
        fp = cert.signing_key.fingerprint()

        found = False
        if (allowed_cas & CertKind.Insecure) and fp in self.insecure_cas:
            found = True
        elif (allowed_cas & CertKind.Secure) and fp in self.secure_cas:
            found = True
        elif (allowed_cas & CertKind.Static) and fp in self.static_cas:
            found = True

        if not found:
            return False

        krl = self.krl
        if krl is not None and not krl.cert_valid(cert):
            return False

        return True

    def update_loop(self):
        while True:
            try:
                self.load()
            except Exception as e:
                sleep_time = self.fail_interval.schedule_next()
                self.log.exception("update failed, will sleep %.2fs before next attempt: %s", sleep_time, e)
                sleep(sleep_time)
            else:
                self.fail_interval.reset()
                sleep_time = self.success_interval.schedule_next()
                self.log.info("CAs and KRL updated, will sleep %.2fs before next attempt", sleep_time)
                sleep(sleep_time)
