#include "cmd_base.h"
#include "lz4.h"

#include <util/generic/algorithm.h>
#include <util/generic/singleton.h>
#include <util/generic/utility.h>
#include <util/stream/format.h>
#include <util/folder/path.h>
#include <util/datetime/cputimer.h>
#include <library/cpp/json/json_writer.h>
#include <library/cpp/json/json_reader.h>

#include <sstream>

namespace {

template<class It>
std::pair<It, It> EqualPrefixRange(It it1,
                                   It it2,
                                   const TString &s)
{
    size_t sSize = s.size();

    auto lower = LowerBound(it1, it2, s, [sSize](const auto &f, const TString &s) {
        return f.Name.compare(0, sSize, s) < 0;
    });
    auto upper = UpperBound(it1, it2, s, [sSize](const TString &s, const auto &f) {
        return f.Name.compare(0, sSize, s) > 0;
    });
    return std::make_pair(lower, upper);
}

template<class It, class T, class Compare>
It BinaryFind(It first,
              It last,
              const T &value,
              Compare comp)
{
    It it;
    typename std::iterator_traits<It>::difference_type count, step;
    count = std::distance(first, last);

    while (count > 0) {
        it = first;
        step = count / 2;
        std::advance(it, step);
        if (comp(*it, value) < 0) {
            first = ++it;
            count -= step + 1;
        }
        else
            count = step;
    }
    return (first != last && !comp(*first, value)) ? first : last;
}

auto BinaryFindFileInfo(const TVector<IS3Client::TFileInfo> &s3Files,
                        const NSolomon::TFileInfo &kikimrFile,
                        size_t s3PrefixLength,
                        bool checkFileSizes = true)
{
    auto it = BinaryFind(s3Files.begin(), s3Files.end(), kikimrFile,
                        [s3PrefixLength, checkFileSizes](const auto &v, const auto &e) {
                            int n = v.Name.compare(s3PrefixLength + 1, v.Name.size(), e.Name);
                            return ((n == 0 && checkFileSizes) ? (e.SizeBytes != v.SizeBytes) : n);
                        });
    return it;
}

TString Hostname() {
    TString fqdn;
    constexpr size_t maxNameSize = 1024;

    fqdn.ReserveAndResize(maxNameSize);
    if (gethostname(fqdn.begin(), maxNameSize) < 0) {
        ythrow yexception() << "Cannot gethostname: " << LastSystemErrorText();
    }
    fqdn.remove(fqdn.find('\0'));

    return fqdn;
}

bool PidExists(pid_t pid) {
    TFsPath pidDir = Sprintf("/proc/%u", pid);

    return pidDir.Exists();
}


/**
 * Restore KV-tablet files from S3.
 */
class TCmd3SRestore: public TCliS3CommandTablet {
private:
    const TString RestorePrefix     = "ks3restore.";
    const TString ChunkPrefix       = "chunk.";
    const TString LockFile          = ".lock";
    const TString StateFile         = ".state";
    const TDuration ExpirationTime  = TDuration::Seconds(120);
    const size_t WriteChunkSize     = 16*1024*1024;

    TString         FQDN    = Hostname();
    pid_t           PID     = getpid();
    ui64            KGeneration;
    TSimpleTimer    LockTimer;
    bool            LockInited = false;

    bool ReadLock(ui64 tabletId,
                  TString &fqdn,
                  pid_t &pid,
                  ui32 &timestamp)
    {
        NJson::TJsonValue json;
        TString data = WaitAndCheck(KvClient_->ReadFile(tabletId, LockFile));

        if (!NJson::ReadJsonTree(data, &json)) {
            return false;
        }
        fqdn        = json["fqdn"].GetString();
        pid         = json["pid"].GetInteger();
        timestamp   = json["expireAt"].GetInteger();

        return true;
    }

    void WriteLock(ui64 tabletId) {
        NJson::TJsonValue json;

        json.InsertValue("fqdn", NJson::TJsonValue(FQDN));
        json.InsertValue("pid", NJson::TJsonValue(PID));
        json.InsertValue("expiredAt", NJson::TJsonValue(TInstant::Now().Seconds() + ExpirationTime.Seconds()));

        WaitAndCheck(KvClient_->WriteFile(tabletId, LockFile, NJson::WriteJson(json)));
        LockTimer.Reset();
    }

