#include <library/cpp/testing/gtest/gtest.h>
#include <cstdlib>
#include <cstdio>
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <unistd.h>
#include <thread>
#include <util/generic/scope.h>
#include <infra/cauth/agent/linux/yandex-cauth-userd/pkg/cauthrpc/models.pb.h>
#include <infra/cauth/agent/linux/nss_cauth_userd/http/http.h>

const char *sockPathTemplate = "httpsockXXXXXX";

char *mkTmpSock() {
    char *path = (char *) malloc(strlen(sockPathTemplate));
    if (path == nullptr) return nullptr;
    strcpy(path, sockPathTemplate);
    return mktemp(path);
}

void cleanupSock(int fd, const char *path) {
    close(fd);
    unlink(path);
}

int mkSock(const char *path) {
    struct stat s = {};
    int ret = stat(path, &s);
    if (ret < 0 && errno != ENOENT) {
        return -1;
    }
    if (ret == 0) {
        unlink(path);
    }
    int fd = socket(PF_UNIX, SOCK_STREAM, 0);
    if (fd < 0) return -1;
    struct sockaddr_un addr = {};
    addr.sun_family = AF_UNIX;
    strcpy(addr.sun_path, path);
    ret = bind(fd, (struct sockaddr *) &addr, sizeof(addr));
    if (ret < 0) {
        cleanupSock(fd, path);
        return -1;
    }
    ret = listen(fd, 0);
    if (ret < 0) {
        cleanupSock(fd, path);
        return -1;
    }
    return fd;
}

void writeChunks(size_t chunkSz, const char *data, size_t dataSz, TStringStream &s) {
    size_t nChunks = dataSz / chunkSz;
    for (size_t i = 0; i <= nChunks; i++) {
        size_t partSz = dataSz > chunkSz ? chunkSz : dataSz;
        char hexSz[16] = {0};
        snprintf(hexSz, 16, "%lx", partSz);
        s << hexSz << "\r\n";
        s.Write(data + i * chunkSz, partSz);
        s << "\r\n";
    }
    s << "0\r\n\r\n";
}

void plainServerThreadFunc(int fd, const char *data, size_t dataSz, bool *srvOk) {
    int conn = accept(fd, nullptr, nullptr);
    if (conn < 0) {
        *srvOk = false;
        return;
    }
    char buf[4096];
    ssize_t recvSz = 0;
    while((recvSz = recv(conn, buf, 4096, 0)) < 0) {
        if (recvSz < 0) {
            if (errno == EINTR) continue;
            if (errno == EAGAIN || errno == EWOULDBLOCK) continue;
            *srvOk = false;
            close(conn);
            return;
        }
    }
    if (NCAuth::NSS::Send(conn, data, dataSz, [](){return false;}) != ssize_t(dataSz)) {
        *srvOk = false;
    }
    close(conn);
}

TEST(HTTP, SinglePlainResponse) {
    char *path = mkTmpSock();
    ASSERT_NE(path, nullptr);
    Y_DEFER { free(path); };
    int fd = mkSock(path);
    if (fd < 0) FAIL();

    ::NCAuth::NCAuthAgent::NRpc::Group group;
    group.set_name("mock");
    group.set_gid(100000);
    group.add_members("warwish");
    TString groupWire;
    ASSERT_TRUE(group.SerializeToString(&groupWire));
    TStringStream response;
    response << "HTTP/1.1 200 OK\r\n"
                "Content-Type: application/x-protobuf\r\n"
                "Date: Wed, 06 Oct 2021 11:47:59 GMT\r\n"
                "Content-Length: " << groupWire.size() << "\r\n"
                "Connection: close\r\n\r\n";
    response.Write(groupWire.data(), groupWire.size());
    bool srvOk = true;
    std::thread srv(plainServerThreadFunc, fd, response.Data(), response.Size(), &srvOk);
    Y_DEFER { cleanupSock(fd, path); };
    TStringStream result;
    result.Reserve(groupWire.size());
    long rv = ::NCAuth::NSS::QueryUserd(path, "/nss/v1/group/mock", &result);
    srv.join();
    ASSERT_TRUE(srvOk);
    ASSERT_EQ(rv, 200);
    ASSERT_TRUE(strncmp(result.Data(), groupWire.data(), Min(groupWire.size(), result.Size())) == 0);
}

