#include "sample.h"
#include "utils.h"

#include <balancer/serval/contrib/cone/cold.h>
#include <balancer/serval/contrib/cone/cone.hh>
#include <balancer/serval/core/address.h>
#include <balancer/serval/core/buffer.h>
#include <balancer/serval/core/config.h>
#include <balancer/serval/core/io.h>
#include <balancer/serval/core/tls.h>
#include <balancer/serval/core/http.h>
#include <balancer/serval/mod/proxy/proxy.ev.pb.h>

#include <library/cpp/threading/hot_swap/hot_swap.h>

#include <util/digest/numeric.h>
#include <util/generic/algorithm.h>
#include <util/generic/deque.h>
#include <util/generic/scope.h>
#include <util/generic/vector.h>
#include <util/generic/ymath.h>
#include <util/random/fast.h>
#include <util/random/random.h>
#include <util/stream/file.h>
#include <util/string/split.h>
#include <util/string/strip.h>
#include <util/system/thread.h>

static NSv::TAction Proxy(const YAML::Node& args, NSv::TAuxData& aux) {
    CHECK_NODE(args, args.IsMap(), "`proxy` requires an argument");
    auto arg = args.begin()->second;
    auto attempts = NSv::Optional<size_t>(args["attempts"], arg.IsSequence() ? arg.size() : 1, [](size_t n) {
        return n > 0;
    });
    auto attemptsFast = NSv::Optional<size_t>(args["attempts-fast"], 0);
    auto attemptFor = NSv::Optional(args["attempt-for"], cone::timedelta::max());
    auto attemptRate = NSv::Optional(args["attempt-rate"], std::numeric_limits<double>::infinity(), [](double v) {
        return v >= 1;
    });
    auto attemptRateDecay = NSv::Optional(args["attempt-rate-decay"], 0.01, [](double v) {
        return 0 < v && v < 1;
    });
    auto attemptDelay = NSv::Optional(args["attempt-delay"], cone::timedelta::max());
    bool hedgeds = !!args["attempt-delay"];

    NSv::TBackendOptions opts;
    opts.ConnTime = &aux.Signal<NSv::THistogram>("conn-time_dhhh");
    opts.ExchangeTime = &aux.Signal<NSv::THistogram>("exchange-time_dhhh");
    opts.StatusCodes = {
        &aux.Signal("status-1xx_dmmm"),
        &aux.Signal("status-2xx_dmmm"),
        &aux.Signal("status-3xx_dmmm"),
        &aux.Signal("status-4xx_dmmm"),
        &aux.Signal("status-5xx_dmmm"),
    };
    opts.ConnTimeout = NSv::Optional(args["conn-timeout"], opts.ConnTimeout);
    opts.RecvTimeout = NSv::Optional(args["recv-timeout"], hedgeds && args["attempt-for"] ? cone::timedelta::max() : opts.RecvTimeout);
    opts.SendTimeout = NSv::Optional(args["send-timeout"], opts.RecvTimeout);
    opts.Keepalive = NSv::Optional(args["keepalive"], opts.Keepalive);
    opts.Conn.WriteBuffer = NSv::Optional(args["send-buffer"], opts.Conn.WriteBuffer);
    opts.Conn.StreamWindow = NSv::Optional(args["recv-buffer"], opts.Conn.StreamWindow);
    opts.Sync = NSv::Optional(args["sync"], opts.Sync);
    if (const auto& spec = args["retry-codes"]) {
        if (spec.IsScalar()) {
            opts.RetryCodes.emplace_back(NSv::ParseCode(spec), false);
        } else {
            for (const auto& item : spec) {
                CHECK_NODE(item.second, !spec.IsMap() || EqualToOneOf(item.second.Scalar(), "normal", "fast"),
                           "retries are either 'normal' or 'fast'");
                opts.RetryCodes.emplace_back(NSv::ParseCode(spec.IsMap() ? item.first : item),
                                             spec.IsMap() && item.second.Scalar() == "fast");
            }
        }
    }

    TVector<NSv::TFunction<TMaybe<ui64>(const NSv::THead&)>> hashBy;
    if (const auto& arg = args["hash-by"]) {
        if (arg.IsScalar()) {
            hashBy.emplace_back(NSv::ParseHashSpec(arg));
        } else {
            for (const auto& item : arg) {
                hashBy.emplace_back(NSv::ParseHashSpec(item));
            }
        }
    }

    auto overrideHeader = NSv::Optional(args["override"], TString{}); // TODO document
    TVector<NSv::TBackend> items;
    THolder<NSv::TWeightsHolder> weightsHolder;
    if (const auto& file = args["weights-file"]) {
        TVector<TString> names;
        CHECK_NODE(file, file.IsScalar(), "must be a path");
        CHECK_NODE(arg, arg.IsSequence(), "with a weights file, there must be several branches");
        for (const auto& item : arg) {
            CHECK_NODE(item, item.IsMap() && item.size() == 1, "with a weights file, branches must be `name: action`");
            CHECK_NODE(item, item.begin()->first.IsScalar(), "branch name must be a scalar");
            items.emplace_back(Backend(item.begin()->second, aux));
            names.emplace_back(item.begin()->first.Scalar());
        }
        weightsHolder = MakeHolder<NSv::TWeightsHolder>(TString(file.Scalar()), std::move(names));
    } else {
        // Hashing is supposed to force equal requests to be sent to the same backends; with
        // autobalancing enabled, they would be distributed between two backends instead.
        bool autobalance = !hashBy;
        TIntrusivePtr<NSv::TWeightsVector> values = new NSv::TWeightsVector;
        if (!arg.IsSequence()) {
            items.emplace_back(Backend(arg, aux));
            values->emplace_back(1.f);
        } else for (const auto& item : arg) {
            auto weight = 1.f;
            auto weighted = item.IsMap() && item.size() == 1 && YAML::convert<float>::decode(item.begin()->first, weight);
            // IIRC, `decode` does not guarantee that the value remains the same on failure.
            CHECK_NODE(item, !weighted || weight >= 0.f, "weights must be >= 0");
            items.emplace_back(Backend(weighted ? item.begin()->second : item, aux));
            values->emplace_back(!weighted ? 1.f : weight ? 1.f / weight : std::numeric_limits<float>::infinity());
            autobalance &= !weighted;
        }
        CHECK_NODE(arg, FindIf(values->begin(), values->end(), IsFinite) != values->end(),
                   "at least one backend with nonzero weight is required");
        if (!autobalance) {
            weightsHolder = MakeHolder<NSv::TWeightsHolder>(values);
        }
    }

    NSv::TAction fallback;
    if (const auto& arg = args["else"]) {
        fallback = aux.Action(arg);
    }

    return [=,
        loads         = std::make_unique<std::atomic<ui32>[]>(items.size()),
        items         = std::move(items),
        weightsHolder = std::move(weightsHolder),
        &overload      = aux.Signal<NSv::TNumber<double>>("overload_avvv", 1.0), // out-requests / in-requests
        &time         = aux.Signal<NSv::THistogram>("time_dhhh"),
        &emptyWrites = aux.CustomSignal("http-empty-writes_summ")
    ](NSv::IStreamPtr& req) mutable {
        auto rqh = req->Head();
        if (!rqh MUN_RETHROW) {
            return false;
        }
        TVector<NSv::TBackend> overrides;
        // TODO 1. actually handle the timeout
        //      2. if (!url->Scheme) replace with...first? most common? from the original list
        //      3. if all headers cannot be parsed/resolved, fail instead of falling back to config?
        //         (maybe also allow empty backend lists as a config argument then?)
        for (auto [a, b] = rqh->equal_range(overrideHeader); a != b; a++) {
            if (auto url = NSv::URL::Parse(TString(a->second), /*withTimeout=*/true)) {
                for (auto ip : NSv::IP::Resolve(TString(url->Host), url->Port)) {
                    overrides.emplace_back(Backend(*url, ip, &emptyWrites));
                }
            }
        }
        NSv::IStreamPtr original = req;
        TVector<size_t> order;
        TVector<size_t> unshifted;
        TIntrusivePtr<NSv::TWeightsVector> weights =
            // Use random with overriden backends -- there is no loads vector to share between requests.
            overrides ? new NSv::TWeightsVector(overrides.size(), 1.) : weightsHolder ? weightsHolder->AtomicLoad() : nullptr;
        auto deadline = attemptFor == cone::timedelta::max() ? cone::time::max() : cone::time::clock::now() + attemptFor;
        auto remaining = attempts;
        auto remainingFast = attemptsFast;
        TMaybe<ui64> seed;
        for (const auto& hasher : hashBy) {
            if (auto h = hasher(*rqh)) {
                seed = seed ? CombineHashes(*seed, *h) : *h;
            }
        }
        TReallyFastRng32 rng(seed ? *seed : RandomNumber<ui64>());

        Y_DEFER {
            if (!weights) {
                for (size_t i : order) {
                    loads[i] -= 1;
                }
            }
        };
        auto selectBackend = [&, i = size_t(0)]() mutable {
            // TODO DC-aware balancing? either pick a weighted random DC for each attempt or pick
            //      the current DC and fall back to random if it fails.
            // TODO an option to send additional requests periodically to check whether the item works.
            if (!weights) {
                // In this mode, we pick the least loaded of two random backends. To offset
                // fast failures, the load generated by a failed attempt is considered
                // to linger until the end of this `proxy` action. (TODO: a timer?)
                auto have = items.size() - unshifted.size();
                if (have > 1) {
                    size_t a = rng.Uniform(have * 2);
                    size_t b = rng.Uniform(have - 1);
                    // One bit is used as a tiebreaker in case both backends turn out
                    // to have the exact same load; choose a random one in that case.
                    bool tiebreaker = a % 2;
                    a /= 2;        // [0, have * 2) -> [0, have)
                    b += (b >= a); // [0, have - 1) -> [0, a) | [a + 1, have)
                    size_t oa = a;
                    size_t ob = b;
                    // An extension of the above trick to an arbitrary number of skipped points.
                    // The order is important: i-th element of `unshifted` is in [0, N - i).
                    for (size_t j = unshifted.size(); j--;) {
                        a += (a >= unshifted[j]);
                        b += (b >= unshifted[j]);
                    }
                    // Branchless version of `tiebreaker ? X <= Y : X < Y` for integers.
                    bool useA = loads[a] < loads[b] + tiebreaker;
                    order.push_back(useA ? a : b);
                    unshifted.push_back(useA ? oa : ob);
                } else {
                    // No choice, pick the only remaining one.
                    size_t result = 0;
                    for (size_t j = unshifted.size(); j--;) {
                        result += (result >= unshifted[j]);
                    }
                    order.push_back(result);
                    // Next attempt will start going over all backends again.
                    unshifted.clear();
                }
                loads[order.back()]++;
            } else if (i == order.size() || (*weights)[order[i]] == std::numeric_limits<float>::infinity()) {
                // Only weight-0 backends remain, need a new sample.
                i = 0;
                order = NSrv::RandomSample(rng, *weights, Min(remaining + remainingFast + 1, weights->size()));
            }
            return i++;
        };

        auto sendOne = [&, wasFastError = false](NSv::IStreamPtr& req) mutable {
            // ++req->Stats()->BackendAttempts;

            remaining--;
            auto t = NSv::Timer(time);
            auto i = selectBackend();
            if (!wasFastError) {
                // XXX should we count fast errors against rate limit or not?
                overload.Update([=](double v) {
                    return v / (1 - attemptRateDecay + (i ? 0 : attemptRateDecay * v));
                });
            }
            wasFastError = false;
            if ((overrides ? overrides : items)[order[i]](req, opts)) {
                return true;
            }
            int err = mun_errno;
            switch (err) {
            case ECANCELED: return false;
            case EREQDONE: return false;
            case EADDRNOTAVAIL: break;
            case ENETUNREACH: break;
            case ECONNREFUSED:
                // ++req->Stats.ConnRefused;
                break;
            case ETIMEDOUT:
                // ++req->Stats.ConnTimeout;
                // req->Stats.ConnectTime = TDuration::Zero();
                // req->Stats.FullBackendTime = TDuration::Zero();
                // req->Stats.StreamStartTime = TInstant::Zero();
                break;
            case ERTIMEDOUT:
                // ++req->Stats()->BackendTimeout;
                // req->Stats()->FullBackendTime = TDuration::Zero();
                // req->Stats()->StreamStartTime = TInstant::Zero();
                break;
            case EWTIMEDOUT:
                // ++req->Stats()->BackendTimeout;
                // req->Stats()->FullBackendTime = TDuration::Zero();
                // req->Stats()->StreamStartTime = TInstant::Zero();
                break;
            case ECONNRESET:
                // ++req->Stats()->ConnReset;
                break;
            }
            // ++req->Stats.BackendError;
            // XXX when is connection timeout a fast error?
            if ((wasFastError = EqualToOneOf(err, ENETUNREACH, ECONNREFUSED, ETIMEDOUT) && remainingFast)) {
                remainingFast--;
                remaining++;
            }
            req->Log().Push<NSv::NEv::TProxyError>(err);
            return false;
        };

        bool canRetry = true;
        if (!hedgeds) {
            do {
                if (sendOne(req)) {
                    return true;
                }
                if (EqualToOneOf(mun_errno, ECANCELED, EREQDONE)) {
                    return false;
                }
                req = original;
                canRetry = original->Replay() == 0;
            } while (canRetry && remaining && overload < attemptRate && cone::time::clock::now() < deadline);
        } else {
            // XXX maybe some sort of tee stream, or rely on buffering and wait until task terminates
            //     or consumes the entire payload before replaying and retrying, etc.
            auto payload = NSv::ReadFrom(*req);
            if (!payload MUN_RETHROW) {
                return false;
            }
            struct cone* sentHead = nullptr;
            // Can't use `cone::mguard` here because we also need "cancel all tasks except X".
            TVector<cone::guard> tasks;
            // With hedgeds, `attempt-for` is used as a limit on everything; if there
            // is no response head until that point, the request is cancelled.
            bool allDone = ::cone->deadline(deadline, [&, i = size_t(0)]() mutable noexcept {
                do {
                    tasks.emplace_back([&] {
                        NSv::IStreamPtr fake = NSv::ConstRequestStream(*rqh, *payload,
                            [&, current = ::cone](NSv::THead& rsh) {
                                sentHead = current;
                                for (auto& task : tasks) {
                                    if (task.get() != current) {
                                        task->cancel();
                                    }
                                }
                                return req->WriteHead(rsh);
                            },
                            [&](TStringBuf chunk) {
                                return req->Write(chunk);
                            },
                            [&](NSv::THeaderVector& tail) {
                                return req->Close(tail);
                            },
                            req->Log().Fork<NSv::NEv::TProxyAsyncAttempt>(++i)
                        );
                        if (!sendOne(fake) MUN_RETHROW) {
                            return false;
                        }
                        return !mun_error(EINVAL, "task neither failed nor wrote a response; attempt-delay cannot be used");
                    });
                    if (!::cone->timeout(attemptDelay, [&]{ return tasks.back()->wait(cone::norethrow); }) MUN_RETHROW) {
                        if (mun_errno != ETIMEDOUT || cone::time::clock::now() >= deadline) {
                            return false; // Request cancelled by client or by timeout.
                        }
                    }
                } while (!sentHead && remaining && overload < attemptRate);
                for (auto& task : tasks) {
                    if (!task->wait(cone::norethrow) MUN_RETHROW) {
                        return false;
                    }
                }
                return true;
            });
            if (allDone) {
                // The tasks *always* fail; pick the last error here.
                for (const auto& task : tasks) {
                    if (task.get() != sentHead) {
                        task->wait(cone::rethrow);
                    }
                }
                // If one task wrote the head, take its error specifically instead.
                if (sentHead && !sentHead->wait(cone::rethrow) && mun_errno == EREQDONE) {
                    return false;
                }
            } else {
                bool cancelled = mun_errno != ETIMEDOUT;
                cone::uninterruptible([&] {
                    for (const auto& task : tasks) {
                        task->cancel();
                    }
                    for (const auto& task : tasks) {
                        task->wait(cone::rethrow); // consume all errors
                    }
                });
                if (cancelled) {
                    return !mun_error(ECANCELED, "request cancelled");
                }
                mun_error(ETIMEDOUT, "all attempts timed out");
            }
            canRetry = req->Replay() == 0;
        }
        if (!hedgeds && !canRetry) {
            // with hedgeds, we always can spawn another attempt
        }
        req->Log().Push<NSv::NEv::TProxyFailure>(canRetry);
        return canRetry && fallback && fallback(req);
    };
}

SV_DEFINE_ACTION("proxy", Proxy);
