#include "curl.h"
#include <contrib/libs/curl/include/curl/curl.h>
#include <util/string/strip.h>
#include <util/string/join.h>
#include <util/digest/multi.h>

template <>
void TDelete::Destroy<curl_slist>(curl_slist* t) noexcept {
    curl_slist_free_all(t);
}

namespace NCurl {
    static HttpCodes GetCode(CURL* easy) {
        long responseCode = 0;
        curl_easy_getinfo(easy, CURLINFO_RESPONSE_CODE, &responseCode);
        return static_cast<HttpCodes>(responseCode);
    }

    static TStringBuf GetURL(const CURL* easy) {
        const char* url = nullptr;
        curl_easy_getinfo(const_cast<CURL*>(easy), CURLINFO_EFFECTIVE_URL, &url);
        return url;
    }

    TRequestContext& TRequestContext::SetHost(TString v) {
        url = std::move(v);
        return *this;
    }

    TRequestContext& TRequestContext::SetPort(ui16 v) {
        port = v;
        return *this;
    }

    TRequestContext& TRequestContext::SetRequestTimeout(TDuration v) {
        requestTimeout = v;
        return *this;
    }

    TRequestContext& TRequestContext::SetConnectTimeout(TDuration v) {
        connectTimeout = v;
        return *this;
    }

    TRequestContext& TRequestContext::SetPostData(TString v) {
        data = std::move(v);
        return *this;
    }

    TRequestContext& TRequestContext::SetContentType(TStringBuf contentType) {
        return AddHeader(TString("Content-Type: ") + contentType);
    }

    TRequestContext& TRequestContext::AddHeader(const TString& v) {
        curlHeaders.Reset(curl_slist_append(curlHeaders.Release(), v.data()));
        return *this;
    }

    TTimings::TTimings(CURL* curl) {
        double value;

        for (auto&[duration, option] : {
                std::make_tuple(std::ref(Total), CURLINFO_TOTAL_TIME),
                std::make_tuple(std::ref(NsLookup), CURLINFO_NAMELOOKUP_TIME),
                std::make_tuple(std::ref(Connect), CURLINFO_CONNECT_TIME),
                std::make_tuple(std::ref(PreTransfer), CURLINFO_PRETRANSFER_TIME),
                std::make_tuple(std::ref(Redirect), CURLINFO_REDIRECT_TIME),
                std::make_tuple(std::ref(AppConnect), CURLINFO_APPCONNECT_TIME),
                std::make_tuple(std::ref(StartTransfer), CURLINFO_STARTTRANSFER_TIME),
        }) {
            curl_easy_getinfo(curl, option, &value);
            duration = TDuration::Seconds(value);
        }
    }

    IOutputStream &operator<<(IOutputStream &stream, const TTimings &timings) {
        if (const auto &Total = timings.Total; Total != TDuration::Max()) {
            stream << LabeledOutput(Total) << ';';
        }
        if (const auto &NsLookup = timings.NsLookup; NsLookup != TDuration::Max()) {
            stream << LabeledOutput(NsLookup) << ';';
        }
        if (const auto &Connect = timings.Connect; Connect != TDuration::Max()) {
            stream << LabeledOutput(Connect) << ';';
        }
        if (const auto &PreTransfer = timings.PreTransfer; PreTransfer != TDuration::Max()) {
            stream << LabeledOutput(PreTransfer) << ';';
        }
        if (const auto &Redirect = timings.Redirect; Redirect != TDuration::Max()) {
            stream << LabeledOutput(Redirect) << ';';
        }
        if (const auto &AppConnect = timings.AppConnect; AppConnect != TDuration::Max()) {
            stream << LabeledOutput(AppConnect) << ';';
        }
        if (const auto &StartTransfer = timings.StartTransfer; StartTransfer != TDuration::Max()) {
            stream << LabeledOutput(StartTransfer) << ';';
        }

        return stream;
    }

    TSSL::TSSL(const NConfig::TDict& config) {
        if (auto it = MapFindPtr(config, "cert"))
            SetCertFile(it->Get<TString>());
        if (auto it = MapFindPtr(config, "pass"))
            SetPassphrase(it->Get<TString>());
        if (auto it = MapFindPtr(config, "key"))
            SetKeyFile(it->Get<TString>());
        if (auto it = MapFindPtr(config, "ca"))
            SetCACertFile(it->Get<TString>());
    }

