#include "s3_client.h"

#include <util/stream/file.h>
#include <util/stream/str.h>
#include <util/string/printf.h>
#include <util/folder/path.h>
#include <util/datetime/cputimer.h>
#include <util/system/mutex.h>
#include <library/cpp/http/simple/http_client.h>
#include <library/cpp/json/json_reader.h>

#include <aws/core/Aws.h>
#include <aws/core/auth/AWSCredentialsProvider.h>
#include <aws/core/client/RetryStrategy.h>
#include <aws/s3/S3Client.h>

#include <aws/s3/model/DeleteObjectsRequest.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>
#include <aws/s3/model/ListObjectsV2Request.h>
#include <aws/s3/model/ListObjectsV2Result.h>
#include <aws/s3/model/PutObjectRequest.h>

#include <contrib/libs/aws-sdk-cpp/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h>
#include <contrib/libs/aws-sdk-cpp/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h>
#include <contrib/libs/aws-sdk-cpp/aws-cpp-sdk-transfer/include/aws/transfer/Transfer_EXPORTS.h>

#include <cstdio>

namespace {

    TString HttpGet(const TString &host,
                    const ui32 port,
                    const TString &uri)
    {
        TStringStream ss;
        TSimpleHttpClient client(host, port);

        client.DoGet(uri, &ss);

        return ss.Str();
    }

    ui64 LongLong2Ui64(long long n) {
        return n;
    }

    class TJsonFile {
    private:
        TString             Content;
        NJson::TJsonValue   Json;
    public:
        TJsonFile(const TString &fileName) {
            Content = TFileInput(fileName).ReadAll();
            if (!NJson::ReadJsonTree(Content, &Json)) {
                Cerr << "Failed to parse json from " << fileName << Endl;
            }
        }
        const std::string operator [](const TString &key) {
            return Json[key].GetString();
        }
    };

    struct RetryStrategy : public Aws::Client::RetryStrategy {
        bool ShouldRetry(
                const Aws::Client::AWSError<Aws::Client::CoreErrors> &error,
                long attemptedRetries) const override
        {
            if (error.ShouldRetry() && attemptedRetries < 3) {
                return true;
            }
            return false;
        }
        long CalculateDelayBeforeNextRetry(
                const Aws::Client::AWSError<Aws::Client::CoreErrors> &,
                long attemptedRetries) const override
        {
            return attemptedRetries*300;
        }
    };


    class TS3TransferManager {
    private:
        static TMutex Mutex;

        static const constexpr size_t BufferSize = 8 * 1024 * 1024; // 8 Mb
        static const constexpr size_t ThreadPoolSize = 4;

        static bool Debug;
        TString S3Bucket;
        THolder<Aws::Utils::Threading::PooledThreadExecutor> Executor;
        THolder<Aws::Transfer::TransferManagerConfiguration> Config;
        std::shared_ptr<Aws::Transfer::TransferManager> S3TM;

    public:
        TS3TransferManager(std::shared_ptr<Aws::S3::S3Client> pS3Client,
                           const TString &bucket)
            : S3Bucket(bucket)
        {
            Debug       = true;
            Executor.Reset(new Aws::Utils::Threading::PooledThreadExecutor(ThreadPoolSize));
            Config.Reset(new Aws::Transfer::TransferManagerConfiguration(Executor.Get()));

            Config->s3Client                        = pS3Client;
            Config->bufferSize                      = BufferSize;
            Config->transferBufferMaxHeapSize       = ThreadPoolSize*BufferSize*2;
            Config->errorCallback                   = [](
                    const Aws::Transfer::TransferManager *,
                    const std::shared_ptr<const Aws::Transfer::TransferHandle> &h,
                    const Aws::Client::AWSError<Aws::S3::S3Errors> &err)
            {
                std::stringstream ss;
                ss << "failed: " << err.GetMessage();
                DebugLog(h, ss.str());
            };
            Config->transferStatusUpdatedCallback   = [](
                    const Aws::Transfer::TransferManager *,
                    const std::shared_ptr<const Aws::Transfer::TransferHandle> &h)
            {
                Aws::Transfer::TransferStatus status = h->GetStatus();
                if (status == Aws::Transfer::TransferStatus::IN_PROGRESS) {
                    return;
                }
                std::stringstream ss;
                ss << status;
                if (status == Aws::Transfer::TransferStatus::FAILED) {
                    ss << ": " << h->GetLastError().GetMessage();
                }
                DebugLog(h, ss.str());
            };
            S3TM = Aws::Transfer::TransferManager::Create(*Config);
        }