TEST(HTTP, SingleChunkedResponse) {
    char *path = mkTmpSock();
    ASSERT_NE(path, nullptr);
    Y_DEFER { free(path); };
    int fd = mkSock(path);
    if (fd < 0) FAIL();

    ::NCAuth::NCAuthAgent::NRpc::Group group;
    group.set_name("mock");
    group.set_gid(100000);
    group.add_members("warwish");
    TString groupWire;
    ASSERT_TRUE(group.SerializeToString(&groupWire));
    TStringStream response;
    response << "HTTP/1.1 200 OK\r\n"
                "Content-Type: application/x-protobuf\r\n"
                "Date: Wed, 06 Oct 2021 11:47:59 GMT\r\n"
                "Transfer-Encoding: chunked\r\n"
                "Connection: close\r\n\r\n";

    writeChunks(10, groupWire.data(), groupWire.size(), response);
    bool srvOk = true;
    std::thread srv(plainServerThreadFunc, fd, response.Data(), response.Size(), &srvOk);
    Y_DEFER { cleanupSock(fd, path); };
    TStringStream result;
    result.Reserve(groupWire.size());
    long rv = ::NCAuth::NSS::QueryUserd(path, "/nss/v1/group/mock", &result);
    srv.join();
    ASSERT_TRUE(srvOk);
    ASSERT_EQ(rv, 200);
    ASSERT_TRUE(strncmp(result.Data(), groupWire.data(), Min(groupWire.size(), result.Size())) == 0);
}

void partialSendServerThreadFunc(int fd, size_t partSz, const char *data, size_t dataSz, bool *srvOk) {
    int conn = accept(fd, nullptr, nullptr);
    if (conn < 0) {
        *srvOk = false;
        return;
    }
    char buf[4096];
    ssize_t recvSz = 0;
    while ((recvSz = recv(conn, buf, 4096, 0)) < 0) {
        if (recvSz < 0) {
            if (errno == EINTR) continue;
            if (errno == EAGAIN || errno == EWOULDBLOCK) continue;
            *srvOk = false;
            return;
        }
    }
    size_t nParts = dataSz / partSz, i = 0;
    while (i <= nParts){
        size_t portionSz = 0;
        if (dataSz - i * partSz >= partSz) {
            portionSz = partSz;
        } else {
            portionSz = dataSz - i * partSz;
        }
        if (portionSz <= 0) break;
        if (NCAuth::NSS::Send(conn, data + i*partSz, portionSz, [](){return false;}) < 0) {
            close(conn);
            *srvOk = false;
            return;
        }
        std::this_thread::sleep_for(std::chrono::milliseconds(50));
        i++;
    }
    close(conn);
}

TEST(HTTP, PartialPlainResponse) {
    char *path = mkTmpSock();
    ASSERT_NE(path, nullptr);
    Y_DEFER { free(path); };
    int fd = mkSock(path);
    if (fd < 0) FAIL();

    ::NCAuth::NCAuthAgent::NRpc::Group group;
    group.set_name("mock");
    group.set_gid(100000);
    group.add_members("warwish");
    TString groupWire;
    ASSERT_TRUE(group.SerializeToString(&groupWire));
    TStringStream response;
    response << "HTTP/1.1 200 OK\r\n"
                "Content-Type: application/x-protobuf\r\n"
                "Date: Wed, 06 Oct 2021 11:47:59 GMT\r\n"
                "Content-Length: " << groupWire.size() << "\r\n"
                "Connection: close\r\n\r\n";
    response.Write(groupWire.data(), groupWire.size());
    bool srvOk = true;
    std::thread srv(partialSendServerThreadFunc, fd, response.Size()/4, response.Data(), response.Size(), &srvOk);
    Y_DEFER { cleanupSock(fd, path); };
    TStringStream result;
    result.Reserve(groupWire.size());
    long rv = ::NCAuth::NSS::QueryUserd(path, "/nss/v1/group/mock", &result);
    srv.join();
    ASSERT_TRUE(srvOk);
    ASSERT_EQ(rv, 200);
    ASSERT_TRUE(strncmp(result.Data(), groupWire.data(), Min(groupWire.size(), result.Size())) == 0);
}