    bool TryLockTablet(ui64 tabletId) {
        TString fqdn;
        ui32    timestamp;
        pid_t   pid         = 0;
        ui32    expire      = 0;

        if (LockInited && LockTimer.Get() < ExpirationTime/2) {
            return true;
        }
        timestamp = TInstant::Now().Seconds();
        if (ReadLock(tabletId, fqdn, pid, expire) && expire > timestamp) {
            if (fqdn != FQDN) {
                Cout << "Tablet " << tabletId
                     << " is locked from host " << fqdn
                     << " (pid=" << pid
                     << ") will expire in " << expire - timestamp << "s"
                     << Endl;
                return false;
            }
            if (pid != PID && PidExists(pid)) {
                Cout << "Tablet " << tabletId
                     << " is locked from this host (pid=" << pid
                     << ") will expire in " << expire - timestamp << "s"
                     << Endl;
                return false;
            }
        }
        if (expire > 0) {
            Cout << "Writing lock. Current lock will expire in " << expire - timestamp << "s" << Endl;
        } else {
            Cout << "Writing lock. No good lock found." << Endl;
        }
        WriteLock(tabletId);
        LockInited = true;

        // increase generation of the tablet if we were too late
        if (expire <= timestamp) {
            ui64 generation = WaitAndCheck(KvClient_->IncrementGeneration(tabletId));
            if (generation == 0) {
                Cout << "Failed to increment generation for table " << tabletId << Endl;
                return false;
            }
            KGeneration = generation;
            Cout << "Incremented tablet generation to " << KGeneration << Endl;
        }
        return true;
    }

    void UnlockTablet(ui64 tabletId) {
        WaitAndCheck(KvClient_->RemoveFile(tabletId, LockFile));
        WaitAndCheck(KvClient_->IncrementGeneration(tabletId));
    }

private:
    void Options(NLastGetopt::TOpts *opts) override {
        TCliS3CommandTablet::Options(opts);

        opts->AddLongOption("prefix")
                .Help("prefix for storage paths in s3")
                .RequiredArgument("PREFIX")
                .Required();
        opts->AddLongOption("src-dir")
                .Help("source directory in PREFIX, instead of shardId")
                .RequiredArgument("DIRECTORY")
                .DefaultValue("")
                .Optional();
        opts->AddLongOption("force")
                .Help("overwrite any file already in kikimr")
                .NoArgument()
                .Optional();
        opts->AddLongOption("amp")
                .Help("amplify upload by this factor")
                .RequiredArgument("FACTOR")
                .DefaultValue("1")
                .Optional();
        opts->AddLongOption("stats")
                .Help("print statistics")
                .NoArgument()
                .Optional();
    }

