# cython: language_level=3str, language=c++

from libcpp cimport bool, pair
from cpython.ref cimport PyObject
from cpython.exc cimport PyErr_SetFromErrno
from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_AsString
from cpython.version cimport PY_MAJOR_VERSION
from libc.stdint cimport uint64_t, uint32_t
from util.generic.deque cimport TDeque
from util.generic.string cimport TString
import os
import socket
import logging
import threading


cdef extern from "<netdb.h>":
    cdef const char * gai_strerror(int)


cdef extern from "Python.h":
    void PyEval_InitThreads()
    object PyCodec_Encode(object o, char *encoding, char *errors)


cdef extern from "infra/netlibus/pylib/errors.h":
    void RaisePyError()


cdef extern from "library/cpp/netliba/v12/udp_address.h" namespace "NNetliba_v12":
    cdef cppclass TUdpAddress nogil:
        TUdpAddress() except +
        bool IsIPv4() except +
        uint32_t GetIpV4() except +
        uint64_t Network
        uint64_t Interface
        int Scope
        int Port

    cdef enum EUdpAddressType:
        UAT_ANY,
        UAT_IPV4,
        UAT_IPV6

    cdef TUdpAddress CreateAddress(TString server, int defaultPort, EUdpAddressType type_) nogil


cdef extern from "library/cpp/netliba/v12/udp_host.h" namespace "NNetliba_v12":
    void SetUdpMaxBandwidthPerIP (float f) nogil except +
    void DisableIBDetection() nogil except +


cdef extern from "infra/netlibus/msgbus.h":
    ctypedef void(*TSendResultCallback)(void *self, TString host, TString addr, SendResult res)

    cdef cppclass TransferState "TMsgBus::TransferState" nogil:
        bool HasOut
        size_t In
        size_t Transfers

    cdef cppclass RecvMsg "TMsgBus::RecvMsg" nogil:
        TString Data
        TString Address
        TString MyAddress

    cdef cppclass TMsgBus nogil:
        TMsgBus(int, float) except +
        void Start() except +
        void Stop() except +
        void Send(TString msg, TString addr, TString dest_host) except +RaisePyError
        void Send(TString msg, TDeque[TUdpAddress] addrs, TString dest_host) except +RaisePyError
        RecvMsg Receive(bool block, float timeout) except +RaisePyError
        int GetListenPort() except +
        void SetSendResultCallback(void *obj, TSendResultCallback cb)
        TransferState GetTransferState() except +

    cpdef enum SendResult "TMsgBus::ESendResult":
        OK "TMsgBus::Ok"
        RETRIED "TMsgBus::Retried"
        FAILED "TMsgBus::Failed"
        UNKNOWN "TMsgBus::Unknown"


IF UNAME_SYSNAME != "Windows":
    from posix.time cimport CLOCK_MONOTONIC, timespec, clock_gettime

    cpdef public inline double monotime():
        cdef timespec ts
        clock_gettime(CLOCK_MONOTONIC, &ts)
        return <double>ts.tv_sec + <double>ts.tv_nsec / 1e9
ELSE:
    cdef extern from "util/system/datetime.h":
        uint64_t MicroSeconds()

    cpdef public inline double monotime():
        return <double>MicroSeconds() / 1e6


IF UNAME_SYSNAME == "Linux":
    cdef extern from "infra/netlibus/pylib/nstools.h" nogil:
        int setns_wrapper(int fd, int mode)

    cdef extern from "<sched.h>" nogil:
        cpdef enum:
            # Cgroup "CLONE_NEWCGROUP"
            Ipc "CLONE_NEWIPC"
            Network "CLONE_NEWNET"
            Mount "CLONE_NEWNS"
            Pid "CLONE_NEWPID"
            User "CLONE_NEWUSER"
            Uts "CLONE_NEWUTS"


    cpdef public move_to_ns(object fileobject, int mode) with gil:
        if not isinstance(fileobject, int):
            fileobject = fileobject.fileno()

        ret = setns_wrapper(fileobject, mode)
        if ret != 0:
            PyErr_SetFromErrno(OSError)
ELSE:
    Ipc = -1
    Network = -1
    Mount = -1
    Pid = -1
    User = -1
    Uts = -1

    cpdef public move_to_ns(object fileobject, int mode) with gil:
        raise RuntimeError("Namespaces aren't available in your OS")


class TimeoutException(RuntimeError):
    pass


class ResolveException(socket.gaierror):
    pass


cdef inline object fromvoidptr(void *obj):
    # note, that I don't use Py_XINCREF here
    cdef PyObject *o = <PyObject*>(obj)
    if not o:
        return None
    return <object>(o)


cpdef public set_max_bandwidth_per_ip(float bandwidth):
    if bandwidth < 0.:
        raise ValueError("bandwidth must be >=0")

    SetUdpMaxBandwidthPerIP(bandwidth)