TEST(HTTP, PartialPlainResponseLarge) {
    char *path = mkTmpSock();
    ASSERT_NE(path, nullptr);
    Y_DEFER { free(path); };
    int fd = mkSock(path);
    if (fd < 0) FAIL();

    ::NCAuth::NCAuthAgent::NRpc::Group group;
    group.set_name("mock");
    group.set_gid(100000);
    for (int i = 0; i < 1000; i++) {
        group.add_members("warwish");
    }
    TString groupWire;
    ASSERT_TRUE(group.SerializeToString(&groupWire));
    TStringStream response;
    response << "HTTP/1.1 200 OK\r\n"
                "Content-Type: application/x-protobuf\r\n"
                "Date: Wed, 06 Oct 2021 11:47:59 GMT\r\n"
                "Content-Length: " << groupWire.size() << "\r\n"
                "Connection: close\r\n\r\n";
    response.Write(groupWire.data(), groupWire.size());
    bool srvOk = true;
    std::thread srv(partialSendServerThreadFunc, fd, response.Size()/4, response.Data(), response.Size(), &srvOk);
    Y_DEFER { cleanupSock(fd, path); };
    TStringStream result;
    result.Reserve(groupWire.size());
    long rv = ::NCAuth::NSS::QueryUserd(path, "/nss/v1/group/mock", &result);
    srv.join();
    ASSERT_TRUE(srvOk);
    ASSERT_EQ(rv, 200);
    ASSERT_TRUE(strncmp(result.Data(), groupWire.data(), Min(groupWire.size(), result.Size())) == 0);
}

TEST(HTTP, PartialChunkedResponse) {
    char *path = mkTmpSock();
    ASSERT_NE(path, nullptr);
    Y_DEFER { free(path); };
    int fd = mkSock(path);
    if (fd < 0) FAIL();

    ::NCAuth::NCAuthAgent::NRpc::Group group;
    group.set_name("mock");
    group.set_gid(100000);
    group.add_members("warwish");
    TString groupWire;
    ASSERT_TRUE(group.SerializeToString(&groupWire));
    TStringStream response;
    response << "HTTP/1.1 200 OK\r\n"
                "Content-Type: application/x-protobuf\r\n"
                "Date: Wed, 06 Oct 2021 11:47:59 GMT\r\n"
                "Transfer-Encoding: chunked\r\n"
                "Connection: close\r\n\r\n";

    writeChunks(10, groupWire.data(), groupWire.size(), response);
    bool srvOk = true;
    std::thread srv(partialSendServerThreadFunc, fd, response.Size()/4, response.Data(), response.Size(), &srvOk);
    Y_DEFER { cleanupSock(fd, path); };
    TStringStream result;
    result.Reserve(groupWire.size());
    long rv = ::NCAuth::NSS::QueryUserd(path, "/nss/v1/group/mock", &result);
    srv.join();
    ASSERT_TRUE(srvOk);
    ASSERT_EQ(rv, 200);
    ASSERT_TRUE(strncmp(result.Data(), groupWire.data(), Min(groupWire.size(), result.Size())) == 0);
}

TEST(HTTP, PartialChunkedResponseLarge) {
    char *path = mkTmpSock();
    ASSERT_NE(path, nullptr);
    Y_DEFER { free(path); };
    int fd = mkSock(path);
    if (fd < 0) FAIL();

    ::NCAuth::NCAuthAgent::NRpc::Group group;
    group.set_name("mock");
    group.set_gid(100000);
    for (int i = 0; i < 1000; i++) {
        group.add_members("warwish");
    }
    TString groupWire;
    ASSERT_TRUE(group.SerializeToString(&groupWire));
    TStringStream response;
    response << "HTTP/1.1 200 OK\r\n"
                "Content-Type: application/x-protobuf\r\n"
                "Date: Wed, 06 Oct 2021 11:47:59 GMT\r\n"
                "Transfer-Encoding: chunked\r\n"
                "Connection: close\r\n\r\n";

    writeChunks(10, groupWire.data(), groupWire.size(), response);
    bool srvOk = true;
    std::thread srv(partialSendServerThreadFunc, fd, response.Size()/4, response.Data(), response.Size(), &srvOk);
    Y_DEFER { cleanupSock(fd, path); };
    TStringStream result;
    result.Reserve(groupWire.size());
    long rv = ::NCAuth::NSS::QueryUserd(path, "/nss/v1/group/mock", &result);
    srv.join();
    ASSERT_TRUE(srvOk);
    ASSERT_EQ(rv, 200);
    ASSERT_TRUE(strncmp(result.Data(), groupWire.data(), Min(groupWire.size(), result.Size())) == 0);
}

