#include <balancer/serval/contrib/cone/cold.h>
#include <balancer/serval/contrib/cone/cone.hh>
#include <balancer/serval/core/address.h>
#include <balancer/serval/core/config.h>
#include <balancer/serval/core/io.h>
#include <balancer/serval/core/log.h>
#include <balancer/serval/core/tls.h>
#include <balancer/serval/core/http.h>

#include <library/cpp/cgiparam/cgiparam.h>
#include <library/cpp/getopt/small/last_getopt.h>
#include <library/cpp/threading/hot_swap/hot_swap.h>

#include <util/datetime/base.h>
#include <util/generic/hash_set.h>
#include <util/stream/file.h>
#include <util/string/split.h>
#include <util/system/info.h>
#include <util/system/mlock.h>

#include <signal.h>
#include <sys/socket.h>

static TString FormatYamlError(const YAML::Exception& e, const TString& path) {
    auto [file, line] = NSv::MapYAMLLine(path, e.mark.line);
    if (line < 0) {
        return TStringBuilder() << "In " << file << ": " << e.msg << Endl;
    }
    TVector<TString> lines = StringSplitter(TFileInput(file).ReadAll()).Split('\n');
    TStringBuilder out;
    out << "In " << file << " line " << line + 1 << " column " << e.mark.column + 1 << ": " << e.msg << Endl;
    for (int j = Max<int>(line, 3) - 3; j < Min<int>(line + 3, lines.size()); j++) {
        out << (line == j ? " -> " : "    ") << lines[j] << Endl;
    }
    return out;
}

struct TTLSContextList : TAtomicRefCount<TTLSContextList>, TVector<NSv::TTLSContext> {
    using TVector<NSv::TTLSContext>::TVector;
};

static TTLSContextList ParseTLS(const YAML::Node& node) {
    TTLSContextList out;
    if (node) {
        const auto& certs = node["certs"];
        // FIXME this should point to the key ("tls"), else the location might be confusing.
        CHECK_NODE(node, node.IsMap() && node["certs"], "`tls` must be a map with `certs`");
        CHECK_NODE(certs, certs.IsSequence(), "must be a sequence, one element per certificate pair");
        TVector<std::array<ui8, 48>> keys;
        if (const auto& keyFiles = node["ticket-keys"]) {
            CHECK_NODE(keyFiles, keyFiles.IsSequence(), "must be a sequence");
            for (const auto& keyFile : keyFiles) try {
                CHECK_NODE(keyFile, keyFile.IsScalar(), "must be a string");
                TFileInput in(TString(keyFile.Scalar()));
                auto current = NSv::ReadTicketKeys(in);
                CHECK_NODE(keyFile, current, "could not read any keys from this file");
                keys.insert(keys.end(), current.begin(), current.end());
            } catch (const yexception& e) {
                FAIL_NODE(keyFile, e.what());
            }
        }
        auto ttl = NSv::Optional<TDuration>(node["ticket-ttl"], TDuration::Seconds(28800));
        for (const auto& ctx : certs) try {
            TVector<NSv::TCertDescr> parsed;
            CHECK_NODE(ctx, ctx.IsScalar() || ctx.IsMap(), "should be a path to a PEM file or "
                       "a map of `cert: key` (or `cert: -` if they are in one file)");
            if (ctx.IsScalar()) {
                parsed.emplace_back(TString(ctx.Scalar()));
            } else {
                for (const auto& cert : ctx) {
                    CHECK_NODE(cert.first, cert.first.IsScalar(), "should be a path");
                    CHECK_NODE(cert.second, cert.second.IsScalar(), "should be a path");
                    if (cert.second.Scalar() == "-") {
                        parsed.emplace_back(TString(cert.first.Scalar()));
                    } else {
                        parsed.emplace_back(TString(cert.second.Scalar()), TString(cert.first.Scalar()));
                    }
                }
            }
            // TODO ciphers
            out.push_back(NSv::TTLSContext({
                .Certs = parsed,
                .Protocols = {"h2", "http/1.1"},
                .TicketKeys = keys,
                .TicketTTL = ttl,
            }));
        } catch (const yexception& e) {
            FAIL_NODE(ctx, e.what());
        }
        for (size_t i = 1; i < out.size(); i++) {
            out[i - 1].SetNext(out[i]);
        }
    }
    return out;
}

