#include "ut_sock.h"

#include <util/stream/length.h>

namespace NSrvKernel::NNetUt {

    namespace {
//        TArrayRef<ui8> Chomp(TArrayRef<ui8> data, size_t off) {
//            return data.Slice(0, data.size() - std::min(off, data.size()));
//        }

        ui32 ConsumeNumber(TArrayRef<ui8>& data, size_t bytes) {
            bytes = std::min(data.size(), bytes);
            ui32 res = 0;
            for (auto i : xrange(bytes)) {
                res |= (res << 8) | data[i];
            }
            data = data.Slice(bytes);
            return res;
        }
    }

    TString TStableRandomInput::GenString(size_t len) {
        return TLengthLimitedInput(this, len).ReadAll();
    }

    size_t TStableRandomInput::DoRead(void* buf, size_t len) {
        for (auto i : xrange(len)) {
            ((ui8*)buf)[i] = Data_[Rng_.GenRand() % 64];
        }
        return len;
    }

    TSockBufSize ConsumeSockBufSize(TArrayRef<ui8>& data) {
        if (data.size() < 4) {
            return {};
        }
        TSockBufSize res {
            .Rcv = std::min(ConsumeNumber(data, 2), 65535u),
            .Snd = std::min(ConsumeNumber(data, 2), 65535u),
        };
        return res;
    }

    TActions ConsumeActions(TArrayRef<ui8>& data) {
        TStableRandomInput in;
        TActions res;
        while (data) {
            if (!ConsumeNumber(data, 1)) {
                // duration: 0 - 4095 us
                res.emplace_back(TDuration::MicroSeconds(
                    std::min(ConsumeNumber(data, 2), 4095u)
                ));
            } else {
                // len: 1 - 256kb
                res.emplace_back(in.GenString(
                    1 + std::min(ConsumeNumber(data, 3), 256 * 1024u - 1)
                ));
            }
        }
        return res;
    }

    ETcpFinish ConsumeTcpFinish(TArrayRef<ui8>& data) {
        // TODO(velavokr): a case of an invalid socket. The code 1) must report an appropriate error, 2) must not crash.
        return ConsumeNumber(data, 1) ? ETcpFinish::Fin : ETcpFinish::Rst;
    }

    std::pair<TString, TDuration> MeasureActions(const TActions& acts) {
        TString st = 0;
        TDuration dur;
        for (const auto& a : acts) {
            std::visit(TOverloaded{
                [&](const TString& s) { st += s; },
                [&](const TDuration& d) { dur += d; }
            }, a);
        }
        return {st, dur};
    }

    TWriteThread::TWriteThread(TSocket s, TActions a, ETcpFinish f)
        : Sock_(std::move(s))
        , Acts_(std::move(a))
    {
        if (ETcpFinish::Rst == f) {
            TryRethrowError(EnableRstOnClose(Sock_));
        }
    }

    void* TWriteThread::ThreadProc() noexcept {
        for (const auto& a : Acts_) {
            Y_TRY(TError, err) {
                return std::visit(TOverloaded{
                    [&](const TString& s) -> TError {
                        try {
                            TSocketOutput sout(Sock_);
                            sout.Write(s);
                        } Y_TRY_STORE(TSystemError, yexception);
                        return {};
                    },
                    [&](const TDuration& d) -> TError {
                        Sleep(d);
                        return {};
                    },
                }, a);
            } Y_CATCH {
                Err_ = std::move(err);
                break;
            };
        }
        return nullptr;
    }

    TReadThread::TReadThread(TSocket s, size_t buf, TDuration tot)
        : Sock_(std::move(s))
        , Buf_(buf)
    {
        Sock_.SetSocketTimeout(tot.Seconds(), (tot.MicroSeconds() + 4999) / 1000);
    }

    void* TReadThread::ThreadProc() noexcept {
        Y_TRY(TError, err) {
            try {
                TSocketInput sin(Sock_);
                size_t off = 0;
                TString res;
                res.resize(Buf_);
                while (off < Buf_) {
                    if (size_t r = sin.Read(res.Detach() + off, Buf_ - off); r > 0) {
                        off += r;
                    } else {
                        Res_ = res;
                        return {};
                    }
                }
                // EOF or some unexpected data?
                Res_ = res + sin.ReadAll();
            } Y_TRY_STORE(TSystemError, yexception)
            return {};
        } Y_CATCH {
            Err_ = std::move(err);
        };
        return nullptr;
    }

    std::pair<TSocketHolder, TSocket> CreateTcpSocketPair(TSockBufSize sz) {
        TSocketHolder listenSocket;
        TryRethrowError(TcpSocket(AF_INET6, false).AssignTo(listenSocket));
        TryRethrowError(Bind(listenSocket, {Loopback6(), 0}));
        TSockAddr eph;
        TryRethrowError(GetSockName(listenSocket).AssignTo(eph));
        TryRethrowError(Listen(listenSocket, 1));
        TSocketHolder connSocket;
        TryRethrowError(TcpSocket(AF_INET6, false).AssignTo(connSocket));
        TryRethrowError(Connect(connSocket, eph));
        TryRethrowError(EnableNonBlocking(connSocket));
        TryRethrowError(SetSockBufSize(connSocket, sz));
        TAcceptResult acceptSocket;
        TryRethrowError(Accept(listenSocket, false).AssignTo(acceptSocket));
        TryRethrowError(SetSockBufSize(acceptSocket.Conn, sz));
        return {std::move(connSocket), TSocket(acceptSocket.Conn.Release())};
    }
}