    TRetries::TRetries(const NConfig::TDict& config) {
        if(auto it = MapFindPtr(config, "count"))
            Count = it->As<size_t>();
        if(auto it = MapFindPtr(config, "time_to_wait_before_retry"))
            TimeToWaitBeforeRetry = NTalkativeConfig::As<TDuration>(*it);
    }

    TOptions::TOptions(const NConfig::TDict& config) {
        if(auto it = MapFindPtr(config, "alive_connections"))
            maxAliveConnections = it->As<size_t>();
        if(auto it = MapFindPtr(config, "parse_headers"))
            parseHeaders = it->As<bool>();
        if(auto it = MapFindPtr(config, "parse_body"))
            parseBody = it->As<bool>();
    }

    TOptions& TOptions::SetMaxAliveConnections(size_t count) {
        maxAliveConnections = count;
        return *this;
    }
    TOptions& TOptions::SetParseHeaders(bool v) {
        parseHeaders = v;
        return *this;
    }
    TOptions& TOptions::SetParseBody(bool v) {
        parseBody = v;
        return *this;
    }

    void TShare::TDestroyer::Destroy(CURLSH* c) {
        curl_share_cleanup(c);
    }

    static void lock_function(CURL* /*handle*/, curl_lock_data data, curl_lock_access /*access*/, void* userptr) {
        reinterpret_cast<TShare*>(userptr)->locks[data].Acquire();
    }
    static void unlock_function(CURL* /*handle*/, curl_lock_data data, void* userptr) {
        reinterpret_cast<TShare*>(userptr)->locks[data].Release();
    }

    TShare::TShare(bool multithread)
        : locks(CURL_LOCK_DATA_LAST)
        , share(curl_share_init()) {
        Y_VERIFY(share);

        if (multithread) {
            curl_share_setopt(share.Get(), CURLSHOPT_USERDATA, reinterpret_cast<void*>(this));
            curl_share_setopt(share.Get(), CURLSHOPT_LOCKFUNC, lock_function);
            curl_share_setopt(share.Get(), CURLSHOPT_UNLOCKFUNC, unlock_function);
        }

        curl_share_setopt(share.Get(), CURLSHOPT_SHARE, CURL_LOCK_DATA_DNS);
        curl_share_setopt(share.Get(), CURLSHOPT_SHARE, CURL_LOCK_DATA_CONNECT);
        curl_share_setopt(share.Get(), CURLSHOPT_SHARE, CURL_LOCK_DATA_SSL_SESSION);
        curl_share_setopt(share.Get(), CURLSHOPT_SHARE, CURL_LOCK_DATA_COOKIE);
    }

    struct TCurlGlobal {
        TCurlGlobal() noexcept {
            curl_global_init(CURL_GLOBAL_DEFAULT);
        }
        ~TCurlGlobal() {
            curl_global_cleanup();
        }
    };

    TCurlBase::TCurlBase() noexcept {
        static TCurlGlobal global;
    }

    void TCurlBase::Prepare(CURL* easy, TArtifacts::TBodyOutput& bodyWriter, TArtifacts::THeaders& headers) {
        curl_easy_setopt(easy, CURLOPT_WRITEDATA, reinterpret_cast<void*>(&bodyWriter));
        curl_easy_setopt(easy, CURLOPT_HEADERDATA, reinterpret_cast<void*>(&headers));
    }

    void TCurlBase::PrepareAndStoreArtifact(CURL* easy, TArtifacts& artifacts) {
        curl_easy_setopt(easy, CURLOPT_PRIVATE, reinterpret_cast<void*>(&artifacts));
        Prepare(easy, artifacts.bodyWriter, artifacts.headers);
    }

    TMaybe<TError> TCurlBase::LoadArtifactAndProcessResult(CURL* easy, ui32 code) {
        TArtifacts* artifact = nullptr;
        curl_easy_getinfo(easy, CURLINFO_PRIVATE, &artifact);
        Y_ENSURE(artifact != nullptr);
        return ProcessResult(easy, *artifact, code);
    }

