#include "utils.h"

#include <balancer/serval/core/tls.h>
#include <balancer/serval/mod/proxy/proxy.ev.pb.h>

#include <util/stream/file.h>

namespace NSv {
    // TODO(velavokr): support websockets properly
    THead& RemoveConnectionHeaders(THead& head) noexcept {
        for (auto&& header : {"connection", "keep-alive", "proxy-connection"}) {
            for (auto it = head.find(header); it != head.end(); it = head.find(header)) {
                head.erase(it);
            }
        }
        if (head.has("upgrade")) {
            head.emplace("connection", "upgrade");
        }
        return head;
    }

    bool TKeepAliveStack::Pop(TConnEntry& out) noexcept {
        auto tid = TThread::CurrentThreadId();
        do {
            // XXX this clears `out` correctly, but damn, that's totally not obvious.
            out.Client.Reset();
            out.IO.Reset();
            out = {};
            auto g = Lock.guard(cone::mutex::interruptible);
            if (!g) {
                return false;
            }
            if (!*this) {
                return true;
            }
            // Pick among the newest ones, so that unnecessary old ones are cleaned up;
            // but prefer the ones created on this thread to make migration unnecessary.
            auto end = rbegin() + Min<size_t>(size(), 16);
            auto cur = std::find_if(rbegin(), end, [&](auto& e) {
                return e.Thread == tid;
            });
            if (cur == end) {
                cur = rbegin();
            }
            out = std::move(*cur++); // move to previous element before calling base()...
            erase(cur.base()); // ...because it points to the element after `*cur`.
        } while (out.Thread == tid ? !out.Client->IsOpen() : !out.Client->Migrate());
        return true;
    }

    void TKeepAliveStack::Push(TConnEntry&& in, cone::timedelta timeout) noexcept {
        if (timeout == cone::timedelta::zero()) {
            return;
        }
        in.CloseAt = cone::time::clock::now() + timeout;
        in.Thread = TThread::CurrentThreadId();
        auto g = Lock.guard();
        push_back(std::move(in));
    }

    bool Exchange(const TBackendOptions& opts, IStream& req, IConnection& conn, TLogFrame log) {
        auto rqh = req.Head();
        if (!rqh MUN_RETHROW) {
            return false;
        }

        // TInstant backendStartTime = TInstant::Now();
        // req.Stats.StreamStartTime = backendStartTime;
        // Y_DEFER {
        //     req.Stats.FullBackendTime = TInstant::Now() - backendStartTime;
        // };

        auto t = Timer(opts.ExchangeTime);
        // TODO on a HTTP version mismatch, convert websocket handshakes. Here are the server sides:
        //   * h1: check `:method: GET`, `upgrade: websocket`, `connection: upgrade`,
        //     and `sec-websocket-version: 13`; append '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
        //     to `sec-websocket-key` (client-generated nonce), hash with SHA-1, encode in base64,
        //     send as `sec-websocket-accept` in a 101 response.
        //   * h2: set enable_connect setting to 1, check `:method: CONNECT` and
        //     `:protocol: websocket`, respond with a normal 200.
        if (conn.IsH2()) {
            rqh->erase("upgrade");
        }
        // XXX this is kind of hacky and may or may not discard zero-length payload
        //     depending on whether it arrives before the reader coroutine goes to sleep.
        bool headOnly = req.AtEnd() && !*req.Tail();
        bool sendDone = req.AtEnd() || opts.Sync;
        auto sendOp = [&](auto&& f) {
            auto ok = ::cone->timeout(opts.SendTimeout, f);
            if (!ok && mun_errno == ETIMEDOUT) {
                mun_errno = EWTIMEDOUT;
            }
            return ok;
        };
        auto up = sendOp([&] {
            return conn.Request(RemoveConnectionHeaders(*rqh), !headOnly, std::move(log));
        });
        if (!up MUN_RETHROW) {
            return false;
        }
        auto sendFn = [&] {
            while (auto chunk = req.Read()) {
                if (!sendOp([&] {
                    return *chunk ? up->Write(*chunk) : up->Close(*req.Tail());
                }) MUN_RETHROW) {
                    // Ignore upstream disconnect: next upstream read will get ECONNRESET, unless
                    // the response has already been received, in which case discard the payload.
                    return mun_errno == EPIPE;
                }
                if (!*chunk) {
                    return true;
                }
                if (!req.Consume(chunk->size()) MUN_RETHROW) {
                    return false;
                }
            }
            return false;
        };
        bool recvActive = false;
        auto recvDeadline = ::cone->deadline(cone::time::max());
        auto recvOp = [&](auto&& f) {
            recvActive = true;
            Y_DEFER {
                recvActive = false;
                recvDeadline = ::cone->deadline(cone::time::max());
            };
            // Do not set recv timeout until send is complete. This prevents us from bouncing
            // a working backend that is simply stuck waiting for the complete payload.
            auto ok = sendDone ? ::cone->timeout(opts.RecvTimeout, f) : f();
            if (!ok && mun_errno == ETIMEDOUT) {
                mun_errno = ERTIMEDOUT;
            }
            return ok;
        };
        auto recvFn = [&] {
            while (auto rsh = recvOp([&]{
                return up->Head();
            })) {
                // if (100 <= rsh->Code && rsh->Code < 600) {
                //     ++req.Stats.BackendStatusCodeCategs[rsh->Code / 100 - 1];
                // }

                for (auto [code, fast] : opts.RetryCodes) {
                    if (code > 0 ? rsh->Code == code : -code <= rsh->Code && rsh->Code < -code + 100) {
                        return fast ? !mun_error(ECONNREFUSED, "backend is unavailable")
                                    : !mun_error(ECONNRESET, "backend responded with an error");
                    }
                }
                int code = rsh->Code;
                if (!req.WriteHead(RemoveConnectionHeaders(*rsh)) MUN_RETHROW) {
                    return false;
                }
                (*opts.StatusCodes[(std::min(code, 599) / 100) - 1])++;
                if (rsh->IsInformational()) {
                    continue;
                }
                // `Close` never succeeds, so this does not return an empty chunk more than once.
                while (auto chunk = recvOp([&]{
                    return up->Read();
                })) {
                    // FIXME `up->Consume` is a send op, needs timeout.
                    if (*chunk ? !req.Write(*chunk) || !up->Consume(chunk->size()) : !req.Close(*up->Tail()) MUN_RETHROW) {
                        return false;
                    }
                }
                return false;
            }
            return false;
        };
        if (sendDone) {
            return (headOnly || sendFn()) && recvFn();
        }
        cone::guard send = [&, recv = ::cone] {
            Y_DEFER {
                sendDone = true;
            };
            // Client errors and cancellation are ignored because then the recv coroutine
            // is already terminating. We do not want the `send->wait`, if it happens,
            // to overwrite the recv error.
            bool ok = sendFn() || EqualToOneOf(mun_errno, ECANCELED, EPIPE, ECONNRESET);
            // On a send error, should abort `recv`. The easiest way is to set a zero
            // timeout; ignoring `recvActive` then aborts any client write ops as well.
            recvDeadline = recv->timeout(!ok ? cone::timedelta::zero() : recvActive ? opts.RecvTimeout : cone::timedelta::max());
            return ok;
        };
        // Yes, that's a single `&`; RHS is evaluated unconditionally to consume any send error.
        return recvFn() & (sendDone && send->wait(cone::rethrow));
    }