struct TActionConfig : TAtomicRefCount<TActionConfig> {
    NSv::TAction Main;
    NSv::TAuxData Aux;
};

static THolder<TActionConfig> ParseActions(const YAML::Node& nodes) {
    auto out = MakeHolder<TActionConfig>();
    if (nodes) {
        CHECK_NODE(nodes, nodes.IsSequence(), "`actions` must be a sequence of action definitions");
        for (const auto& node : nodes) {
            CHECK_NODE(node, node.IsMap() && node.size() == 1, "`actions` items must be `name: definition`");
            auto a = node.begin()->first;
            auto b = node.begin()->second;
            CHECK_NODE(a, a.IsScalar(), "action name must be a string");
            CHECK_NODE(a, out->Aux.AddAction(TString(a.Scalar()), b), "action `" << a.Scalar() << "` already exists");
        }
    }
    out->Main = out->Aux.Action("main", YAML::Node{});
    out->Aux.Freeze();
    return out;
}

using TBindPoints = TVector<std::pair<NSv::TFile, TStringBuf /*scheme*/>>;

static TBindPoints ParseBind(const YAML::Node& bind, bool haveSSL, bool allowUDP, bool reuseAddr, bool validation = false) {
    TBindPoints out;
    CHECK_NODE(bind, bind.IsSequence() && bind.size(), "bind point list must be a non-empty sequence");
    for (const auto& addr : bind) {
        CHECK_NODE(addr, addr.IsScalar(), "addresses should be strings");
        auto url = NSv::URL::Parse(TString(addr.Scalar()));
        CHECK_NODE(addr, url, mun_last_error()->text);
        CHECK_NODE(addr, url->Path == "", "bind points do not have paths");
        auto scheme = NSv::Interned(url->Scheme, "http", "https", "udp", "tcp", "tcps");
        CHECK_NODE(addr, scheme, "bind points must be http: or https:");
        CHECK_NODE(addr, !EqualToOneOf(url->Scheme, "https", "tcps") || haveSSL, "https: bind points require at least one TLS context");
        CHECK_NODE(addr, !EqualToOneOf(url->Scheme, "udp") || allowUDP, "udp: is not allowed here");
        auto net = NSv::IP::Parse(url->Host, url->Port);
        CHECK_NODE(addr, net, "bind point does not specify a valid address");
        if (!validation) {
            auto sk = scheme == "udp" ? NSv::TFile::Bind<SOCK_DGRAM>(net, reuseAddr) : NSv::TFile::Bind(net, reuseAddr);
            if (!sk MUN_RETHROW) {
                FAIL_NODE(addr, "binding failed: (" << mun_errno << ") " << mun_last_error()->text);
            }
            out.emplace_back(std::move(sk), scheme);
        }
    }
    return out;
}

static bool Respond(NSv::IStreamPtr req, int code, TStringBuf data) {
    auto length = ToString(data.size());
    return req->WriteHead({code, {
        {"cache-control", "no-cache"},
        {"content-length", length},
        {"content-type", "text/plain; charset=utf-8"}
    }}) && req->Write(data) && req->Close();
}

static auto SignalPipe = []() -> std::pair<NSv::TFile, NSv::TFile> {
    int p[2];
    Y_ENSURE(::pipe(p) >= 0 && cold_unblock(p[0]) >= 0, "could not create a signal pipe");
    return {p[0], p[1]};
}();

static void SignalHandler(int num) noexcept {
    signal(num, SIG_DFL);
    SignalPipe.second.Write({(const char*)&num, sizeof(int)});
}