        void SetDebug(bool debug = true) {
            Debug = debug;
        }

        static void DebugLog(const std::shared_ptr<const Aws::Transfer::TransferHandle> &h,
                             const std::string &str)
        {
            if (Debug) {
                TGuard<TMutex> g(Mutex);

                const std::string &key      = h->GetKey();
                const std::string &filename = h->GetTargetFilePath();
                bool isUpload = (h->GetTransferDirection() == Aws::Transfer::TransferDirection::UPLOAD);

                std::cerr << "Transfer "
                          << ((filename.size() == 0) ? "" : "file://" + filename + " ")
                          << ((isUpload) ? "->" : "<-")
                          << " s3://" << key
                          << " " << str
                          << std::endl;
            }
        }

        i32 Wait(std::vector<std::shared_ptr<Aws::Transfer::TransferHandle>> &handles,
                 size_t &bytes,
                 std::vector<std::shared_ptr<Aws::Transfer::TransferHandle>> &handlesCompleted)
        {
            bool nonRetryableError = false;
            for (auto it = handles.begin(); it != handles.end();) {
                std::shared_ptr<Aws::Transfer::TransferHandle> &h = *it;

                h->WaitUntilFinished();
                bytes += h->GetBytesTransferred();
                if (h->GetStatus() == Aws::Transfer::TransferStatus::COMPLETED) {
                    handlesCompleted.push_back(h);
                    handles.erase(it);
                } else if (h->GetLastError().ShouldRetry()) {
                    ++it;
                } else {
                    handles.erase(it);
                    nonRetryableError = true;
                }
            }
            return (nonRetryableError && handles.size() == 0) ? -1 : handles.size();
        }

        i32 Wait(std::vector<std::shared_ptr<Aws::Transfer::TransferHandle>> &handles,
                 size_t &bytes)
        {
            std::vector<std::shared_ptr<Aws::Transfer::TransferHandle>> handlesCompleted;

            return Wait(handles, bytes, handlesCompleted);
        }

        bool Put(std::unordered_map<std::string, std::stringstream> &keyToStringStream,
                 size_t &bytes,
                 int retry = 0)
        {
            i32 result = -1;
            std::vector<std::shared_ptr<Aws::Transfer::TransferHandle>> handles;
            std::unordered_map<std::string, std::shared_ptr<std::iostream>> keyToStreams;

            for (auto &[key, ss]: keyToStringStream) {
                std::shared_ptr<std::iostream>stream = std::shared_ptr<std::iostream>(&ss, [](void*){});

                handles.push_back(S3TM->UploadFile(
                    stream,
                    S3Bucket,
                    key,
                    Aws::DEFAULT_CONTENT_TYPE,
                    {}));
                keyToStreams[key] = stream;
            }
            bytes = 0;
            while (retry-- >= 0) {
                result = Wait(handles, bytes);
                if (result <= 0) {
                    break;
                }
                for (const auto &h: handles) {
                    const std::string &key = h->GetKey();
                    std::shared_ptr<std::iostream> stream = keyToStreams[key];
                    if (Debug) {
                        Cerr << "Retry putting to s3://" << key
                             << " (" << retry << " retries left)"
                             << Endl;
                    }
                    stream->seekg(0, std::ios_base::beg);
                    S3TM->RetryUpload(stream, h);
                }
            }
            return (result == 0);
        }