    int RunS3OnTablet(ui64 tabletId, ui64 shardId, const NLastGetopt::TOptsParseResult &opts) override {
        ui32 counter         = 0;
        ui64 bytesTotal      = 0;
        ui64 bytesDownload   = 0;
        ui64 bytesDownloaded = 0;
        ui64 bytesRemove     = 0;
        int retryCount       = 3;
        ui32 totalFiles;
        TVector<NSolomon::TFileInfo> kikimrFiles;
        TVector<NSolomon::TFileInfo> kikimrFilesNew;
        TVector<IS3Client::TFileInfo> s3FilesToRestore;
        TVector<IS3Client::TDirInfo> s3Dirs;
        TVector<TString> kikimrFilesToRemove;
        TString s3Prefix    = TFsPath(opts.Get("prefix")).Fix().GetPath();
        TString srcDir      = opts.Get("src-dir");
        bool forceOverwrite = opts.Has("force");
        bool printStats     = opts.Has("stats");
        ui32 ampFactor      = std::stoul(opts.Get("amp"));
        TString s3StateFile;

        s3Prefix += (s3Prefix.EndsWith('/') ? "" : "/") + ((srcDir.size() > 0) ? srcDir : Sprintf("%04lu", shardId));
        s3StateFile = s3Prefix + "/" + StateFile;

        // list files
        if (!S3Client_->List(s3FilesToRestore, s3Dirs, s3Prefix)) {
            Cout << "Failed to list S3 files at " << s3Prefix << Endl;
            return 1;
        }
        if (s3FilesToRestore.empty()) {
            Cout << "No files at s3://" << s3Prefix << Endl;
            return 1;
        }
        SortBy(s3FilesToRestore, [](const IS3Client::TFileInfo &f) { return f.Name; });

        // if s3StateFile is present, something bad happened on dump, check this manually
        {
            auto it = BinaryFind(s3FilesToRestore.begin(), s3FilesToRestore.end(), s3StateFile,
                                 [](const auto &s, const auto &k) { return s.Name.compare(k); });
            if (it != s3FilesToRestore.end()) {
                Cout << "Found state file " << s3StateFile << " - cannot restore" << Endl;
                return 1;
            }
        }
        if (printStats) {
            for (const IS3Client::TFileInfo &f: s3FilesToRestore) {
                bytesTotal += f.SizeBytes;
            }
        }

        // lock tablet
        if (!TryLockTablet(tabletId)) {
            Cout << "Cannot lock tablet " << tabletId << ". Exiting ..." << Endl;
            return 1;
        }

        // list kikimr files
        kikimrFiles = WaitAndCheck(KvClient_->ListFiles(tabletId));
        SortBy(kikimrFiles, [](const NSolomon::TFileInfo &f) { return f.Name; });
        {
            auto it = BinaryFind(kikimrFiles.begin(), kikimrFiles.end(), LockFile,
                                 [](const auto &s, const auto &k) { return s.Name.compare(k); });
            if (it != kikimrFiles.end()) {
                kikimrFiles.erase(it);
            }
        }

        // create list to remove, create list to upload. Do not compare by size, since they could be compressed
        bool compareBySize = false;
        for (const NSolomon::TFileInfo &f: kikimrFiles) {
            NSolomon::TFileInfo file;
            if (f.Name.StartsWith(RestorePrefix)) {
                TString shortName = f.Name.substr(RestorePrefix.size());
                file = {shortName, f.SizeBytes, f.CreatedAt};
            } else {
                file = f;
            }
            auto it = BinaryFindFileInfo(s3FilesToRestore, file, s3Prefix.size(), compareBySize);
            if (it != s3FilesToRestore.end()) {
                ++counter;
                if (forceOverwrite) {
                    kikimrFilesToRemove.push_back(f.Name);
                    bytesRemove += f.SizeBytes;
                } else {
                    s3FilesToRestore.erase(it);
                }
            } else {
                kikimrFilesToRemove.push_back(f.Name);
                bytesRemove += f.SizeBytes;
            }
        }
        Cout << "Working at s3 prefix='"    << s3Prefix << "'. "
             << counter                     << " files already in kikimr"
             << ((forceOverwrite) ? " (will be overwritten)" : "") << ". "
             << s3FilesToRestore.size()     << " files are to be restored. "
             << kikimrFilesToRemove.size()  << " extra files in kikimr will be removed."
             << Endl;
        if (printStats) {
            for (const IS3Client::TFileInfo &f: s3FilesToRestore) {
                bytesDownload += f.SizeBytes;
            }
            Cout << "STATS ON S3PREFIX "            << s3Prefix
                 << " BYTES_TOTAL_COMPRESSED: "     << bytesTotal
                 << " BYTES_DOWNLOAD_COMPRESSED: "  << bytesDownload
                 << " BYTES_REMOVE_FROM_KIKIMR: "   << bytesRemove
                 << Endl;
        }

        // upload files with RestorePrefix
        counter = 0;
        totalFiles = s3FilesToRestore.size();
        auto it = s3FilesToRestore.begin();
        while (it != s3FilesToRestore.end()) {
            ui32            fileCounter;
            TSimpleTimer    simpleTimer;
            size_t          sizeUnpack = 0;
            size_t          size = 0;
            ui64            timeGetUs;
            ui64            timePutUs;
            TMap<TString, TVector<NSolomon::TAsyncKvResult<void>>> nameToWrites;
            TMap<TString, TVector<TString>> nameToChunks;
            std::unordered_map<std::string, std::stringstream> keyToStringStream;

            keyToStringStream.clear();
            for (fileCounter = 0; fileCounter < ampFactor && it != s3FilesToRestore.end(); ++fileCounter, ++it, ++counter) {
                keyToStringStream[it->Name] = std::stringstream(std::ios_base::binary|std::ios_base::out|std::ios_base::in);
                Cout << "Reading from s3 " << it->Name << " (size=" << it->SizeBytes << ")." << Endl;
            }
            simpleTimer.Reset();
            if (!S3Client_->Get(keyToStringStream, size, retryCount)) {
                Cout << "Failed to get files from s3. Exiting ..." << Endl;
                return 1;
            }
            timeGetUs = simpleTimer.Get().MicroSeconds();
            Cout << Sprintf("Read %u files from s3 (%u/%u), %.3f MB/s on total of %.3f MB. Sending them to kikimr",
                            fileCounter,
                            counter,
                            totalFiles,
                            static_cast<double>(size)/static_cast<double>(timeGetUs),
                            static_cast<double>(size)/1024/1024)
                 << Endl;
            simpleTimer.Reset();
            for (auto &[ k, ss ]: keyToStringStream) {
                TString dataDst;
                TVector<NSolomon::TAsyncKvResult<void>> chunkToWrites;
                TVector<TString> writeChunks;

                try {
                    TLz4Decompress zFile(&ss);
                    dataDst = zFile.ReadAll();
                } catch (TLZ4TypeError) {
                    dataDst = TString(ss.str());
                }
                TString fileName = k.substr(s3Prefix.size() + 1);
                TString kikimrName = RestorePrefix + fileName;
                sizeUnpack += dataDst.size();

                if (dataDst.size() > WriteChunkSize) {
                    for (size_t offt = 0; offt < dataDst.size(); offt += WriteChunkSize) {
                        TString chunkName = Sprintf("%s%s%s.%03lu", RestorePrefix.c_str(), ChunkPrefix.c_str(), fileName.c_str(), writeChunks.size());
                        chunkToWrites.push_back(KvClient_->WriteFile(tabletId, chunkName, dataDst.substr(offt, WriteChunkSize)));
                        writeChunks.push_back(chunkName);
                    }
                    nameToChunks[kikimrName] = writeChunks;
                } else {
                    chunkToWrites.push_back(KvClient_->WriteFile(tabletId, kikimrName, dataDst));
                }
                nameToWrites[kikimrName] = chunkToWrites;
                Cout << "Sending to kikimr " << tabletId << "/" << kikimrName << " (size=" << dataDst.size() << " chunks=" << chunkToWrites.size() << ")." << Endl;

                auto itk = BinaryFind(kikimrFiles.begin(), kikimrFiles.end(), kikimrName,
                                      [](const auto &v, const auto &e) { return v.Name.compare(e); });
                if (itk == kikimrFiles.end()) {
                    kikimrFiles.push_back({kikimrName, static_cast<ui32>(dataDst.size()), {}});
                }
            }
            for (auto &[ n, ws ]: nameToWrites) {
                for (auto &w: ws) {
                    WaitAndCheck(w);
                }
                if (ws.size() > 1) {
                    WaitAndCheck(KvClient_->ConcatFiles(tabletId, nameToChunks[n], n, false));
                }
            }
            timePutUs = simpleTimer.Get().MicroSeconds();

            Cout << Sprintf("Uploaded %u files to kikimr, %.3f MB/s on total of %.3f MB (compress ratio %.3f -> %.3f MB)",
                            fileCounter,
                            static_cast<double>(sizeUnpack)/static_cast<double>(timePutUs),
                            static_cast<double>(sizeUnpack)/1024/1024,
                            static_cast<double>(sizeUnpack)/static_cast<double>(size),
                            static_cast<double>(size)/1024/1024) << Endl;
            if (printStats) {
                bytesDownloaded += sizeUnpack;
                bytesDownload -= size;
                Cout << "STATS ON RESTORE"
                     << " TOTAL_BYTES_SENT: "   << bytesDownloaded
                     << " TOTAL_BYTES_LEFT: "   << bytesDownload
                     << " CHUNK_SIZE: "         << size
                     << " TIME_GET_US: "        << timeGetUs
                     << " TIME_PUT_US: "        << timePutUs
                     << Endl;
            }
            // update lock file in KV
            if (!TryLockTablet(tabletId)) {
                Cout << "Cannot lock tablet " << tabletId << ". Exiting ..." << Endl;
                return 1;
            }
        }
        SortBy(kikimrFiles, [](const auto& f) { return f.Name; });

        // list files in kikimr after upload
        kikimrFilesNew = WaitAndCheck(KvClient_->ListFiles(tabletId));
        SortBy(kikimrFilesNew, [](const auto& f) { return f.Name; });
        {
            auto it = BinaryFind(kikimrFilesNew.begin(), kikimrFilesNew.end(), LockFile,
                                 [](const auto &v, const auto &e) { return v.Name.compare(e); });
            if (it != kikimrFilesNew.end()) {
                kikimrFilesNew.erase(it);
            }
        }

        // check if somebody changed files in kikimr during upload
        if (kikimrFilesNew.size() != kikimrFiles.size()) {
            Cout << "Kikimr files had changed as we copied files from s3:" << Endl;
            for (auto itNew = kikimrFilesNew.begin(); itNew != kikimrFilesNew.end(); ++itNew) {
                auto it = BinaryFind(kikimrFiles.begin(), kikimrFiles.end(), itNew->Name,
                                     [](const auto &v, const auto &e) { return v.Name.compare(e); });
                if (it == kikimrFiles.end()) {
                    Cout << "File " << itNew->Name << " (size=" << itNew->SizeBytes << ") appeared" << Endl;
                }
            }
            for (auto itOld = kikimrFiles.begin(); itOld != kikimrFiles.end(); ++itOld) {
                auto it = BinaryFind(kikimrFilesNew.begin(), kikimrFilesNew.end(), itOld->Name,
                                     [](const auto &v, const auto &e) { return v.Name.compare(e); });
                if (it == kikimrFilesNew.end()) {
                    Cout << "File " << itOld->Name << " (size=" << itOld->SizeBytes << ") disappeared" << Endl;
                }
            }
            Cout << "Exiting..." << Endl;
            return 1;
        } else {
            auto itNew = kikimrFilesNew.begin();
            auto itOld = kikimrFiles.begin();
            for (; itNew != kikimrFilesNew.end(); ++itNew, ++itOld) {
                if (itNew->Name != itOld->Name || itNew->SizeBytes != itOld->SizeBytes) {
                    Cout << "Kikimr files had changed as we copied files from s3: "
                         << itOld->Name << " (size=" << itOld->SizeBytes
                         << ") -> "
                         << itNew->Name << " (size=" << itNew->SizeBytes
                         << ")"
                         << Endl;
                    Cout << "Exiting..." << Endl;
                    return 1;
                }
            }
        }

        // remove files from kikimr according to created list
        for (const TString &f: kikimrFilesToRemove) {
            Cout << "Removing kikimr file " << f << " ..." << Endl;
            WaitAndCheck(KvClient_->RemoveFile(tabletId, f));
        }

        if (!TryLockTablet(tabletId)) {
            Cout << "Cannot lock tablet " << tabletId << ". Exiting ..." << Endl;
            return 1;
        }
        // rename files in kikimr
        Cout << "Renaming restored files ..." << Endl;
        WaitAndCheck(KvClient_->RemovePrefix(tabletId, RestorePrefix));

        // remove lock
        Cout << "Restore is done. Removing table lock ..." << Endl;
        UnlockTablet(tabletId);

        return 0;
    }
};

} // namespace

TMainClass* Cmd3SRestore() {
    return Singleton<TCmd3SRestore>();
}