int main(int argc, const char** argv) {
    signal(SIGPIPE, SIG_IGN);
    signal(SIGCHLD, SIG_IGN);
    signal(SIGTERM, &SignalHandler);
    signal(SIGINT, &SignalHandler);

    TString configPath;
    TString logPath;
    bool validateOnly = false;
    bool printOnly = false;
    bool mLock = false;
    NLastGetopt::TOpts opts = NLastGetopt::TOpts::Default();
    opts.SetFreeArgsNum(0);
    opts.AddLongOption('c', "config", "Path to the YAML configuration file.")
        .Required().RequiredArgument("<path>").StoreResult(&configPath);
    opts.AddLongOption('l', "log", "Where to write the log file.")
        .RequiredArgument("<path>").StoreResult(&logPath);
    opts.AddLongOption('K', "validate", "do not start serval, just validate config")
            .StoreValue(&validateOnly, true).NoArgument();
    opts.AddLongOption('P', "print", "do not start serval, just print the raw config")
            .StoreValue(&printOnly, true).NoArgument();
    opts.AddLongOption('L', "mlock", "lock the binary in memory")
            .StoreValue(&mLock, true).NoArgument();
    NLastGetopt::TOptsParseResult optParse(&opts, argc, argv);

    if (mLock) {
        LockAllMemory(LockCurrentMemory);
    }

    // The destruction order:
    //  1. stop listening for admin connections (doing this before worker shutdown allows
    //     assuming that if the admin API is accessible, then so is the actual service,
    //     and also that nobody will attempt to reload the config while we're waiting);
    //  2. close existing admin connections;
    //  3. shutdown all workers;
    //  4. drop the rest of the stuff, nothing uses it anymore.
    cone::mutex coneStatsMut;
    THashSet<std::pair<const std::atomic<unsigned>*, const std::atomic<mun_usec>*>> coneStats;
    THashMap<TStringBuf, NSv::TAction> adminActions;
    THotSwap<TActionConfig> actions;
    THotSwap<TTLSContextList> tls(new TTLSContextList);
    NSv::TNumber<ui64> configReloads{"config-reloads_dmmm"};
    NSv::TNumber<ui64> connections{"conn_dmmm"};
    NSv::TNumber<ui64> connectionsNow{"conn-active_ammv"};
    NSv::THistogram connectionTime{"conn-time_dhhh"};
    NSv::TLog log{4096};
    TBindPoints workerSockets;
    TBindPoints adminSockets;
    TVector<cone::guard> workers;
    // Wait for termination simultaneously. This actually cancels each coroutine twice,
    // the second time being in cone::guard::~guard, but whatever. `cone::mguard` does
    // the same thing correctly, but it can't store references to threads.
    Y_DEFER { for (auto& c : workers) c->cancel(); };
    cone::mguard adminConnections;
    cone::mguard adminListeners;

    auto accept = [&](NSv::TFile& sk, bool hasTLS, cone::mguard& conns, auto&& f) noexcept {
        return [&, hasTLS, f]() mutable {
            // If coroutine scheduling delay is high due to overload, this limits the rate
            // at which connections are accepted.
            while (cone::yield()) {
                auto child = NSv::TFile::Accept(sk);
                if (!child MUN_RETHROW_OS) {
                    if (mun_errno == ECANCELED) {
                        return false;
                    }
                    // Everything else is either EAGAIN-like (e.g. errors on the new socket)
                    // or should never happen (e.g. EBADF), except for EMFILE and ENFILE.
                    // TODO on E{M,N}FILE, close some connections or something.
                    continue;
                }
                conns.add([&, fd = std::move(child), tls = hasTLS ? tls.AtomicLoad() : nullptr, f]() mutable {
                    if (!tls) {
                        return f(fd);
                    }
                    auto io = (*tls)[0].Wrap(fd);
                    return io && f(*io);
                });
            }
            return false;
        };
    };

    auto runWorker = [&](TBindPoints& sockets, NSv::TNumber<ui64>* emptyWritesSignal) {
        NSv::TThreadLocalRoot tlsRoot;
        if (auto g = coneStatsMut.guard()) {
            coneStats.insert({cone::count(), cone::delay()});
        }
        Y_DEFER {
            auto g = coneStatsMut.guard();
            coneStats.erase({cone::count(), cone::delay()});
        };
        // As above, destruction order is important.
        cone::mguard accepted;
        cone::mguard listeners;
        for (auto& [socket, proto] : sockets) {
            NSv::THandler handler = [&, proto = proto](NSv::IStreamPtr s) {
                auto rqh = s->Head();
                if (!rqh MUN_RETHROW) {
                    return false;
                }
                auto it = rqh->find(":scheme");
                if (it != rqh->end() && it->second == "unknown") {
                    it->second = proto;
                }
                return actions.AtomicLoad()->Main(s);
            };
            if (proto == "udp") {
                listeners.add([&, &socket = socket, handler]() mutable {
                    return NSv::TestUDPServer(socket, handler, NSv::TLogFrame(log))->Wait();
                });
                continue;
            }
            bool raw = EqualToOneOf(proto, "tcp", "tcps");
            bool encrypt = EqualToOneOf(proto, "https", "tcps");
            listeners.add(accept(socket, encrypt, accepted, [&, raw, handler](NSv::IO& io) {
                connections++;
                connectionsNow++;
                Y_DEFER { connectionsNow--; };
                auto timer = NSv::Timer(connectionTime);
                auto frame = NSv::TLogFrame(log).Fork<NSv::NEv::TAcceptConnection>(io.Peer().FormatFull());
                return raw ? NSv::RawTCPServer(io, handler, std::move(frame))->Wait()
                           : NSv::Serve(io, handler, emptyWritesSignal, std::move(frame));
            }));
        }
        return cone::event{}.wait();
    };

    auto adminHandle = [&](NSv::IO& io) {
        return NSv::Serve(io, [&](NSv::IStreamPtr req) {
            auto rqh = req->Head();
            if (!rqh) {
                return false;
            }
            for (const auto& [path, action] : adminActions) {
                if (rqh->Path() == TString::Join("/", path) || rqh->PathWithQuery == TString::Join("/admin?action=", path)) {
                    return action(req);
                }
            }
            return Respond(req, 404, TStringBuilder() << "Unknown handle " << rqh->PathWithQuery << "; try /help\n");
        }, nullptr);
    };

    // Ok, this lambda is super long, but there's no other way: every step is basically
    // split in two to ensure atomicity (w.r.t. errors, not worker threads) of the process.
    auto reloadConfig = [&]() -> TMaybe<TString /* error */> {
        try {
            auto yaml = NSv::LoadYAML(configPath);
            CHECK_NODE(yaml, yaml.IsMap(), "config should be a map with at least `actions`, `bind`, and `admin`");
            CHECK_NODE(yaml, yaml["admin"], "no admin interfaces defined");
            CHECK_NODE(yaml, yaml["bind"], "no bind points defined");

            // TLS setup:
            auto newTLS = MakeHolder<TTLSContextList>(ParseTLS(yaml["tls"]));
            // Might be active bind points using TLS -- these need to be shut down.
            CHECK_NODE(yaml, *newTLS || !*tls.AtomicLoad(), "removing all TLS contexts requires a restart");
            // Actions setup:
            auto newActions = ParseActions(yaml["actions"]);
            auto n = NSv::Optional(yaml["workers"], NSystemInfo::CachedNumberOfCpus(), [](size_t n) {
                return n > 0;
            });

            // Bind with SO_REUSEADDR
            bool reuseAddr = NSv::Optional(yaml["reuseaddr"], false);

#if SV_CAN_MULTIBIND
            // Workers setup:
            // TODO reconstruct bind points in existing workers?
            TVector<TBindPoints> newSockets;
            // Bind in the admin thread, then move the file descriptors to the worker. This
            // allows reporting bind errors as exceptions, and also linearizes the calls,
            // avoiding contention on a spinlock (see BALANCER-1425).
            for (size_t i = workers.size(); i < n; i++) {
                newSockets.emplace_back(ParseBind(yaml["bind"], !!*newTLS, true, reuseAddr));
            }
            // Admin setup:
            // A udp listener is also a udp connection; which `cone::mguard` to put it into?
            auto newAdminBind = ParseBind(yaml["admin"], !!*newTLS, false, reuseAddr);
#else
            // Workers setup:
            // Bind points likely haven't changed, so an attempt to rebind will fail.
            // Could close the old sockets and try, but then how to rollback on error?
            if (!workerSockets) {
                workerSockets = ParseBind(yaml["bind"], !!*newTLS, true, reuseAddr);
            }
            // Admin setup:
            if (!adminSockets) {
                adminSockets = ParseBind(yaml["admin"], !!*newTLS, false, reuseAddr);
            }
            auto& newAdminBind = adminSockets;
#endif

            configReloads++;
            // A-a-a-and TLS commit:
            tls.AtomicStore(newTLS.Release());
            // Actions commit:
            actions.AtomicStore(newActions.Release());
            // Workers commit; must be after TLS:
            while (workers.size() > n)
                workers.pop_back();
#if SV_CAN_MULTIBIND
            for (auto& sks : newSockets) {
                workers.emplace_back(cone::thread([&, sks = std::move(sks)]() mutable {
                    NSv::TNumber<ui64>* emptyWritesSignal = &actions.AtomicLoad()->Aux.CustomSignal("http-empty-writes_summ");
                    return runWorker(sks, emptyWritesSignal);
                }));
            }
#else
            while (workers.size() < n) {
                workers.emplace_back(cone::thread([&] {
                    NSv::TNumber<ui64>* emptyWritesSignal = &actions.AtomicLoad()->Aux.CustomSignal("http-empty-writes_summ");
                    return runWorker(workerSockets, emptyWritesSignal);
                }));
            }
#endif
            // Admin commit; must be after TLS and workers (like destruction order, but reverse):
            cone::mguard newAdmin;
            for (auto& [sk, proto] : newAdminBind) {
                newAdmin.add(accept(sk, proto == "https", adminConnections, adminHandle));
            }
            adminListeners = std::move(newAdmin);
#if SV_CAN_MULTIBIND
            adminSockets = std::move(newAdminBind);
#endif
            return {};
        } catch (const YAML::Exception& e) {
            return FormatYamlError(e, configPath);
        } catch (...) {
            return CurrentExceptionMessage() + "\n";
        }
    };

    auto validateConfig = [&]() -> TMaybe<TString> {
        try {
            auto yaml = NSv::LoadYAML(configPath);
            CHECK_NODE(yaml, yaml.IsMap(), "config should be a map with at least `actions`, `bind`, and `admin`");
            CHECK_NODE(yaml, yaml["admin"], "no admin interfaces defined");
            CHECK_NODE(yaml, yaml["bind"], "no bind points defined");

            auto newTLS = MakeHolder<TTLSContextList>(ParseTLS(yaml["tls"]));
            CHECK_NODE(yaml, *newTLS || !*tls.AtomicLoad(), "removing all TLS contexts requires a restart");
            auto newActions = ParseActions(yaml["actions"]);

            bool reuseAddr = NSv::Optional(yaml["reuseaddr"], false);

            Y_UNUSED(ParseBind(yaml["bind"], !!*newTLS, true, reuseAddr, true));
            Y_UNUSED(ParseBind(yaml["admin"], !!*newTLS, false, reuseAddr, true));

            return {};
        } catch (const YAML::Exception& e) {
            return FormatYamlError(e, configPath);
        } catch (...) {
            return CurrentExceptionMessage() + "\n";
        }
    };

    auto printConfig = [&]() -> TMaybe<TString> {
        try {
            auto yaml = NSv::LoadYAML(configPath);
            Cout << YAML::Dump(yaml) << Endl;
            return {};
        } catch (const YAML::Exception& e) {
            return FormatYamlError(e, configPath);
        } catch (...) {
            return CurrentExceptionMessage() + "\n";
        }
    };

    auto reopenLog = [&]() -> TMaybe<TString> {
        if (logPath) try {
            log.Reopen(MakeHolder<TUnbufferedFileOutput>(TFile(logPath, OpenAlways | WrOnly | ForAppend)));
        } catch (...) {
            return CurrentExceptionMessage() + "\n";
        }
        return {};
    };

    adminActions[""] = adminActions["help"] = [&](NSv::IStreamPtr req) {
        return Respond(req, 200, TStringBuilder() <<
            "https://a.yandex-team.ru/arc/trunk/arcadia/balancer/serval\n\n"
            "Actions (/admin?action=X can also be used instead of /X):\n\n"
#if SV_CAN_MULTIBIND
            "  /reload      Hot-swap the config file. New workers and the admin API will\n"
            "               listen on new bind points, old workers continue with old ones.\n\n"
#else
            "  /reload      Hot-swap the config file. Changes in bind points are ignored.\n\n"
#endif
            << (logPath ? "  /reopenlog   Flush and reopen the log file in case it was rotated.\n\n" : "") <<
            "  /stat        Print all non-zero unistat signals.\n\n"
            "  /ping        Pong.\n\n"
            "  /help        This message.\n\n"
            "Graceful shutdown is initiated by SIGINT or SIGTERM. Send again to terminate.\n");
    };

    adminActions["ping"] = [&](NSv::IStreamPtr req) { return Respond(req, 200, "pong\n"); };

    // XXX /_golovan is allowed because we sometimes want itype=apphost. Don't use it.
    adminActions["stat"] = adminActions["_golovan"] = [&](NSv::IStreamPtr req) {
        ui64 cones = 0;
        ui64 delay = 0;
        if (auto g = coneStatsMut.guard()) {
            for (auto v : coneStats) {
                cones += v.first->load();
                delay += v.second->load();
            }
        }
        return Respond(req, 200, NSv::SerializeSignals(
            NSv::TNumber<ui64>("coroutines-active_ammv", cones),
            NSv::TNumber<ui64>("coroutines-delay_avvv", (delay + workers.size() - 1) / workers.size()),
            NSv::TNumber<ui64>("log-overflow_dmmm", log.Overflow()),
            NSv::TNumber<ui64>("log-capacity_ammv", log.Capacity()),
            configReloads, connections, connectionsNow, connectionTime,
            actions.AtomicLoad()->Aux.Signals()) + "\n");
    };

    adminActions["reopenlog"] = [&, lock = MakeHolder<cone::mutex>()](NSv::IStreamPtr req) {
        auto guard = lock->guard(cone::mutex::interruptible);
        if (!guard) {
            return false;
        }
        auto err = reopenLog();
        return err ? Respond(req, 500, *err) : Respond(req, 200, "OK\n");
    };

    adminActions["reload"] = [&, lock = MakeHolder<cone::mutex>()](NSv::IStreamPtr req) {
        auto guard = lock->guard(cone::mutex::interruptible);
        if (!guard) {
            return false;
        }

        if (req->Head()) {
            auto cgi = TCgiParameters(req->Head()->Query());
            auto ptr = cgi.find("new_config_path");
            if (ptr != cgi.end()) {
                configPath = ptr->second;
            }
        }

        auto err = reloadConfig();
        return err ? Respond(req, 500, *err) : Respond(req, 200, "OK\n");
    };

    if (printOnly) {
        auto err = printConfig();
        if (err) {
            Cerr << *err;
            return 1;
        } else {
            return 0;
        }
    }

    if (validateOnly) {
        auto err = validateConfig();
        if (err) {
            Cerr << *err;
            return 1;
        } else {
            return 0;
        }
    }

    if (auto err = reopenLog()) {
        Cerr << *err;
        return 1;
    }
    if (auto err = reloadConfig()) {
        Cerr << *err;
        return 1;
    }
    int num;
    Y_DEFER {
        Cerr << Endl;
    }; // so that ^C-ing in a terminal does not break readline
    return !SignalPipe.first.ReadInto({(const char*)&num, sizeof(int)});
}