        bool Put(const std::unordered_map<std::string, std::string> &keyToFilename,
                 size_t &bytes,
                 int retry = 0)
        {
            i32 result = -1;
            std::vector<std::shared_ptr<Aws::Transfer::TransferHandle>> handles;

            for (const auto &[key, filename]: keyToFilename) {
                handles.push_back(S3TM->UploadFile(
                    filename,
                    S3Bucket,
                    key,
                    Aws::DEFAULT_CONTENT_TYPE,
                    {}));
            }
            bytes = 0;
            while (retry-- >= 0) {
                result = Wait(handles, bytes);
                if (result <= 0) {
                    break;
                }
                for (const auto &h: handles) {
                    const std::string &filename = h->GetTargetFilePath();
                    if (Debug) {
                        Cerr << "Retry transfer " << filename
                             << " -> s3://" << h->GetKey()
                             << " (" << retry << " retries left)"
                             << Endl;
                    }
                    S3TM->RetryUpload(filename, h);
                }
            }
            return (result == 0);
        }

        bool Get(std::unordered_map<std::string, std::iostream *> &keyToIOStream,
                 size_t &bytes,
                 int retry = 0)
        {
            i32 result = -1;
            std::vector<std::shared_ptr<Aws::Transfer::TransferHandle>> handles;

            for (auto &[key, ios]: keyToIOStream) {
                handles.push_back(S3TM->DownloadFile(
                    S3Bucket,
                    key,
                    [s = static_cast<Aws::IOStream *>(ios)]() {
                        return s;
                    }));
            }
            bytes = 0;
            while (retry-- >= 0) {
                result = Wait(handles, bytes);
                if (result <= 0) {
                    break;
                }
                for (const auto &h: handles) {
                    const std::string &key = h->GetKey();
                    keyToIOStream[key]->seekp(0, std::ios_base::beg);
                    if (Debug) {
                        Cerr << "Retry getting s3://" << key
                             << " (" << retry << " retries left)"
                             << Endl;
                    }
                    S3TM->RetryDownload(h);
                }
            }
            return (result == 0);
        }

        bool Get(std::unordered_map<std::string, std::stringstream> &keyToStringStream,
                 size_t &bytes,
                 int retry = 0)
        {
            std::unordered_map<std::string, std::iostream *> keyToIostream;

            for (auto &[key, ss]: keyToStringStream) {
                auto headRequest = Aws::S3::Model::HeadObjectRequest().WithBucket(S3Bucket).WithKey(key);
                auto out = Config->s3Client->HeadObject(headRequest);
                if (!out.IsSuccess()) {
                    if (Debug) {
                        Cerr << "Cannot get head for s3://" << key << ": "
                             << out.GetError().GetMessage()
                             << Endl;
                    }
                    continue;
                }
                // since aws is using multipart download for larger files (>TM bufsize) unconditionally,
                // we need to preallocate target buffer
                ss.str(std::string(out.GetResult().GetContentLength(), '\0'));
                keyToIostream[key] = new std::iostream(ss.rdbuf());
            }
            return Get(keyToIostream, bytes, retry);
        }

        bool Get(const std::unordered_map<std::string, std::string> &keyToFilename,
                 size_t &bytes,
                 int retry = 0)
        {
            i32 result = -1;
            std::vector<std::shared_ptr<Aws::Transfer::TransferHandle>> handles;

            for (const auto &[key, filename]: keyToFilename) {
                TFsPath filePath(filename);
                filePath.Parent().MkDirs(MODE0775);
                handles.push_back(S3TM->DownloadFile(
                    S3Bucket,
                    key,
                    filename));
            }
            bytes = 0;
            while (retry-- >= 0) {
                result = Wait(handles, bytes);
                if (result <= 0) {
                    break;
                }
                for (const auto &h: handles) {
                    if (Debug) {
                        Cerr << "Retry transfer s3://" << h->GetKey()
                             << " -> " << h->GetTargetFilePath()
                             << " (" << retry << " retries left)"
                             << Endl;
                    }
                    S3TM->RetryDownload(h);
                }
            }
            return (result == 0);
        }
    };
    TMutex TS3TransferManager::Mutex;
    bool TS3TransferManager::Debug;


    class TTimedS3Client {
    private:
        static const constexpr bool UseHttps = true;

