#include <balancer/serval/contrib/cone/cone.hh>
#include <balancer/serval/contrib/cone/cold.h>
#include <balancer/serval/core/address.h>
#include <balancer/serval/core/function.h>
#include <balancer/serval/core/http.h>
#include <balancer/serval/core/io.h>
#include <balancer/serval/core/tls.h>
#include <balancer/serval/core/storage.h>
#include <balancer/serval/tools/loadgen/opts.pb.h>

#include <library/cpp/getoptpb/getoptpb.h>

#include <util/datetime/base.h>
#include <util/generic/hash_set.h>
#include <util/generic/scope.h>
#include <util/generic/vector.h>
#include <util/stream/file.h>
#include <util/string/ascii.h>
#include <util/string/strip.h>

int main(int argc, const char** argv) {
    NSv::TThreadLocalRoot root;
    NGetoptPb::TGetoptPbSettings getopt;
    getopt.ConfPathShort = 0;
    TLoadConfig config = NGetoptPb::GetoptPbOrAbort(argc, argv, getopt);

    if (config.GetPrintErrno()) {
#define PRINT_ERRNO(n, msg) Cout << n << "\t" #n "\t" msg << Endl
        PRINT_ERRNO(EMFILE, "Per-process file descriptor limit reached");
        PRINT_ERRNO(ENFILE, "System-wide file descriptor limit reached");
        PRINT_ERRNO(ECANCELED, "Request cancelled");
        PRINT_ERRNO(EADDRNOTAVAIL, "Not enough ephemeral ports; try tcp_tw_reuse=1");
        PRINT_ERRNO(ENETUNREACH, "No route to target");
        PRINT_ERRNO(ECONNREFUSED, "Target is offline");
        PRINT_ERRNO(ETIMEDOUT, "Connection timeout");
        PRINT_ERRNO(ECONNRESET, "Connection closed before the response has been received");
        PRINT_ERRNO(EPIPE, "Connection closed before the request has been sent");
#undef PRINT_ERRNO
        return 0;
    }

    Y_ENSURE(config.BaseSize() > 0, "-b/--base: at least one base URL is required");
    Y_ENSURE(config.GetThreads() > 0, "-t/--threads: must be > 0");
    Y_ENSURE(config.GetConnections() > 0, "-c/--connections: must be > 0");
    Y_ENSURE(config.GetStreams() > 0, "-s/--streams: must be > 0");
    Y_ENSURE(config.RPSSize() > 0, "-r/--rps: at least one item is required");
    Y_ENSURE(!config.GetInOrder() || config.BaseSize() == 1, "-i/--in-order: can only be used with one target");

    TVector<std::tuple<bool /*TLS*/, bool /*H2*/, bool /*Force H2*/, NSv::URL, NSv::IP>> addrs;
    for (const auto& base : config.GetBase()) {
        auto addr = NSv::URL::Parse(base);
        Y_ENSURE(addr, "-b/--base: invalid address " << base);
        Y_ENSURE(!addr->Path, "-b/--base: paths must be specified on stdin, one per line");
        Y_ENSURE(EqualToOneOf(addr->Scheme, "http", "https", "h2", "h2c", "h2d"),
                 "-b/--base: unsupported scheme " << addr->Scheme);
        bool tls = EqualToOneOf(addr->Scheme, "https", "h2", "h2d", "wss");
        bool h2 = EqualToOneOf(addr->Scheme, "h2", "h2d", "h2c");
        bool forceh2 = EqualToOneOf(addr->Scheme, "h2", "h2c");
        auto ips = NSv::IP::Resolve(TString(addr->Host), addr->Port);
        Y_ENSURE(ips, "-b/--base: could not resolve " << addr->Host);
        for (auto& ip : ips)
            addrs.emplace_back(tls, h2, forceh2, *addr, ip);
    }

    TVector<std::tuple<ui64 /*start*/, ui64 /*end*/, ui64 /*a*/, ui64 /*b*/>> segments;
    for (const auto& rps : config.GetRPS()) {
        TStringBuf args = rps;
        TStringBuf time;
        ui64 start = segments ? std::get<1>(segments.back()) : 0;
        ui64 end = args.TryRSplit('/', args, time) ? start + FromString<TDuration>(time).NanoSeconds() : Max<ui64>();
        ui64 a = Min<ui64>(1000000000u, FromString<ui64>(args.Before('-')));
        ui64 b = Min<ui64>(1000000000u, FromString<ui64>(args.After('-')));
        Y_ENSURE(end != Max<ui64>() || (a == b && segments.size() == config.RPSSize() - 1),
                 "-r/--rps: `/time` can only be omitted on the last item, and only if it is constant");
        Y_ENSURE(a > 0 && b > 0, "-r/--rps: should be > 0");
        segments.emplace_back(start, end, a, b);
    }

    std::atomic<size_t> currentPath = 0;
    TVector<TString> paths;
    for (TString line; Cin.ReadLine(line); ) {
        if (line) {
            Y_ENSURE(line.StartsWith("/"), "path " << line.Quote() << " does not start with /");
            paths.emplace_back(std::move(line));
        }
    }
    Y_ENSURE(paths, "need at least one path on stdin");

    NSv::THeaderVector headers;
    for (auto& h : *config.MutableHeader()) {
        TStringBuf k = StripString(h);
        TStringBuf v;
        // Ignore leading : in pseudo-headers.
        Y_ENSURE(k.TrySplitOn(k.find(':', 1), k, v), "-H/--header: must be `name: value`");
        k = StripString(k);
        v = StripString(v);
        Y_ENSURE(k, "-H/--header: name must be non-empty");
        Y_ENSURE(v, "-H/--header: value must be non-empty");
        std::transform(k.begin(), k.end(), (char*)k.data(), [](char c) { return AsciiToLower(c); });
        headers.insert({k, v});
    }

    auto getNextT = [&, c = std::atomic<ui64>(0)]() mutable -> cone::timedelta {
        ui64 value = c.load(std::memory_order_acquire);
        ui64 target;
        do {
            auto it = UpperBoundBy(segments.begin(), segments.end(), value, [](auto& s) { return std::get<1>(s); });
            if (it == segments.end())
                return cone::timedelta::max();
            auto [start, end, a, b] = *it;
            // Using `P / (Q / R)` instead of `P * (R / Q)` because of integral division
            // and the fact that here `Q >= R` always holds due to `a, b >= 1`.
            target = value + 1000000000u / (a == b ? a : ((value - start) * b + (end - value) * a) / (end - start));
        } while (!c.compare_exchange_strong(value, target));
        return std::chrono::nanoseconds(value);
    };

    struct TConnEntry {
        NSv::TFile FD;
        THolder<NSv::IO> IO;
        THolder<NSv::IConnection> Client;
        unsigned Load = Max<unsigned>();
    };

    auto connect = [&](unsigned j, TConnEntry& conn) {
        auto clientTLS = [](int id) {
            static const NSv::TTLSContext ctxs[] = {
                NSv::TTLSContext({.Client = true, .VerifyPeer = false, .Protocols = {"http/1.1"}}),
                NSv::TTLSContext({.Client = true, .VerifyPeer = false, .Protocols = {"h2", "http/1.1"}}),
                NSv::TTLSContext({.Client = true, .VerifyPeer = false, .Protocols = {"h2"}}),
            };
            return &ctxs[id];
        };
        auto& [tls, h2, forceh2, url, ip] = addrs[j % addrs.size()];
        auto* ctx = tls ? clientTLS(forceh2 ? 2 : h2 ? 1 : 0) : nullptr;
        return (conn.FD = NSv::TFile::Connect(ip))
            && (conn.IO = ctx ? ctx->Wrap(conn.FD, TString(url.Host)) : THolder<NSv::IO>(new NSv::TFile(std::move(conn.FD))))
            && (conn.Client = NSv::H2Client(*conn.IO, nullptr, {.ForceH2 = forceh2}));
    };

    auto send = [&](unsigned j, NSv::IConnection& conn) {
        auto& [tls, h2, forceh2, url, ip] = addrs[j % addrs.size()];
        NSv::THead rqh{"GET", paths[currentPath++ % paths.size()], headers};
        if (rqh.find(":authority") == rqh.end())
            rqh.emplace(":authority", url.Host);
        if (rqh.find(":scheme") == rqh.end())
            rqh.emplace(":scheme", tls ? "https" : "http");
        return conn.Request(rqh);
    };

    auto readHead = [&](NSv::IStream& stream) -> NSv::THead* {
        while (auto head = stream.Head())
            if (!head->IsInformational())
                return head;
        return nullptr;
    };

    auto readBody = [&](NSv::IStream& stream) -> TMaybe<size_t> {
        size_t total = 0;
        while (auto chunk = stream.Read()) {
            total += chunk->size();
            if (!*chunk)
                return total;
            if (!stream.Consume(chunk->size()))
                break;
        }
        return {};
    };

    auto measure = [](TDuration& into, auto&& f) {
        auto start = mun_usec_monotonic();
        Y_DEFER { into = TDuration::MicroSeconds(mun_usec_monotonic() - start); };
        return f();
    };

    cone::barrier barrier{config.GetThreads()};
    auto runThread = [&] {
        TVector<std::shared_ptr<TConnEntry>> conns(config.GetConnections());
        TFileOutput log(TFile("/dev/stdout", OpenAlways | WrOnly | ForAppend));
        cone::event loadDrop;
        unsigned logCount = 0;
        unsigned nextConn = 0;
        auto sendRequest = [&] {
            unsigned j;
            for (size_t k = 0;;) {
                j = config.GetInOrder() ? k : nextConn++ % conns.size();
                if (!conns[j] || conns[j]->Load < config.GetStreams())
                    break;
                if (++k % conns.size() == 0 && !loadDrop.wait() MUN_RETHROW)
                    return false;
            }
            int code = 0;
            TInstant start = TInstant::Now();
            TDuration connT = TDuration::Zero();
            TDuration sendT = TDuration::Zero();
            TDuration delay = TDuration::Zero();
            TDuration total = TDuration::Zero();
            size_t recvSize = 0;
            auto ok = measure(total, [&] {
                auto conn = conns[j];
                if (conn && conn->Client->IsOpen()) {
                    conn->Load++;
                } else {
                    Y_DEFER { loadDrop.wake(); };
                    conn = conns[j] = std::make_shared<TConnEntry>();
                    // TODO conn timeout
                    if (!measure(connT, [&]{ return connect(j, *conn); })) {
                        conns[j].reset();
                        return false;
                    }
                    conn->Load = 1;
                }
                Y_DEFER {
                    conn->Load--;
                    loadDrop.wake();
                };
                // TODO send timeout
                auto rsp = measure(sendT, [&]{ return send(j, *conn->Client); });
                if (!rsp)
                    return false;
                // TODO recv timeout
                auto rsh = measure(delay, [&]{ return readHead(*rsp); });
                if (!rsh)
                    return false;
                code = rsh->Code;
                auto total = readBody(*rsp);
                if (!total)
                    return false;
                recvSize = *total;
                return true;
            });
            log << start.MicroSeconds() / 1000000.0 << '\t' /* << tag */ << '\t'
                << total.MicroSeconds() << '\t' << connT.MicroSeconds() << '\t'
                << sendT.MicroSeconds() << '\t' << delay.MicroSeconds() << '\t'
                << (total - connT - sendT - delay).MicroSeconds() << '\t'
                << /* interval_event??? */ "0\t" << /* size_out */ "0\t"
                << recvSize << '\t' << (ok ? 0 : mun_errno) << '\t' << code << Endl;
            if (++logCount % 16 == 0)
                log.Flush();
            return ok;
        };

        auto spawnAndRetry = [&] {
            for (size_t i = config.GetRetries() + 1; i--;)
                if (sendRequest())
                    break;
            return true; // ignore errors
        };

        if (!barrier.join() MUN_RETHROW)
            return false;
        cone::time start = cone::time::clock::now();
        cone::mguard tasks;
        for (cone::timedelta d; (d = getNextT()) != cone::timedelta::max(); tasks.add(spawnAndRetry))
            if (!cone::sleep(start + d) MUN_RETHROW)
                return false;
        while (std::any_of(conns.begin(), conns.end(), [](auto& c) { return c && c->Load; }))
            if (!loadDrop.wait() MUN_RETHROW)
                return false;
        return true;
    };

    TVector<cone::guard> threads(config.GetThreads());
    for (auto& thread : threads)
        thread = cone::thread(runThread);
    for (auto& thread : threads)
        if (!thread->wait(cone::rethrow))
            return 1;
    return 0;
}