    TBackend Backend(URL url, IP ip, NSv::TNumber<ui64>* emptyWritesSignal) {
        if (url.Scheme == "udp") {
            return [url, c = TestUDPClient(ip)](IStreamPtr& req, const TBackendOptions& opts) {
                return Exchange(opts, *req, *c, req->Log().Fork<NEv::TProxyBackend>(url.ToString()));
            };
        }
        // FIXME `opts.Conn` should be in the key
        return StaticData(std::make_tuple(TString(url.Scheme), TString(url.Host), url.Port), [&]{
            bool h2 = EqualToOneOf(url.Scheme, "h2", "h2c");
            bool tcp = EqualToOneOf(url.Scheme, "tcp", "tcps");
            auto tls = EqualToOneOf(url.Scheme, "h2", "https") ? [](bool h2) {
                static const TTLSContext ctx[] = {
                    TTLSContext({.Client = true, .Protocols = {"http/1.1"}}),
                    TTLSContext({.Client = true, .Protocols = {"h2", "http/1.1"}}),
                };
                return &ctx[h2];
            }(h2) : nullptr;
            return [=, ka = MakeHolder<TKeepAliveStack>()](IStreamPtr& req, const TBackendOptions& opts) {
                auto log = req->Log().Fork<NEv::TProxyBackend>(url.ToString());
                TConnEntry conn;
                // Don't bother locking the mutex if we don't want (or can't use) keepalive.
                if (!tcp && opts.Keepalive != cone::timedelta::zero() && !ka->Pop(conn) MUN_RETHROW) {
                    return false;
                }
                if (!conn.Client && !::cone->timeout(opts.ConnTimeout, [&]{
                    auto t = Timer(opts.ConnTime);
                    // TInstant connStartTime = TInstant::Now();
                    TConnectionOptions co = opts.Conn;
                    co.ForceH2 = h2 && !tls;
                    // XXX `IConnection` holds a reference to the `IO`, so passing fd
                    //     directly is bad, as `TConnEntry` will become immovable.
                    bool ok = (conn.File = TFile::Connect(ip))
                           && (conn.IO = tls ? tls->Wrap(conn.File, TString(url.Host)) : THolder<IO>(new TFile(std::move(conn.File))))
                           && (conn.Client = tcp ? RawTCPClient(*conn.IO) : H2Client(*conn.IO, emptyWritesSignal, co));

                    // req->Stats.ConnectTime = TInstant::Now() - connStartTime;
                    log.Push<NEv::TProxyNewConnection>();
                    return ok;
                }) MUN_RETHROW) {
                    return false;
                }
                Y_DEFER {
                    if (conn.Client->IsIdle()) {
                        ka->Push(std::move(conn), opts.Keepalive);
                    }
                };
                return Exchange(opts, *req, *conn.Client, std::move(log));
            };
        });
    }