        static const constexpr TDuration ConnectTimeout = TDuration::Seconds(5);
        static const constexpr TDuration RequestTimeout = TDuration::Seconds(3);
        static const constexpr TDuration ProxyLifetime  = TDuration::Seconds(10);

        TString S3Endpoint;
        TString S3Bucket;
        bool BalancerBypass;
        bool Debug;

        TSimpleTimer ProxyTimer;

        Aws::Auth::AWSCredentials Credentials;
        Aws::SDKOptions Options;
        std::shared_ptr<Aws::S3::S3Client> S3Client;
        THolder<TS3TransferManager> TM;

    public:
        TTimedS3Client(TString keyFile,
                       TString endpoint,
                       TString bucket,
                       bool balancerBypass = false,
                       bool verbose = true)
            : S3Endpoint(endpoint)
            , S3Bucket(bucket)
            , BalancerBypass(balancerBypass)
            , Debug(verbose)
        {
            TJsonFile jsonFile(keyFile);

            Aws::InitAPI(Options);
            Credentials.SetAWSAccessKeyId(jsonFile["AccessKeyId"]);
            Credentials.SetAWSSecretKey(jsonFile["AccessSecretKey"]);
            if (Debug) {
                Cerr << "TS3Client::Init using"
                     << " S3 access key "   << Credentials.GetAWSAccessKeyId()
                     << ", S3 endpoint "    << S3Endpoint
                     << ", S3 bucket "      << S3Bucket
                     << Endl;
            }
            Init();
        }
        ~TTimedS3Client() {
            Aws::ShutdownAPI(Options);
        }

        void Init() {
            Aws::Client::ClientConfiguration s3Config;

            s3Config.endpointOverride   = S3Endpoint;
            if (BalancerBypass) {
                s3Config.proxyHost      = HttpGet(((UseHttps) ? "https://" : "http://") + S3Endpoint, (UseHttps) ? 443 : 80, "/hostname");
                s3Config.proxyPort      = 4080;
                s3Config.proxyScheme    = Aws::Http::Scheme::HTTP;
                s3Config.scheme         = Aws::Http::Scheme::HTTP;
                ProxyTimer.Reset();
                if (Debug) {
                    Cerr << "Proxy host = " << s3Config.proxyHost
                         << Endl;
                }
            } else {
                s3Config.scheme         = (UseHttps) ? Aws::Http::Scheme::HTTPS : Aws::Http::Scheme::HTTP;
                s3Config.verifySSL      = false;
                if (Debug) {
                    Cerr << "Using direct connection to " << S3Endpoint
                         << Endl;
                }
            }
            s3Config.connectTimeoutMs   = ConnectTimeout.MilliSeconds();
            s3Config.requestTimeoutMs   = RequestTimeout.MilliSeconds();
            s3Config.retryStrategy.reset(new RetryStrategy());
            S3Client = std::shared_ptr<Aws::S3::S3Client>(new Aws::S3::S3Client(Credentials, s3Config));
            TM = THolder<TS3TransferManager>(new TS3TransferManager(S3Client, S3Bucket));
            SetDebug(Debug);
        }

        void SetDebug(bool debug = true) {
            Debug = debug;
            TM->SetDebug(debug);
        }

        void UpdateS3Proxy() {
            if (BalancerBypass && ProxyTimer.Get() >= ProxyLifetime) {
                Init();
            }
        }

        Aws::S3::S3Client *GetClient() {
            UpdateS3Proxy();
            return S3Client.get();
        }

        TS3TransferManager *GetTransferManager() {
            UpdateS3Proxy();
            return TM.Get();
        }
    };


    class TS3Client : public IS3Client {
    private:
        static const constexpr char *DirectoryDelimiter = "/";
        static const constexpr ui32 ListChunkSize = 1000;

        TString S3Bucket;
        bool Debug;

        THolder<TTimedS3Client> Client;

    public:
        TS3Client(TString keyFile,
                  TString endpoint,
                  TString bucket,
                  bool balancerBypass = false,
                  bool verbose = true)
            : S3Bucket(bucket)
        {
            Client = THolder<TTimedS3Client>(
                        new TTimedS3Client(keyFile,
                                           endpoint,
                                           bucket,
                                           balancerBypass,
                                           verbose));
        }
        ~TS3Client() override {
        }