    TMaybe<TError> TCurlBase::ProcessResult(CURL* easy, TArtifacts& artifacts, ui32 _code) {
        artifacts.bodyWriter.Flush();
        artifacts.Timings.emplace_back(easy);
        const auto code = static_cast<CURLcode>(_code);
        switch (code) {
            case CURLE_OK: {
                if (!artifacts.headers.empty()) {
                    artifacts.firstLine = std::move(artifacts.headers.front());
                    artifacts.headers.pop_front();
                }

                artifacts.code = GetCode(easy);
                return Nothing();
            }
            case CURLE_OPERATION_TIMEDOUT: {
                return TError{} << "timeout: " << curl_easy_strerror(code)
                                << " for url " << GetURL(easy)
                                << " with timings: " << MakeRangeJoiner("->", artifacts.Timings);
            }
            default: {
                return TError{} << curl_easy_strerror(code)
                                << " for url " << GetURL(easy)
                                << " with timings: " << MakeRangeJoiner("->", artifacts.Timings);;
            }
        }
    }

    void TCurl::TDestroyer::Destroy(CURL* context) {
        curl_easy_cleanup(context);
    }

    TMaybe<TError> TCurl::PerformWithRetries(TArtifacts& artifacts) {
        TStringStream bodyWriter;
        TArtifacts::THeaders headers;

        Prepare(context.Get(), bodyWriter, headers);
        auto r = curl_easy_perform(context.Get());
        size_t retries = Max(size_t(1), retryOpts->Count);
        while (retries-- && (r != CURLE_OK || GetCode(context.Get()) >= HTTP_INTERNAL_SERVER_ERROR) && !retryOpts->UnretriableCodes.count(GetCode(context.Get()))) {
            bodyWriter.Clear();
            headers.clear();
            artifacts.Timings.emplace_back(context.Get());
            if (retryOpts->TimeToWaitBeforeRetry)
                Sleep(retryOpts->TimeToWaitBeforeRetry);
            r = curl_easy_perform(context.Get());
        }

        bodyWriter.ReadAll(artifacts.bodyWriter);
        artifacts.headers = std::move(headers);
        return ProcessResult(context.Get(), artifacts, r);
    }

    TMaybe<TError> TCurl::Perform(TArtifacts& artifacts) {
        if (retryOpts && retryOpts->Count) {
            return PerformWithRetries(artifacts);
        } else {
            Prepare(context.Get(), artifacts.bodyWriter, artifacts.headers);
            auto r = curl_easy_perform(context.Get());
            return ProcessResult(context.Get(), artifacts, r);
        }
    }

    TCurl& TCurl::Reset() {
        curl_easy_reset(context.Get());
        return *this;
    }

    CURL* TCurl::Native() const noexcept {
        return context.Get();
    }

    TStringBuf TCurl::GetEffectiveURL() const noexcept {
        const char* url = nullptr;
        curl_easy_getinfo(Native(), CURLINFO_EFFECTIVE_URL, &url);
        return url;
    }

    TDuration TCurl::GetTotalTime() const noexcept {
        const double* v = nullptr;
        curl_easy_getinfo(Native(), CURLINFO_TOTAL_TIME, &v);
        return v ? TDuration::MicroSeconds(1e6 * *v) : TDuration::Max();
    }

    TDuration TCurl::GetConnectTime() const noexcept {
        const double* v = nullptr;
        curl_easy_getinfo(Native(), CURLINFO_CONNECT_TIME, &v);
        return v ? TDuration::MicroSeconds(1e6 * *v) : TDuration::Max();
    }

    static size_t body_callback(void* contents, size_t size, size_t nitems, void* userdata) {
        const size_t realsize = size * nitems;
        if (userdata)
            reinterpret_cast<TArtifacts::TBodyOutput*>(userdata)->Write(contents, realsize);
        return realsize;
    }

    static size_t null_body_callback(void*, size_t size, size_t nitems, void*) {
        return size * nitems;
    }

    static size_t header_callback(char* buffer, size_t size, size_t nitems, void* userdata) {
        const size_t realsize = size * nitems;
        if (userdata) {
            auto header = StripString(TStringBuf{buffer, realsize});
            if (!header.empty())
                reinterpret_cast<TArtifacts::THeaders*>(userdata)->emplace_back(header);
        }
        return realsize;
    }

    static size_t null_header_callback(char*, size_t size, size_t nitems, void*) {
        return size * nitems;
    }