cdef class MsgBus(object):
    """
    :param int port: port to bind on. By default random ephemeral port is used
    :param float timeout: transfer attempt timeout. Default: 15.0
    :param object logger: logger object to use. If not specified, 'ya.skynet.netlibus' will be used
    :type transfer_result_callback: function(str host, str addr, SendResult result)
    :param transfer_result_callback: callback for transfer results. Default: None
    """

    cdef TMsgBus* thisptr
    cdef public dict counters
    cdef public object log
    cdef public double log_time
    cdef public object transfer_result_callback
    cdef public object inode
    cdef public object fqdn

    Timeout = TimeoutException

    def __cinit__(self, int port=0, float timeout=15.0, object logger=None, object transfer_result_callback=None, object netns_pid=None):
        self.inode = None
        self.fqdn = None
        with nogil:
            DisableIBDetection()

        if netns_pid is not None:
            if Network == -1:
                raise RuntimeError("Namespaces are not available in this OS")

            net_pid_file = open('/proc/%d/ns/net' % (netns_pid,))
            uts_pid_file = open('/proc/%d/ns/uts' % (netns_pid,))
            t = threading.Thread(target=self._create_ns_bus, args=(net_pid_file, uts_pid_file, port, timeout))
            t.start()
            t.join()
        else:
            with nogil:
                self.thisptr = new TMsgBus(port, timeout)
            self.fqdn = socket.getfqdn()

        self.thisptr.SetSendResultCallback(<void*>self, &MsgBus.process_send_result)
        self.counters = {'in_flight': 0, 'sent': 0, 'resent': 0, 'expired': 0, 'received': 0}

        self.transfer_result_callback = transfer_result_callback
        self.log = logger or logging.getLogger('ya.skynet.netlibus')
        self.log_time = monotime()

    def _create_ns_bus(self, object net_pid_file, object uts_pid_file, int port, float timeout):
        move_to_ns(net_pid_file, Network)
        move_to_ns(uts_pid_file, Uts)
        self.thisptr = new TMsgBus(port, timeout)
        self.fqdn = socket.getfqdn()
        self.inode = os.fstat(net_pid_file.fileno()).st_ino

    def __dealloc__(self):
        del self.thisptr

    def start(self):
        with nogil:
            self.thisptr.Start()

    def stop(self):
        with nogil:
            self.thisptr.Stop()

    def port(self):
        return self.thisptr.GetListenPort()

    @staticmethod
    cdef public void process_send_result(void *s, TString host, TString addr, SendResult result) with gil:
        cdef object self = <MsgBus>fromvoidptr(s)

        if result == OK:
            self.counters['in_flight'] -= 1
            # self.log.warn("delivered to %s (%s)", _TString_to_str(host), _TString_to_str(addr))
        elif result == RETRIED:
            self.counters['resent'] += 1
            self.log.debug("delivery to %s (%s) failed, retry to next addr", _TString_to_str(host), _TString_to_str(addr))
        elif result == FAILED:
            self.counters['in_flight'] -= 1
            self.counters['expired'] += 1
            self.log.debug("delivery to %s (%s) failed", _TString_to_str(host), _TString_to_str(addr))
        cur_time = monotime()
        log_time = self.log_time
        if cur_time - log_time > 30:
            self.log.info('[statistics] %s: counters %s', self.port(), self.counters)
            self.log_time = cur_time

        if self.transfer_result_callback is not None:
            self.transfer_result_callback(_TString_to_str(host), _TString_to_str(addr), result)

    def send(self, bytes msg not None, object addr not None, object dest=None):
        cdef TString msg_ = TString(PyBytes_AS_STRING(msg), len(msg))
        cdef TString addr_ = basestr_to_TString(addr)
        cdef TString dest_ = basestr_to_TString(dest)
        self.counters['in_flight'] += 1
        self.counters['sent'] += 1
        with nogil:
            self.thisptr.Send(msg_, addr_, dest_)

    def send_ex(self, bytes msg not None, list addrs not None, object dest=None):
        cdef TDeque[TUdpAddress] candidates = TDeque[TUdpAddress]()
        cdef TUdpAddress candidate
        cdef TString msg_ = TString(PyBytes_AS_STRING(msg), len(msg))
        cdef TString dest_ = basestr_to_TString(dest)
        for addr in addrs:
            host = addr[0]
            candidate = CreateAddress(basestr_to_TString(host), addr[1], UAT_ANY)
            if candidate.Interface == 0:
                continue

            candidates.push_back(candidate)
        if candidates.size() == 0:
            raise socket.gaierror(-2, gai_strerror(-2))
        self.counters['in_flight'] += 1
        self.counters['sent'] += 1
        with nogil:
            self.thisptr.Send(msg_, candidates, dest_)

    def receive(self, bool block=True, timeout=None):
        tmt = -1. if timeout is None else float(timeout)
        with nogil:
            m = self.thisptr.Receive(block, tmt)
            msg = m.Data
            addr = m.Address
            my_addr = m.MyAddress

        self.counters['received'] += 1
        return _TString_to_bytes(msg), _TString_to_str(addr), _TString_to_str(my_addr)

    def get_transfer_state(self):
        cdef TransferState c = self.thisptr.GetTransferState()
        return c.HasOut, c.In, c.Transfers

cdef public PyObject* TimeoutError = <PyObject*>TimeoutException
cdef public PyObject* ResolveError = <PyObject*>ResolveException

cdef bytes _TString_to_bytes(TString stroka):
    return bytes((stroka.Data())[:stroka.Size()])

cdef str _TString_to_str(TString s):
    cdef bytes b = _TString_to_bytes(s)
    if PY_MAJOR_VERSION == 2:
        return str(b)

    return b.decode()

cdef TString basestr_to_TString(object s):
    cdef bytes b

    if s is None:
        return TString()
    elif isinstance(s, bytes):
        return TString(PyBytes_AS_STRING(s), len(s))
    else:
        b = PyCodec_Encode(s, "utf-8", "strict")
        return TString(PyBytes_AS_STRING(b), len(b))

PyEval_InitThreads()