    class TMetricsStream : public TStreamProxy {
    public:
        TMetricsStream(IStreamPtr& req, const TBackendOptions& opts) noexcept
            : TStreamProxy(req)
            , Opts_(opts)
        {}

        bool WriteHead(THead& head) noexcept override {
            (*Opts_.StatusCodes[std::min(head.Code, 599) / 100 - 1])++;
            // if (100 <= head.Code && head.Code < 600) {
            //     ++Stats.BackendStatusCodeCategs[head.Code / 100 - 1];
            // }
            return S_->WriteHead(head);
        }

    private:
        const TBackendOptions& Opts_;
    };

    TBackend Backend(const YAML::Node& arg, TAuxData& aux) {
        auto url = URL::Parse(TString(arg.Scalar()));
        if (!url || !url->Scheme) {
            return [f = aux.Action(arg), sigs = aux.CollectingSignals()](IStreamPtr& req, const TBackendOptions& opts) {
                if (sigs) {
                    req = std::make_shared<TMetricsStream>(req, opts);
                }
                return f(req);
            };
        }
        auto ip = IP::Parse(url->Host, url->Port);
        CHECK_NODE(arg, ip, "not an IP address (use tools/resolver for name resolution)");
        CHECK_NODE(arg, url->Path == "", "path must not be specified");
        CHECK_NODE(arg, EqualToOneOf(url->Scheme, "http", "https", "h2", "h2c", "udp", "tcp", "tcps"), "unsupported protocol");
        NSv::TNumber<ui64>* emptyWritesSignal = &aux.CustomSignal("http-empty-writes_summ");
        return Backend(*url, ip, emptyWritesSignal);
    }

    TWeightsHolder::TWeightsHolder(TString path, TVector<TString> names)
        : Path_(std::move(path))
        , Names_(std::move(names))
    {
        // TODO `CHECK_NODE`
        Y_ENSURE(Update(), "could not load any non-zero weights from " << Path_);
        Updater_ = [this]() {
            while (cone::sleep_for(std::chrono::seconds(5))) {
                Update();
            }
            return false;
        };
    }

    bool TWeightsHolder::Update() {
        try {
            THolder<TWeightsVector> res = MakeHolder<TWeightsVector>(Names_.size(), std::numeric_limits<double>::quiet_NaN());
            TFileInput in(Path_);
            for (TString line; in.ReadLine(line); ) {
                TStringBuf value = line;
                TStringBuf name = value.NextTok(',');
                // This mode is not intended for long lists.
                auto it = Find(Names_.begin(), Names_.end(), StripString(name));
                if (it == Names_.end()) {
                    continue;
                }
                float tmp;
                if (!TryFromString(StripString(value), tmp) || tmp < 0) {
                    return false; // TODO log
                }
                (*res)[it - Names_.begin()] = 1. / tmp;
            }
            if (FindIf(res->begin(), res->end(), IsFinite) == res->end()
             || FindIf(res->begin(), res->end(), IsNan) != res->end()) {
                return false;
            }
            AtomicStore(res.Release());
            return true;
        } catch (const TFileError&) {
            return false; // ¯\_(ツ)_/¯ TODO log
        }
    }

    int ParseCode(const YAML::Node& node) {
        const auto& str = node.Scalar();
        if (str.size() == 3 && '0' <= str[0] && str[0] <= '9' && str[1] == 'x' && str[2] == 'x') {
            return -(str[0] - '0') * 100;
        }
        return Required<int>(node, [](int code) {
            return 0 < code && code < 1000;
        });
    }

    TFunction<TMaybe<ui32>(const THead&)> ParseHashSpec(const YAML::Node& node) {
        // TODO other modes, e.g. !query?
        TStringBuf tag = node.Tag();
        if (tag.SkipPrefix("!net/")) {
            TStringBuf v4, v6;
            CHECK_NODE(node, tag.TrySplit('/', v4, v6), "syntax: !net/v4mask/v6mask");
            auto m4 = IP::ParseMask(v4, AF_INET);
            auto m6 = IP::ParseMask(v6, AF_INET6);
            CHECK_NODE(node, m4 && m6, "invalid netmasks");
            return [header = node.Scalar(), m4 = *m4, m6 = *m6](const THead& h) -> TMaybe<ui64> {
                for (auto [a, b] = h.equal_range(header); a != b; a++) {
                    if (auto addr = IP::Parse(a->second)) {
                        return THash<IP::TRaw>{}(addr.Raw(addr.IsV4() ? m4 : m6));
                    }
                }
                return {};
            };
        }
        return [header = node.Scalar()](const THead& h) -> TMaybe<ui64> {
            TMaybe<ui64> x;
            for (auto [a, b] = h.equal_range(header); a != b; a++) {
                x = CombineHashes<ui64>(x ? *x : 0, THash<TStringBuf>{}(a->second));
            }
            return x;
        };
    }
}