    TCurl& TCurl::Setup(TRequestContext r) noexcept {
        requestContext = std::move(r);

        curl_easy_setopt(context.Get(), CURLOPT_URL, requestContext.url ? requestContext.url->c_str() : nullptr);
        curl_easy_setopt(context.Get(), CURLOPT_PORT, requestContext.port.GetOrElse(0));
        curl_easy_setopt(context.Get(), CURLOPT_HTTPHEADER, requestContext.curlHeaders.Get());
        curl_easy_setopt(context.Get(), CURLOPT_TIMEOUT_MS, requestContext.requestTimeout.GetOrElse(TDuration::Zero()).MilliSeconds());
        curl_easy_setopt(context.Get(), CURLOPT_CONNECTTIMEOUT_MS, requestContext.connectTimeout.GetOrElse(TDuration::Zero()).MilliSeconds());

        if (requestContext.data) {
            curl_easy_setopt(context.Get(), CURLOPT_POSTFIELDSIZE, requestContext.data->size());
            curl_easy_setopt(context.Get(), CURLOPT_POSTFIELDS, requestContext.data->data());
        } else {
            curl_easy_setopt(context.Get(), CURLOPT_HTTPGET, 1);
        }

        return *this;
    }

    TCurl& TCurl::Setup(TMaybe<TRetries> config) {
        retryOpts = std::move(config);
        return *this;
    }

    TCurl& TCurl::Setup(const TOptions& options) {
        if (options.maxAliveConnections)
            SetMaxConnections(*options.maxAliveConnections);

        curl_easy_setopt(context.Get(), CURLOPT_WRITEFUNCTION, options.parseBody ? body_callback : null_body_callback);
        curl_easy_setopt(context.Get(), CURLOPT_HEADERFUNCTION, options.parseHeaders ? header_callback : null_header_callback);

        return *this;
    }

    TCurl& TCurl::SetShare(TAtomicSharedPtr<TShare> s) {
        share = std::move(s);
        curl_easy_setopt(context.Get(), CURLOPT_SHARE, share.Get()->Get());

        return *this;
    }

    TCurl::TCurl(const TOptions& options)
        : context(curl_easy_init()) {
        Y_ENSURE(context);
        Setup(options);
    }

    TCurl& TCurl::Setup(const TSSL& ssl) {
        curl_easy_setopt(context.Get(), CURLOPT_SSLCERTTYPE, "PEM");

        /* set the cert for client authentication */
        if (ssl.certFile)
            curl_easy_setopt(context.Get(), CURLOPT_SSLCERT, ssl.certFile.c_str());

        /* sorry, for engine we must set the passphrase
           (if the key has one...) */
        if (ssl.passPhrase)
            curl_easy_setopt(context.Get(), CURLOPT_KEYPASSWD, ssl.passPhrase.c_str());

        if (ssl.keyFile) {
            /* if we use a key stored in a crypto engine,
               we must set the key type to "ENG" */
            curl_easy_setopt(context.Get(), CURLOPT_SSLKEYTYPE, "PEM");

            /* set the private key (file or ID in engine) */
            curl_easy_setopt(context.Get(), CURLOPT_SSLKEY, ssl.keyFile.c_str());
        }

        /* set the file with the certs vaildating the server */
        if (ssl.caCertFile)
            curl_easy_setopt(context.Get(), CURLOPT_CAINFO, ssl.caCertFile.c_str());

        /* disconnect if we can't validate server's cert */
        curl_easy_setopt(context.Get(), CURLOPT_SSL_VERIFYPEER, 1L);

        return *this;
    }

    void TCurl::SetMaxConnections(size_t connections) {
        curl_easy_setopt(context.Get(), CURLOPT_MAXCONNECTS, connections);
    }

    TPoolTraits::TPoolTraits(const NConfig::TDict& config) {
        if(auto it = MapFindPtr(config, "options"))
            options = TOptions(it->Get<NConfig::TDict>());
        if(auto it = MapFindPtr(config, "ssl"))
            ssl = TSSL(it->Get<NConfig::TDict>());
    }

    THolder<TCurl> TPoolTraits::create() const {
        auto curl = MakeHolder<TCurl>(options);
        if(ssl) {
            curl->Setup(*ssl);
        }
        curl->Setup(retries);
        return curl;
    }

    TCurl& TPoolTraits::extract(TPool::TItemHolder& holder) {
        return holder.Get();
    }

    TPoolTraits::TPoolTraits(TOptions options, TMaybe<TSSL> ssl)
        : options(std::move(options))
        , ssl(std::move(ssl)) {
    }

    TPoolTraits::TPoolTraits(TMaybe<TSSL> ssl, TMaybe<TRetries> retries)
    : ssl(std::move(ssl))
    , retries(std::move(retries)) {}
} // namespace NCurl
