#include "request_context.h"

#include <library/cpp/http/misc/http_headers.h>
#include <library/cpp/openssl/io/stream.h>

#include <contrib/libs/openssl/include/openssl/ssl.h>

namespace NSolomon {
namespace {

CURLcode UseBuiltinCaCertInSSLContext(CURL*, void* sslctx, void* caPath) {
    SSL_CTX_set_cert_store(static_cast<SSL_CTX*>(sslctx), GetBuiltinOpenSslX509Store().Release());
    // It is OK to ignore result of SSL_CTX_load_verify_locations():
    // Dir "/etc/ssl/certs/" may be missing
    SSL_CTX_load_verify_locations(static_cast<SSL_CTX*>(sslctx), nullptr, static_cast<char*>(caPath));
    return CURLE_OK;
}

size_t OnHeaderShim(char* buffer, size_t size, size_t nitems, TRequestContext* self) {
    return self->OnHeader({buffer, size * nitems});
}

size_t OnBodyShim(char* buffer, size_t size, size_t nitems, TRequestContext* self) {
    return self->OnBody({buffer, size * nitems});
}

curl_socket_t OnSocketOpenShim(TRequestContext* self, curlsocktype purpose, struct curl_sockaddr* address) {
    return self->OnSocketOpen(purpose, address);
}

int OnSeekShim(TRequestContext* self, curl_off_t offset, int origin) {
    return self->OnSeek(offset, origin);
}

size_t OnReadShim(char* buffer, size_t size, size_t nitems, TRequestContext* self) {
    return self->OnRead(buffer, size * nitems);
}

#ifdef CURL_VERBOSE
int OnDebugShim(CURL* handle, curl_infotype type, char* data, size_t size, TRequestContext* self) {
    return self->OnDebug(handle, type, {data, size});
}
#endif

} // namespace

void TRequestContext::Finalize(TRequestError error) noexcept {
    try {
        Callback_(IHttpClient::TResult::FromError(std::move(error)));
    } catch (...) {
        // ignore the exception from a user supplied callback if any
    }
}

bool TRequestContext::Finalize(CURLcode result) noexcept {
    bool done = true;

    try {
        if (result == CURLE_OK) {
            long responseCode;
            curl_easy_getinfo(Handle_->Handle(), CURLINFO_RESPONSE_CODE, &responseCode);

            const auto code = HttpCodes(responseCode);
            Response_->Code_ = code;
            // TODO: allow customizing the codes we should retry
            if (code >= 500 && code < 600) {
                done = Retries_ == 0;
                Retries_--;
            }

            if (done) {
                Callback_(IHttpClient::TResult::FromValue(std::move(Response_)));
            }
        } else {
            Callback_(GetError(result));
        }
    } catch (...) {
    }

    return done;
}

TDuration TRequestContext::Init(std::unique_ptr<ICurlHandle> handle, CURLSH* share) {
    Handle_ = std::move(handle);
    auto* h = Handle_->Handle();
    ErrorMessage_[0] = '\0';

    UploadDataView_ = Request_->Data();
    curl_easy_reset(h);

    curl_easy_setopt(h, CURLOPT_SSL_CTX_FUNCTION, UseBuiltinCaCertInSSLContext);
    curl_easy_setopt(h, CURLOPT_SSL_CTX_DATA, CaPath_);
    curl_easy_setopt(h, CURLOPT_SSL_VERIFYHOST, 0); // because fetcher will use cached ip address in URL instead of FQDN
    curl_easy_setopt(h, CURLOPT_SSL_VERIFYPEER, 1); // but we have to verify peer's certificate in any case

    curl_easy_setopt(h, CURLOPT_PRIVATE, this);
    curl_easy_setopt(h, CURLOPT_URL, Url_.c_str());
    curl_easy_setopt(h, CURLOPT_DEFAULT_PROTOCOL, "http");
    curl_easy_setopt(h, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_1_1);
    curl_easy_setopt(h, CURLOPT_TCP_FASTOPEN, 1);
    curl_easy_setopt(h, CURLOPT_NOSIGNAL, 1);

    switch (Request_->Method()) {
        case EHttpMethod::Get:
            curl_easy_setopt(h, CURLOPT_HTTPGET, 1);
            break;
        case EHttpMethod::Post:
            curl_easy_setopt(h, CURLOPT_POST, 1);

            if (Request_->Data().size() < 2_GB) {
                curl_easy_setopt(h, CURLOPT_POSTFIELDSIZE, Request_->Data().size());
            } else {
                curl_easy_setopt(h, CURLOPT_POSTFIELDSIZE_LARGE, Request_->Data().size());
            }

            break;
        case EHttpMethod::Put:
            curl_easy_setopt(h, CURLOPT_POST, 1);
            if (Request_->Data().size() < 2_GB) {
                curl_easy_setopt(h, CURLOPT_POSTFIELDSIZE, Request_->Data().size());
            } else {
                curl_easy_setopt(h, CURLOPT_POSTFIELDSIZE_LARGE, Request_->Data().size());
            }

            curl_easy_setopt(h, CURLOPT_CUSTOMREQUEST, "PUT");
            break;
        case EHttpMethod::Head:
            curl_easy_setopt(h, CURLOPT_HTTPGET, 1);
            curl_easy_setopt(h, CURLOPT_NOBODY, 1);
            break;
    };

    curl_easy_setopt(h, CURLOPT_HEADERFUNCTION, OnHeaderShim);
    curl_easy_setopt(h, CURLOPT_HEADERDATA, this);
    curl_easy_setopt(h, CURLOPT_WRITEFUNCTION, OnBodyShim);
    curl_easy_setopt(h, CURLOPT_WRITEDATA, this);
    curl_easy_setopt(h, CURLOPT_OPENSOCKETFUNCTION, OnSocketOpenShim);
    curl_easy_setopt(h, CURLOPT_OPENSOCKETDATA, this);
    curl_easy_setopt(h, CURLOPT_SEEKFUNCTION, OnSeekShim);
    curl_easy_setopt(h, CURLOPT_SEEKDATA, this);
    curl_easy_setopt(h, CURLOPT_READFUNCTION, OnReadShim);
    curl_easy_setopt(h, CURLOPT_READDATA, this);
    curl_easy_setopt(h, CURLOPT_ERRORBUFFER, ErrorMessage_);

#ifdef CURL_VERBOSE
    curl_easy_setopt(h, CURLOPT_DEBUGFUNCTION, OnDebugShim);
            curl_easy_setopt(h, CURLOPT_DEBUGDATA, this);
            curl_easy_setopt(h, CURLOPT_VERBOSE, 1);
#endif

    curl_easy_setopt(h, CURLOPT_FOLLOWLOCATION, (Opts_.Retries > 0));
    curl_easy_setopt(h, CURLOPT_MAXREDIRS, Opts_.Retries);
    curl_easy_setopt(h, CURLOPT_CONNECTTIMEOUT_MS, Opts_.ConnectTimeout.MilliSeconds());
    curl_easy_setopt(h, CURLOPT_TIMEOUT_MS, Opts_.ConnectTimeout.MilliSeconds() + Opts_.ReadTimeout.MilliSeconds());
    curl_easy_setopt(h, CURLOPT_PROTOCOLS, CURLPROTO_HTTP | CURLPROTO_HTTPS);
    curl_easy_setopt(h, CURLOPT_FORBID_REUSE, 1);

    if (BindOptions_.has_value()) {
        if (!BindOptions_->Host.empty()) {
            curl_easy_setopt(h, CURLOPT_INTERFACE, BindOptions_->Host.c_str());
        }
        curl_easy_setopt(h, CURLOPT_LOCALPORT, static_cast<long>(BindOptions_->Port));
        curl_easy_setopt(h, CURLOPT_LOCALPORTRANGE, static_cast<long>(BindOptions_->PortRange));
    }

    // according to https://curl.se/libcurl/c/CURLOPT_DNS_CACHE_TIMEOUT.html
    if (DnsCacheLifetime_ == TDuration::Zero()) {
        // set to 0 to completely disable caching
        curl_easy_setopt(h, CURLOPT_DNS_CACHE_TIMEOUT, 0);
    } else if (DnsCacheLifetime_ == TDuration::Max()) {
        // set to -1 to make the cached entries remain forever
        curl_easy_setopt(h, CURLOPT_DNS_CACHE_TIMEOUT, -1L);
    } else {
        curl_easy_setopt(h, CURLOPT_DNS_CACHE_TIMEOUT, DnsCacheLifetime_.Seconds());
    }

    auto&& keepAlive = Opts_.KeepAliveOpts;
    if (keepAlive) {
        curl_easy_setopt(h, CURLOPT_TCP_KEEPALIVE, keepAlive->TcpKeepAlive);
        curl_easy_setopt(h, CURLOPT_TCP_KEEPINTVL, keepAlive->KeepInterval.Seconds());
        curl_easy_setopt(h, CURLOPT_TCP_KEEPIDLE, keepAlive->KeepIdle.Seconds());
    }

    curl_easy_setopt(h, CURLOPT_SHARE, share);

    WriteHeaders(h);
    return TDuration::Seconds(Timer_.PassedReset());
}

int TRequestContext::OnDebug(CURL* handle, curl_infotype type, TStringBuf msg) {
    Y_UNUSED(handle);
    switch (type) {
        case CURLINFO_TEXT:
            Cerr << "== Info: " << msg;
            break;
        case CURLINFO_HEADER_OUT:
            Cerr << "=> Send header";
            break;
        case CURLINFO_DATA_OUT:
            Cerr << "=> Send data";
            break;
        case CURLINFO_SSL_DATA_OUT:
            Cerr << "=> Send SSL data";
            break;
        case CURLINFO_HEADER_IN:
            Cerr << "<= Recv header";
            break;
        case CURLINFO_DATA_IN:
            Cerr << "<= Recv data";
            break;
        case CURLINFO_SSL_DATA_IN:
            Cerr << "<= Recv SSL data";
            break;
        default:
            return 0;
    }

    Cerr << Endl;
    return 0;
}

size_t TRequestContext::OnHeader(TStringBuf data) {
    EnsureResponse();
    TStringBuf name, value;
    const auto ok = StringSplitter(data).SplitByString(TStringBuf(": "))
        .TryCollectInto(&name, &value);

    if (!ok) {
        return data.size();
    }

    value = StripString(value);
    if (HasResponseLimit() && name == NHttpHeaders::CONTENT_LENGTH && !FitsResponseLimit(value)) {
        ResponseLimitSizeExceeded(value);
        return 0;
    }

    Response_->Headers_.Add(name, value);

    return data.size();
}

size_t TRequestContext::OnBody(TStringBuf data) {
    EnsureResponse();
    if (HasResponseLimit() && !FitsResponseLimit(data.size() + Response_->Data().size())) {
        ResponseLimitSizeExceeded(data.size() + Response_->Data().size());
        return 0;
    }

    Response_->Data_.AppendNoAlias(data);
    return data.size();
}

int TRequestContext::OnSeek(curl_off_t offset, int origin) {
    switch (origin) {
        case SEEK_SET:
            UploadDataView_ = Request_->Data();
            UploadDataView_.Skip(offset);
            break;
        case SEEK_CUR:
            UploadDataView_.Skip(offset);
            break;
        case SEEK_END:
            if (offset <= 0) {
                UploadDataView_.RSeek(-offset);
            } else {
                return CURL_SEEKFUNC_FAIL;
            }

            break;
        default:
            Y_VERIFY_DEBUG(false);
            return CURL_SEEKFUNC_FAIL;
    }

    return CURL_SEEKFUNC_OK;
}

int TRequestContext::OnRead(char* buffer, size_t capacity) {
    const auto writeSize = Min(capacity, UploadDataView_.size());
    CopyN(UploadDataView_.begin(), writeSize, buffer);
    UploadDataView_.Skip(writeSize);
    return writeSize;
}

curl_socket_t TRequestContext::OnSocketOpen(curlsocktype, struct curl_sockaddr* address) {
    Socket_ = socket(address->family, address->socktype, address->protocol);

    if (Socket_ != -1 && BindOptions_.has_value() && BindOptions_->Port != 0) {
        Y_ENSURE(SetSockOpt(Socket_, SOL_SOCKET, SO_REUSEADDR, 1L) == 0, "cannot set SO_REUSEADDR to a newly opened socket");
    }

    return Socket_;
}

void TRequestContext::WriteHeaders(CURL* h) {
    if (HeaderArena_.Empty()) {
        Request_->Headers().ForEach([&] (TStringBuf name, TStringBuf value) {
            HeaderArena_.Add(name, value);
        });
    }

    HeaderArena_.Bind(h);
}

} // namespace NSolomon
