#include "sockops.h"

#include <balancer/kernel/helpers/syscalls.h>

#include <util/network/socket.h>
#include <util/system/compat.h>

#ifdef _win_
#   include <winsock.h>
#else
#   include <netinet/tcp.h>
#   include <sys/ioctl.h>
#   include <sys/socket.h>
#   include <sys/types.h>
#   define ioctlsocket ioctl
#   if defined(_linux_) && !defined(SO_REUSEPORT)
// TODO(velavokr): have to define this, since we are still using Ubuntu 12.04 glibc in our builds (sbt:102903136)
#       define SO_REUSEPORT 15
#   endif
#   ifdef _darwin_
#       define tcp_info tcp_connection_info
#       define TCP_INFO TCP_CONNECTION_INFO
#       define tcpi_rtt tcpi_rttcur
#       define tcpi_total_retrans tcpi_txretransmitpackets
#       define TCP_KEEPIDLE TCP_KEEPALIVE
#   endif
#endif


namespace NSrvKernel {

    namespace {
        template <class T>
        auto SetSockOptImpl(SOCKET sock, int level, int optname, T val) {
            return Y_SYSCALL(setsockopt(sock, level, optname, &val, sizeof(val)));
        }
    }

    TErrorOr<TSocketHolder> TcpSocket(int af, bool makeNonblock) noexcept {
        TSocketHolder res;
#ifdef SOCK_NONBLOCK
        Y_PROPAGATE_ERROR(Y_R_SYSCALL(socket(af, (SOCK_STREAM | (makeNonblock ? SOCK_NONBLOCK : 0)), IPPROTO_TCP)).AssignTo(res));
#else
        Y_PROPAGATE_ERROR(Y_R_SYSCALL(socket(af, SOCK_STREAM, IPPROTO_TCP)).AssignTo(res));
        if (makeNonblock) {
            Y_PROPAGATE_ERROR(EnableNonBlocking(res));
        }
#endif
        return res;
    }

    TError Bind(SOCKET sock, const TSockAddr& addr) noexcept {
        return Y_SYSCALL(bind(sock, addr.Addr(), addr.Len()));
    }

    TError Listen(SOCKET sock, unsigned backlog) noexcept {
        return Y_SYSCALL(listen(sock, backlog));
    }

    TError Shutdown(SOCKET sock) noexcept {
        return Y_SYSCALL(shutdown(sock, SHUT_WR));
    }

    TError Close(SOCKET sock) noexcept {
        return Y_SYSCALL(closesocket(sock));
    }

    TError EnableNonBlocking(SOCKET sock) noexcept {
        unsigned nb = 1;
        return Y_SYSCALL(ioctl(sock, FIONBIO, &nb));
    }

    TError EnableNoDelay(SOCKET sock) noexcept {
        return SetSockOptImpl(sock, IPPROTO_TCP, TCP_NODELAY, 1);
    }

    TError EnableReuseAddr(SOCKET sock) noexcept {
        return SetSockOptImpl(sock, SOL_SOCKET, SO_REUSEADDR, 1);
    }

    TError EnableReusePort(SOCKET sock) noexcept {
#ifdef _win_
        return EnableReuseAddr(sock);
#else
        return SetSockOptImpl(sock, SOL_SOCKET, SO_REUSEPORT, 1);
#endif
    }

    TError EnableRstOnClose(SOCKET sock) noexcept {
        struct linger l;
        Zero(l);
        l.l_linger = 0;
        l.l_onoff = 1;
        return SetSockOptImpl(sock, SOL_SOCKET, SO_LINGER, l);
    }

    TError EnableV6Only(SOCKET sock) noexcept {
#ifdef IPV6_V6ONLY
        return SetSockOptImpl(sock, IPPROTO_IPV6, IPV6_V6ONLY, 1);
#else
        Y_UNUSED(sock);
        // OSX does not have the option but its default is just what we need
        return {};
#endif
    }

    TErrorOr<TSockAddr> GetSockName(SOCKET sock) noexcept {
        sockaddr_storage stor = {};
        socklen_t len = sizeof(stor);
        Y_PROPAGATE_ERROR(Y_SYSCALL(getsockname(sock, (sockaddr*)&reinterpret_cast<sockaddr&>(stor), &len)));
        TSockAddr res;
        Y_PROPAGATE_ERROR(TSockAddr::FromSockAddr(stor).AssignTo(res));
        return res;
    }

    TError ValidateSockBufSize(TSockBufSize sz) noexcept {
        TSocketHolder sock;
        Y_PROPAGATE_ERROR(TcpSocket(AF_INET6, false).AssignTo(sock));
        return SetSockBufSize(sock, sz);
    }

    TError SetSockBufSize(SOCKET sock, TSockBufSize sz) noexcept {
        if (sz.Rcv) {
            Y_PROPAGATE_ERROR(SetSockOptImpl(sock, SOL_SOCKET, SO_RCVBUF, *sz.Rcv));
        }
        if (sz.Snd) {
            Y_PROPAGATE_ERROR(SetSockOptImpl(sock, SOL_SOCKET, SO_SNDBUF, *sz.Snd));
        }
        return {};
    }

    namespace {
        auto TimevalFromDuration(TDuration d) {
            return timeval {
                .tv_sec = time_t(d.Seconds()),
                .tv_usec = suseconds_t(d.MicroSeconds() % 1'000'000)
            };
        }
    }

