#include "config.h"

#include <drive/library/cpp/aes/crypto.h>

#include <util/stream/file.h>
#include <util/string/builder.h>
#include <util/string/join.h>

#include <library/cpp/openssl/holders/hmac.h>
#include <library/cpp/string_utils/base64/base64.h>

#include <contrib/libs/openssl/include/openssl/hmac.h>


TS3ClientConfig::TEncryptedBucket::TEncryptedBucket(const TString& name)
    : Name(name)
{
}

void TS3ClientConfig::TEncryptedBucket::Init(const TYandexConfig::Section* section) {
    BucketName = section->GetDirectives().Value("Bucket", Name);
    EncryptionMode = section->GetDirectives().Value("EncryptionMode", EncryptionMode);
    Default = section->GetDirectives().Value("Default", Default);

    TString encryptionKeyHR = section->GetDirectives().Value("EncryptionKey", TString());
    if (!encryptionKeyHR) {
        EncryptionKeyPath = section->GetDirectives().Value("EncryptionKeyPath", EncryptionKeyPath);
        if (EncryptionKeyPath) {
            encryptionKeyHR = Strip(TFileInput(EncryptionKeyPath).ReadAll());
        }
    }
    switch (EncryptionMode) {
        case EBucketEncryptionMode::AES_GCM_128:
            AssertCorrectConfig(encryptionKeyHR.size() == 32, "EncryptionKey has wrong size. Must be 32.");
            break;
        case EBucketEncryptionMode::AES_GCM_256:
            AssertCorrectConfig(encryptionKeyHR.size() == 64, "EncryptionKey has wrong size. Must be 64.");
            break;
    }
    for (size_t i = 0; i < encryptionKeyHR.size() / 2; ++i) {
        auto firstByte = TCryptoGCMProcessor::HexDigitToInt(encryptionKeyHR[i * 2]);
        auto secondByte = TCryptoGCMProcessor::HexDigitToInt(encryptionKeyHR[i * 2 + 1]);
        EncryptionKey.push_back((firstByte << 4) + secondByte);
    }
}

void TS3ClientConfig::TEncryptedBucket::ToString(IOutputStream& os) const {
    os << "Bucket: " << BucketName << Endl;
    os << "EncryptionMode: " << EncryptionMode << Endl;
    if (EncryptionKeyPath) {
        os << "EncryptionKeyPath: " << EncryptionKeyPath << Endl;
    }
}


void TS3ClientConfig::Init(const TYandexConfig::Section* section, const TMap<TString, NSimpleMeta::TConfig>* requestPolicy) {
    TAsyncApiImpl::TConfig::Init(section, requestPolicy);
    AccessKeyIdPath = section->GetDirectives().Value("AccessKeyIdPath", AccessKeyIdPath);
    if (!!AccessKeyIdPath) {
        AccessKeyId = Strip(TFileInput(AccessKeyIdPath).ReadAll());
    }
    AccessKeyId = section->GetDirectives().Value("AccessKeyId", AccessKeyId);
    AccessSecretKey = section->GetDirectives().Value("AccessSecretKey", AccessSecretKey);
    AccessSecretKeyPath = section->GetDirectives().Value("AccessSecretKeyPath", AccessSecretKeyPath);
    if (!AccessSecretKey && AccessSecretKeyPath) {
        AccessSecretKey = Strip(TFileInput(AccessSecretKeyPath).ReadAll());
    }
    OriginalHost = GetHost();
    RequestTimeout = section->GetDirectives().Value<TDuration>("RequestTimeout", RequestTimeout);
    HandlerThreads = section->GetDirectives().Value<ui32>("HandlerThreads", HandlerThreads);

    const TString bucketsStr = section->GetDirectives().Value<TString>("PublicBuckets", TString());
    TVector<TString> bucketNames = StringSplitter(bucketsStr).SplitBySet(", ").SkipEmpty();
    PublicBuckets = TSet<TString>(bucketNames.begin(), bucketNames.end());
    PublicBucketHost = section->GetDirectives().Value("PublicBucketHost", PublicBucketHost);

    auto sections = section->GetAllChildren();
    InitPeriod = section->GetDirectives().Value<TDuration>("InitPeriod", InitPeriod);
    if (auto sectionIt = sections.find("EncryptedBuckets"); sectionIt != sections.end()) {
        const auto buckets = sectionIt->second->GetAllChildren();
        for (auto&& bucketSection : buckets) {
            auto bucket = MakeAtomicShared<TEncryptedBucket>(bucketSection.first);
            bucket->Init(bucketSection.second);
            EncryptedBuckets.emplace(bucket->GetBucketName(), std::move(bucket));
        }
    }
}

void TS3ClientConfig::ToString(IOutputStream& os) const {
    TAsyncApiImpl::TConfig::ToString(os);
    os << "AccessKeyId: " << AccessKeyId << Endl;
    os << "AccessKeyIdPath: " << AccessKeyIdPath << Endl;
    os << "AccessSecretKeyPath: " << AccessSecretKeyPath << Endl;
    os << "RequestTimeout: " << RequestTimeout << Endl;
    os << "HandlerThreads: " << HandlerThreads << Endl;
    os << "PublicBuckets: " << JoinSeq(", ", PublicBuckets) << Endl;
    os << "PublicBucketHost: " << PublicBucketHost << Endl;
    os << "InitPeriod: " << InitPeriod << Endl;
    if (!EncryptedBuckets.empty()) {
        os << "<EncryptedBuckets>" << Endl;
        for (const auto& bucket : EncryptedBuckets) {
            os << "<" << bucket.second->GetName() << ">" << Endl;
            bucket.second->ToString(os);
            os << "</" << bucket.second->GetName() << ">" << Endl;
        }
        os << "</EncryptedBuckets>" << Endl;
    }
}

TString TS3ClientConfig::SignString(const TString& request) const {
    unsigned char buf[EVP_MAX_MD_SIZE];
    NOpenSSL::THmacCtx ctx;
    unsigned int len;
    HMAC_Init_ex(ctx, AccessSecretKey.data(), AccessSecretKey.size(), EVP_sha1(), nullptr);
    HMAC_Update(ctx, (const unsigned char*)request.data(), request.size());
    HMAC_Final(ctx, buf, &len);
    TString signature = Base64Encode(TStringBuf((const char*)buf, len));
    return "AWS " + AccessKeyId + ":" + signature;
}

NNeh::THttpRequest TS3ClientConfig::CreateCommonRequest(const TString& authStr, const TString& date, const TString& method, const TString& postData) const {
    NNeh::THttpRequest req;

    if (method == "POST" || method == "PUT") {
        req.SetPostData(TBlob::FromString(postData));
    }
    req.SetRequestType(method);
    req.AddHeader("Date", date);
    req.AddHeader("Authorization", SignString(authStr));

    return req;
}

TS3ClientConfig TS3ClientConfig::GetBucketConfig(const TString& bucketName) const {
    TS3ClientConfig bucketConfig(*this);
    bucketConfig.SetHost(bucketName + "." + GetHost());
    return bucketConfig;
}

TString TS3ClientConfig::GetTmpFilePath(const TString& bucketName, const TString& path) const {
    TStringBuilder result;
    if (GetIsHttps()) {
        result << "https";
    } else {
        result << "http";
    }
    result << "://";
    if (PublicBuckets.contains(bucketName)) {
        const auto& publicBucketHost = PublicBucketHost ? PublicBucketHost : OriginalHost;
        result << bucketName << '.' << publicBucketHost << '/';
    } else {
        result << GetOriginalHost() << '/' << bucketName << '/';
    }
    result << path;
    return result;
}