void partialRecvServerThreadFunc(int fd, const char *data, size_t dataSz, bool *srvOk) {
    int conn = accept(fd, nullptr, nullptr);
    if (conn < 0) {
        *srvOk = false;
        return;
    }
    char buf[32];
    ssize_t recvSz = 0;
    while((recvSz = recv(conn, buf, 32, 0)) != 0) {
        if (recvSz < 0) {
            if (errno == EINTR) continue;
            if (errno == EAGAIN || errno == EWOULDBLOCK) continue;
            *srvOk = false;
            close(conn);
            return;
        }
        if (strncmp(buf + recvSz - 4, "\r\n\r\n", 4) == 0) break;
    }
    if (NCAuth::NSS::Send(conn, data, dataSz, [](){return false;}) != ssize_t(dataSz)) {
        *srvOk = false;
    }
    close(conn);
}

TEST(HTTP, RemotePartialRecv) {
    char *path = mkTmpSock();
    ASSERT_NE(path, nullptr);
    Y_DEFER { free(path); };
    int fd = mkSock(path);
    if (fd < 0) FAIL();

    ::NCAuth::NCAuthAgent::NRpc::Group group;
    group.set_name("mock");
    group.set_gid(100000);
    group.add_members("warwish");
    TString groupWire;
    ASSERT_TRUE(group.SerializeToString(&groupWire));
    TStringStream response;
    response << "HTTP/1.1 200 OK\r\n"
                "Content-Type: application/x-protobuf\r\n"
                "Date: Wed, 06 Oct 2021 11:47:59 GMT\r\n"
                "Content-Length: " << groupWire.size() << "\r\n"
                "Connection: close\r\n\r\n";
    response.Write(groupWire.data(), groupWire.size());
    bool srvOk = true;
    std::thread srv(partialRecvServerThreadFunc, fd, response.Data(), response.Size(), &srvOk);
    Y_DEFER { cleanupSock(fd, path); };
    TStringStream result;
    result.Reserve(groupWire.size());
    long rv = ::NCAuth::NSS::QueryUserd(path, "/nss/v1/group/mock", &result);
    srv.join();
    ASSERT_TRUE(srvOk);
    ASSERT_EQ(rv, 200);
    ASSERT_TRUE(strncmp(result.Data(), groupWire.data(), Min(groupWire.size(), result.Size())) == 0);
}

void slowServerThreadFunc(int fd, const char *data, size_t dataSz, bool *srvOk) {
    int conn = accept(fd, nullptr, nullptr);
    if (conn < 0) {
        *srvOk = false;
        return;
    }
    char buf[4096];
    ssize_t recvSz = 0;
    while((recvSz = recv(conn, buf, 4096, 0)) < 0) {
        if (recvSz < 0) {
            if (errno == EINTR) continue;
            if (errno == EAGAIN || errno == EWOULDBLOCK) continue;
            *srvOk = false;
            close(conn);
            return;
        }
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(500));
    if (NCAuth::NSS::Send(conn, data, dataSz, [](){return false;}) != ssize_t(dataSz)) {
        *srvOk = false;
    }
    close(conn);
}

TEST(HTTP, SlowServer) {
    char *path = mkTmpSock();
    ASSERT_NE(path, nullptr);
    Y_DEFER { free(path); };
    int fd = mkSock(path);
    if (fd < 0) FAIL();

    ::NCAuth::NCAuthAgent::NRpc::Group group;
    group.set_name("mock");
    group.set_gid(100000);
    group.add_members("warwish");
    TString groupWire;
    ASSERT_TRUE(group.SerializeToString(&groupWire));
    TStringStream response;
    response << "HTTP/1.1 200 OK\r\n"
                "Content-Type: application/x-protobuf\r\n"
                "Date: Wed, 06 Oct 2021 11:47:59 GMT\r\n"
                "Content-Length: " << groupWire.size() << "\r\n"
                "Connection: close\r\n\r\n";
    response.Write(groupWire.data(), groupWire.size());
    bool srvOk = true;
    std::thread srv(slowServerThreadFunc, fd, response.Data(), response.Size(), &srvOk);
    Y_DEFER { cleanupSock(fd, path); };
    TStringStream result;
    result.Reserve(groupWire.size());
    long rv = ::NCAuth::NSS::QueryUserd(path, "/nss/v1/group/mock", &result);
    srv.join();
    ASSERT_FALSE(srvOk);
    ASSERT_EQ(rv, -1);
}