    TError SetSockTimeout(SOCKET sock, TSockTimeout tout) noexcept {
        if (tout.Rcv) {
            Y_PROPAGATE_ERROR(SetSockOptImpl(sock, SOL_SOCKET, SO_RCVTIMEO, TimevalFromDuration(*tout.Rcv)));
        }
        if (tout.Snd) {
            Y_PROPAGATE_ERROR(SetSockOptImpl(sock, SOL_SOCKET, SO_SNDTIMEO, TimevalFromDuration(*tout.Snd)));
        }
        return {};
    }

    TError ValidateKeepalive(TTcpKeepalive k) noexcept {
        TSocketHolder sock;
        Y_PROPAGATE_ERROR(TcpSocket(AF_INET6, false).AssignTo(sock));
        return EnableKeepalive(sock, k);
    }

    TError EnableKeepalive(SOCKET sock, TTcpKeepalive k) noexcept {
        Y_PROPAGATE_ERROR(SetSockOptImpl(sock, SOL_SOCKET, SO_KEEPALIVE, 1));
        if (k.Cnt) {
            Y_PROPAGATE_ERROR(SetSockOptImpl(sock, IPPROTO_TCP, TCP_KEEPCNT, *k.Cnt));
        }
        if (k.Idle) {
            Y_PROPAGATE_ERROR(SetSockOptImpl(sock, IPPROTO_TCP, TCP_KEEPIDLE, *k.Idle));
        }
        if (k.Intvl) {
            Y_PROPAGATE_ERROR(SetSockOptImpl(sock, IPPROTO_TCP, TCP_KEEPINTVL, *k.Intvl));
        }
        return {};
    }

    TError ValidateCongestionControl(const TString& algo) noexcept {
        TSocketHolder sh;
        Y_PROPAGATE_ERROR(TcpSocket(AF_INET6, false).AssignTo(sh));
        return SetCongestionControl(sh, algo);
    }

    TError SetCongestionControl(SOCKET sock, const TString& algo) noexcept {
#ifdef TCP_CONGESTION
        return Y_SYSCALL(setsockopt(sock, IPPROTO_TCP, TCP_CONGESTION, algo.data(), algo.size()));
#else
        Y_UNUSED(sock);
        Y_UNUSED(algo);
        return Y_MAKE_ERROR(TSystemError(ENOPROTOOPT));
#endif
    }

    TError SetNotsentLowat(SOCKET sock, int size) noexcept {
#ifdef TCP_NOTSENT_LOWAT
        return Y_SYSCALL(setsockopt(sock, IPPROTO_TCP, TCP_NOTSENT_LOWAT, &size, sizeof(int)));
#else
        Y_UNUSED(sock);
        Y_UNUSED(size);
        return Y_MAKE_ERROR(TSystemError(ENOPROTOOPT));
#endif
    }

    bool CanGetTcpInfo() noexcept {
        Y_TRY(TError, err) {
            TSocketHolder sock;
            Y_PROPAGATE_ERROR(TcpSocket(AF_INET6, false).AssignTo(sock));
            return GetTcpInfo(sock).ReleaseError();
        } Y_CATCH {
            return false;
        };
        return true;
    }

    TErrorOr<TTcpInfo> GetTcpInfo(SOCKET sock) noexcept {
#ifdef TCP_INFO
        TTcpInfo res;
        tcp_info info;
        Zero(info);
        socklen_t len = sizeof(info);
        Y_PROPAGATE_ERROR(Y_SYSCALL(getsockopt(sock, IPPROTO_TCP, TCP_INFO, &info, &len)));
        res.RttVar = TDuration::MicroSeconds(info.tcpi_rttvar);
        res.SndCwnd = info.tcpi_snd_cwnd;
        res.Rtt = TDuration::MicroSeconds(info.tcpi_rtt);
        res.TotalRetrans = info.tcpi_total_retrans;
#   ifdef _linux_
        res.Unacked = info.tcpi_unacked;
#   endif
        return res;
#else
        Y_UNUSED(sock);
        return Y_MAKE_ERROR(TSystemError(ENOPROTOOPT));
#endif
    }

    TMaybe<int> ErrNo(const TError& err) {
        auto* p = err ? err.GetAs<TSystemError>() : nullptr;
        return p ? MakeMaybe(p->Status()) : Nothing();
    }

    bool HasBlocked(const TError& err) {
        return IsIn({EAGAIN, EWOULDBLOCK}, ErrNo(err));
    }

    bool HasTimedOut(const TError& err) {
        return ETIMEDOUT == ErrNo(err);
    }

    TError Connect(SOCKET sock, const TSockAddr& addr) {
        return Y_SYSCALL(connect(sock, addr.Addr(), addr.Len()));
    }

    TErrorOr<TAcceptResult> Accept(SOCKET sock, bool makeNonblock) noexcept {
        sockaddr_storage stor = {};
        sockaddr* aPtr = (sockaddr*)&stor;
        socklen_t len = sizeof(stor);
        TAcceptResult res;
#ifdef SOCK_NONBLOCK
        Y_PROPAGATE_ERROR(Y_R_SYSCALL(accept4(sock, aPtr, &len, (makeNonblock ? SOCK_NONBLOCK : 0))).AssignTo(res.Conn));
#else
        Y_PROPAGATE_ERROR(Y_R_SYSCALL(accept(sock, aPtr, &len)).AssignTo(res.Conn));
        if (makeNonblock) {
            Y_PROPAGATE_ERROR(EnableNonBlocking(res.Conn));
        }
#endif
        Y_PROPAGATE_ERROR(TSockAddr::FromSockAddr(*aPtr).AssignTo(res.Addr));
        return res;
    }

}