        void SetDebug(bool debug) override {
            Debug = debug;
            Client->SetDebug(Debug);
        }

        bool List(TVector<TFileInfo> &files,
                  TVector<TDirInfo> &dirs,
                  const std::string &prefix = {},
                  i64 maxKeys = UNLIMITED) override
        {
            i64 nKeys = 0;
            std::string prefixModified = prefix;
            std::string token = {};

            while (true) {
                auto listRequest = Aws::S3::Model::ListObjectsV2Request().WithBucket(S3Bucket);

                listRequest.SetPrefix(prefixModified);
                listRequest.SetMaxKeys(ListChunkSize);
                listRequest.SetDelimiter(DirectoryDelimiter);
                if (nKeys > 0) {
                    listRequest.SetContinuationToken(token);
                }

                auto out = Client->GetClient()->ListObjectsV2(listRequest);

                if (!out.IsSuccess()) {
                    if (Debug) {
                        Cerr << "TS3Client::List error: " << out.GetError().GetMessage() << Endl;
                    }
                    return false;
                }
                // if no DirectoryDelimiter at the end, retry with it
                if (out.GetResult().GetCommonPrefixes().size() == 1
                        && out.GetResult().GetCommonPrefixes()[0].GetPrefix().compare(prefix + DirectoryDelimiter) == 0)
                {
                    prefixModified += DirectoryDelimiter;
                    continue;
                }
                for (const Aws::S3::Model::CommonPrefix &obj: out.GetResult().GetCommonPrefixes()) {
                    dirs.push_back({ obj.GetPrefix() });
                    if (++nKeys == maxKeys) {
                        return true;
                    }
                }
                for (const Aws::S3::Model::Object &obj: out.GetResult().GetContents()) {
                    files.push_back({
                                        obj.GetKey(),
                                        LongLong2Ui64(obj.GetSize()),
                                        TInstant::FromValue(obj.GetLastModified().Millis()*1000)
                                    });
                    if (++nKeys == maxKeys) {
                        return true;
                    }
                }
                if (!out.GetResult().GetIsTruncated()) {
                    return true;
                }
                token = out.GetResult().GetNextContinuationToken();
            }
            return true;
        }

        bool ListRecursive(TVector<TFileInfo> &files,
                           const std::string &prefix = {},
                           i64 maxKeys = UNLIMITED) override
        {
            i64 nKeys = 0;
            std::string prefixModified = prefix;
            std::string token = {};

            while (true) {
                auto listRequest = Aws::S3::Model::ListObjectsV2Request().WithBucket(S3Bucket);

                listRequest.SetPrefix(prefixModified);
                listRequest.SetMaxKeys(ListChunkSize);
                if (nKeys > 0) {
                    listRequest.SetContinuationToken(token);
                }

                auto out = Client->GetClient()->ListObjectsV2(listRequest);

                if (!out.IsSuccess()) {
                    if (Debug) {
                        Cerr << "TS3Client::List error: " << out.GetError().GetMessage() << Endl;
                    }
                    return false;
                }
                for (const Aws::S3::Model::Object &obj: out.GetResult().GetContents()) {
                    files.push_back({
                                        obj.GetKey(),
                                        LongLong2Ui64(obj.GetSize()),
                                        TInstant::FromValue(obj.GetLastModified().Millis()*1000)
                                    });
                    if (++nKeys == maxKeys) {
                        return true;
                    }
                }
                if (!out.GetResult().GetIsTruncated()) {
                    return true;
                }
                token = out.GetResult().GetNextContinuationToken();
            }
            return true;
        }

