#include "registration.h"

#include <solomon/agent/lib/auth/auth.h>
#include <solomon/agent/lib/context/context.h>
#include <solomon/agent/lib/http/headers.h>
#include <solomon/agent/misc/logger.h>

#include <solomon/protos/agent/agent.pb.h>
#include <solomon/libs/cpp/http/client/curl/client.h>

#include <library/cpp/http/misc/http_headers.h>
#include <library/cpp/string_utils/url/url.h>

#include <util/generic/guid.h>
#include <util/system/env.h>
#include <util/system/hostname.h>

namespace NSolomon::NAgent {
namespace {

constexpr ui32 DEFAULT_PULL_INTERVAL_SECONDS{15};
constexpr auto REGISTRATION_MIN_DELAY = TDuration::Seconds(15);
constexpr auto REGISTRATION_MAX_DELAY = TDuration::Minutes(2);

class TRegistrationData {
public:
    ui64 PullIntervalSeconds{0};
    ui32 DataPort{0};
    ui32 ManagementPort{0};
    TString Address;
    TString Version;
    TString Description;

public:
    THolder<TRegisterAgentRequest> ToProtobuf() {
        auto message = MakeHolder<TRegisterAgentRequest>();

        message->SetPullIntervalSeconds(PullIntervalSeconds);
        message->SetDataPort(DataPort);
        message->SetManagementPort(ManagementPort);
        message->SetAddress(Address);
        message->SetVersion(Version);
        message->SetDescription(Description);

        return message;
    }
};

TStringBuf RetrieveHost(const TString& url) {
    ui16 port{0};
    TStringBuf scheme, host;

    Y_ENSURE(TryGetSchemeHostAndPort(url, scheme, host, port), "unable to parse " << url << " as a valid host");

    Y_ENSURE(scheme == "http://" || scheme == "https://",
             "wrong or no scheme is specified in the url \"" << url << "\". Choose http:// or https://");

    Y_ENSURE(!host.empty(), "empty host value");

    return host;
}

void InflateDataWithPullInterval(const TProviderRegistrationConfig& registrationConfig, TRegistrationData& data) {
    TString pullIntervalStr = registrationConfig.GetPullInterval();

    if (pullIntervalStr.empty()) {
        data.PullIntervalSeconds = DEFAULT_PULL_INTERVAL_SECONDS;
        return;
    }

    StripInPlace(pullIntervalStr);

    try {
        data.PullIntervalSeconds = TDuration::Parse(pullIntervalStr).Seconds();
    } catch (...) {
        ythrow yexception() << "failed to parse PullInterval: " << CurrentExceptionMessage();
    }
}

void InflateDataWithStoragePort(const TAgentConfig& config, TRegistrationData& data) {
    Y_ENSURE(config.HasHttpServer(), "cannot register the Agent instance with a missing HttpServer section");

    const auto& storagePort = config.GetHttpServer().GetBindPort();
    Y_ENSURE(storagePort != 0, "port to an Agent data server cannot be 0");

    data.DataPort = storagePort;
}

void InflateDataWithManagementPort(const TAgentConfig& config, TRegistrationData& data) {
    if (config.HasManagementServer()) {
        auto port = config.GetManagementServer().GetBindPort();

        if (port != 0) {
            data.ManagementPort = port;
        }
    }
}

TRegistrationData CollectDataForRegistration(const TAgentConfig& config) {
    const auto& registrationConfig = config.GetRegistration();
    TRegistrationData data;

    InflateDataWithPullInterval(registrationConfig, data);
    InflateDataWithStoragePort(config, data);
    InflateDataWithManagementPort(config, data);

    data.Address = GetFQDNHostName();
    data.Version = AGENT_VERSION;
    data.Description = registrationConfig.GetDescription();

    return data;
}

const TString CONTENT_TYPE_PROTOBUF = "application/x-protobuf";

} // namespace

void TRegistrationTask::InitHttpClient(NMonitoring::TMetricRegistry& registry, bool bindBeforeConnect) {
    const auto maxInflight = 10; // probably there won't be more Solomon installations than this

    TCurlClientOptions curlOpts;
    curlOpts.DnsCacheLifetime = TDuration::Hours(2);
    curlOpts.MaxInflight = maxInflight;

    if (bindBeforeConnect) {
        curlOpts.BindOptions = TCurlBindOptions{.Host = HostName()}; // use local host name to bind client socket to
    }

    TString caCertDir = Strip(GetEnv("SA_CAPATH", ""));

    if (!caCertDir.empty()) {
        curlOpts.CaCertDir = caCertDir;
    }

    HttpClient_ = CreateCurlClient(std::move(curlOpts), registry);
    ReqOpts_.ConnectTimeout = TDuration::Seconds(5);
    ReqOpts_.ReadTimeout = TDuration::Seconds(15);
    ReqOpts_.Retries = 3;
    ReqOpts_.BackoffMin = TDuration::Seconds(5);
    ReqOpts_.BackoffMax = TDuration::Minutes(1);
}

TRegistrationTask::TRegistrationTask(
        const TAgentConfig& config,
        NMonitoring::TMetricRegistry& registry,
        TTimerThread& timer,
        TSimpleSharedPtr<IThreadPool> pool)
    : TimerThread_{timer}
    , ThreadPool_{std::move(pool)}
    , Backoff_{REGISTRATION_MIN_DELAY, REGISTRATION_MAX_DELAY}
{
    if (!TimerThread_.IsStarted()) {
        TimerThread_.Start();
    }

    ClustersHosts_[EClusterType::PRODUCTION] = "solomon.yandex.net";
    ClustersHosts_[EClusterType::PRESTABLE] = "solomon-prestable.yandex.net";
    ClustersHosts_[EClusterType::TESTING] = "solomon-test.yandex.net";
    ClustersHosts_[EClusterType::CLOUD_PROD] = "solomon.cloud.yandex-team.ru";
    ClustersHosts_[EClusterType::CLOUD_PREPROD] = "solomon.cloud-preprod.yandex-team.ru";
    ClustersHosts_[EClusterType::CLOUD_GPN] = "monitoring.private-api.ycp.gpn.yandexcloud.net";
    ClustersHosts_[EClusterType::CLOUD_ISRAEL] = "solomon.yandexcloud.co.il";

    for (const auto& [_, host]: ClustersHosts_) {
        Y_ENSURE(!host.empty());
    }

    const auto& registrationConfig = config.GetRegistration();
    const auto& providerId = registrationConfig.GetProvider();

    Y_ENSURE(!providerId.empty(), "empty Provider value inside the Registration section");
    Y_ENSURE(!config.GetCluster().empty(), "empty Cluster value inside the Agent config");

    auto regData = CollectDataForRegistration(config);
    RegistrationData_ = regData.ToProtobuf()->SerializeAsStringOrThrow();

    Y_ENSURE(registrationConfig.EndpointsSize() > 0, "no endpoints for registration are specified");

    for (auto& endpoint: registrationConfig.GetEndpoints()) {
        Y_ENSURE(endpoint.HasType() != endpoint.HasUrl(),
                 "either Type or Url field should be specified");

        TString url;
        TString host;

        if (endpoint.HasUrl()) {
            url = Strip(endpoint.GetUrl());
            Y_ENSURE(!url.empty(), "Endpoints::Url field cannot be empty");

            host = RetrieveHost(url);
        } else {
            Y_ENSURE(EClusterType_IsValid(endpoint.GetType()),
                     "unknown endpoint type: " << static_cast<ui32>(endpoint.GetType()));

            host = ClustersHosts_[endpoint.GetType()];

            url = TStringBuilder() << "https://" << host << "/api/v2/agents/register";
        }

        IAuthProviderPtr authProvider;
        if (const auto& authMethod = endpoint.GetAuthMethod()) {
            authProvider = GetAgentCtx()->GetAuthProvider(authMethod);
            Y_ENSURE(authProvider, "no auth provider for \"" << authMethod << "\"");
        }

        // TODO: validate a url
        url += (url.Contains('?') ? '&' : '?');
        url += "provider=" + providerId;

        auto statePtr = MakeIntrusive<TClusterRegistrationState>(std::move(url), std::move(authProvider));
        auto isNewElement = ClusterNameToState_.emplace(host, statePtr).second;

        Y_ENSURE(isNewElement, "endpoint " << host << " is specified multiple times");
    }

    InitHttpClient(registry, GetAgentCtx()->DoBindBeforeConnect());
}

TRegistrationTask::~TRegistrationTask() {
    Stop();

    for (auto& [clusterName, state]: ClusterNameToState_) {
        if (state->CurrentRequest.Initialized()) {
            Y_VERIFY(state->CurrentRequest.Wait(TDuration::Seconds(10)),
                    "failed to wait for an async request for %s", clusterName.c_str());
        }
    }
}

std::function<void()> TRegistrationTask::CreateTimerFuncClosure(
        TString clusterName,
        TClusterRegistrationStatePtr state)
{
    return [this, clusterName, state]() {
        auto future = SendRequest(clusterName, state);

        auto asyncReq = future.Apply([this, clusterName, state](const NThreading::TFuture<ui64>& result) {
            TDuration delay;

            try {
                delay = Jitter_(TDuration::Seconds(result.GetValue()));
                state->IsRegistered = true;
                state->PreviousDelay = delay;

                Backoff_.Reset();
            } catch (...) {
                SA_LOG(ERROR) << "failed to register the Agent instance in " << clusterName
                              << ": " << CurrentExceptionMessage();

                state->IsRegistered = false;

                delay = (state->PreviousDelay)
                    ? Max(Jitter_(state->PreviousDelay), Backoff_())
                    : Backoff_();
            }

            auto task = MakeFuncTimerTask(ThreadPool_.Get(), CreateTimerFuncClosure(clusterName, state));
            TimerThread_.Schedule(task, delay);

            SA_LOG(DEBUG) << "scheduled the next registration request for " << clusterName
                          << " in " << delay.Seconds() << " seconds";
        });

        state->CurrentRequest = std::move(asyncReq);
    };
}

NThreading::TFuture<ui64> TRegistrationTask::SendRequest(
        TString clusterName,
        TClusterRegistrationStatePtr state)
{
    // TODO: headers reusage
    auto headers = Headers({
        {TString{NHttpHeaders::CONTENT_TYPE}, CONTENT_TYPE_PROTOBUF},
        {TString{NHttpHeaders::USER_AGENT}, USER_AGENT_HEADER},
    });

    if (state->AuthProvider) {
        THashMap<TString, TString> authHeaders;
        try {
            state->AuthProvider->AddCredentials(authHeaders);
        } catch (...) {
            return NThreading::MakeErrorFuture<ui64>(std::current_exception());
        }

        for (auto&& [key, value]: authHeaders) {
            headers->Add(key, value);
        }
    }

    TString requestId = CreateGuidAsString();
    TString requestIdLog = TStringBuilder() << "[rid=" << requestId << "]";

    TString requestIdParameter = "";
    requestIdParameter += state->Endpoint.Contains('?') ? '&' : '?';
    requestIdParameter += "requestId=" + requestId;

    TString backendUrl = state->Endpoint + requestIdParameter;
    auto req = Post(backendUrl, RegistrationData_, std::move(headers));

    NThreading::TPromise<ui64> promise = NThreading::NewPromise<ui64>();
    NThreading::TFuture<ui64> future = promise.GetFuture();

    auto responseCallback = [=](IHttpClient::TResult result) mutable -> void {
        if (!result.Success()) {
            TString errMsg = TStringBuilder() << requestIdLog << " " << result.Error().Message();
            promise.SetException(errMsg);

            return;
        }

        if (result.Value()->Code() >= 400) {
            TString errMsg = TStringBuilder() << requestIdLog << " got an error code " << result.Value()->Code()
                                              << ": " << result.Value()->Data();
            promise.SetException(errMsg);

            return;
        }

        auto data = result.Value()->Data();

        TRegisterAgentResponse resp;
        if (!resp.ParseFromArray(data.data(), data.size())) {
            TString errMsg = TStringBuilder() << requestIdLog << " failed to parse a registration response";
            promise.SetException(errMsg);

            return;
        }

        SA_LOG(DEBUG) << requestIdLog << " successfully registered in " << clusterName;
        promise.SetValue(resp.GetRegisterDelaySeconds());
    };

    SA_LOG(DEBUG) << requestIdLog << " registering the Agent in " << clusterName;

    HttpClient_->Request(std::move(req), std::move(responseCallback), ReqOpts_);

    return future;
}

void TRegistrationTask::Start() {
    for (auto& [clusterName, state]: ClusterNameToState_) {
        TDuration initialDelay = Jitter_(REGISTRATION_MIN_DELAY);

        auto clusterTask = MakeFuncTimerTask(ThreadPool_.Get(), CreateTimerFuncClosure(clusterName, state));
        TimerThread_.Schedule(clusterTask, initialDelay);

        SA_LOG(INFO) << "scheduled an initial registration request for " << clusterName
                     << " in " << initialDelay.Seconds() << " seconds";
    }
}

void TRegistrationTask::Stop() {
    // TODO: the timer will be external and shared. In that case it cannot be stopped, 'cause other clients use it
    TimerThread_.Stop();
}

bool TRegistrationTask::IsRegisteredIn(const TString& clusterName) const {
    auto it = ClusterNameToState_.find(clusterName);

    if (it == ClusterNameToState_.end()) {
        return false;
    }

    return it->second->IsRegistered;
}

} // namespace NSolomon::NAgent