        bool Head(const std::string &key,
                  TFileInfo &file) override
        {
            auto headRequest = Aws::S3::Model::HeadObjectRequest().WithBucket(S3Bucket).WithKey(key);

            auto out = Client->GetClient()->HeadObject(headRequest);
            if (!out.IsSuccess()) {
                if (Debug) {
                    Cerr << "TS3Client::Head error: " << out.GetError().GetMessage() << Endl;
                }
                return false;
            }
            file.Name           = key;
            file.SizeBytes      = out.GetResult().GetContentLength();
            file.LastModified   = TInstant::FromValue(out.GetResult().GetLastModified().Millis()*1000);
            return true;
        }

        bool Put(const std::string &key,
                 std::stringstream &ss) override
        {
            std::shared_ptr<std::iostream>data = std::shared_ptr<std::iostream>(&ss, [](void*){});

            auto putRequest = Aws::S3::Model::PutObjectRequest().WithBucket(S3Bucket).WithKey(key);
            putRequest.SetBody(data);

            auto out = Client->GetClient()->PutObject(putRequest);
            if (!out.IsSuccess()) {
                if (Debug) {
                    Cerr << "TS3Client::Put error: " << out.GetError().GetMessage() << Endl;
                }
                return false;
            }
            return true;
        }

        bool Put(std::unordered_map<std::string, std::stringstream> &keyToStringStream,
                 size_t &bytes,
                 int retry = 0) override
        {
            return Client->GetTransferManager()->Put(keyToStringStream, bytes, retry);
        }

        bool Put(const std::unordered_map<std::string, std::string> &keyToFilename,
                 size_t &bytes,
                 int retry = 0) override
        {
            return Client->GetTransferManager()->Put(keyToFilename, bytes, retry);
        }

        bool Get(const std::string &key,
                 std::iostream *ios) override
        {
            auto getRequest = Aws::S3::Model::GetObjectRequest().WithBucket(S3Bucket).WithKey(key);

            auto out = Client->GetClient()->GetObject(getRequest);
            if (!out.IsSuccess()) {
                if (Debug) {
                    Cerr << "TS3Client::Get error: " << out.GetError().GetMessage() << Endl;
                }
                return false;
            }
            *ios << out.GetResult().GetBody().rdbuf();
            return true;
        }

        bool Get(std::unordered_map<std::string, std::iostream *> &keyToIOStream,
                 size_t &bytes,
                 int retry = 0) override
        {
            return Client->GetTransferManager()->Get(keyToIOStream, bytes, retry);
        }

        bool Get(std::unordered_map<std::string, std::stringstream> &keyToStringStream,
                 size_t &bytes,
                 int retry = 0) override
        {
            return Client->GetTransferManager()->Get(keyToStringStream, bytes, retry);
        }

        bool Get(const std::unordered_map<std::string, std::string> &keyToFilename,
                 size_t &bytes,
                 int retry = 0) override
        {
            return Client->GetTransferManager()->Get(keyToFilename, bytes, retry);
        }

        bool Del(const TVector<TString> &keys) override {
            bool result = true;
            const size_t maxCount = 1000;

            for (auto it = keys.begin(); it != keys.end();) {
                std::vector<Aws::S3::Model::ObjectIdentifier> objIds;

                for (size_t counter = 0; counter < maxCount && it != keys.end(); ++counter, ++it) {
                    objIds.push_back(Aws::S3::Model::ObjectIdentifier().WithKey(*it));
                }
                auto deleteRequest = Aws::S3::Model::DeleteObjectsRequest()
                        .WithBucket(S3Bucket)
                        .WithDelete(Aws::S3::Model::Delete().WithObjects(objIds));

                auto out = Client->GetClient()->DeleteObjects(deleteRequest);
                if (!out.IsSuccess()) {
                    if (Debug) {
                        Cerr << "TS3Client::Del error: " << out.GetError().GetMessage() << Endl;
                    }
                    result = false;
                } else if (Debug) {
                    Cerr << "Removed " << objIds.size() << " objects" << Endl;
                }
            }
            return result;
        }
    };
}


THolder<IS3Client>CreateS3Client(TString keyFile,
                                 TString endpoint,
                                 TString bucket,
                                 bool balancerBypass,
                                 bool verbose)
{
    return THolder<IS3Client>(new TS3Client(keyFile, endpoint, bucket, balancerBypass, verbose));
}
