#include "erasure.h"

extern "C" {
    #include <contrib/libs/jerasure/jerasure.h>
    #include <contrib/libs/jerasure/reed_sol.h>
}

#include <library/cpp/digest/crc32c/crc32c.h>

#include <util/stream/buffer.h>
#include <util/generic/hash.h>

namespace NSaas {
    TVector<char *> TErasureHelper::GetDataStartsPtrs(TVector<NSaas::TErasurePart>& parts) const {
        TVector<char *> result(MainNumber + ParityNumber);
        for (ui32 i = 0; i < MainNumber + ParityNumber; ++i) {
            result[i] = (char *)parts[i].MutableData()->begin();
        }

        return result;
    }

    // Method returns splitted object on MainNumber parts and encoded with JErasure ParityNumber parts
    TVector<NSaas::TErasurePart> TErasureHelper::SplitToParts(TStringBuf object) const {
        Y_ENSURE(object.size() > 0, "Empty object");

        TVector<NSaas::TErasurePart> result(MainNumber + ParityNumber);
        // Computing global check sum...
        ui32 globalCheckSum = Crc32c(object.begin(), object.size());
        ui32 totalSize = object.size();

        ui32 currentPosition = 0;
        ui32 dataSize = 0;
        for (ui32 i = 0; i < MainNumber; ++i) {
            ui32 sizeWithoutPadding = object.size() / MainNumber + (i < object.size() % MainNumber);
            bool paddingIsNecessary = (i >= object.size() % MainNumber);

            if (!dataSize) {
                dataSize = sizeWithoutPadding + (ui32)paddingIsNecessary;
            } else {
                // Checking that the data is the same size
                Y_ENSURE(dataSize == sizeWithoutPadding + (ui32)paddingIsNecessary, "Part size invariant was broken");
            }

            // Writing info, data and if necessary padding
            result[i].SetGlobalCheckSum(globalCheckSum);
            result[i].SetTotalSize(totalSize);
            result[i].SetLocalSizeWithoutPadding(sizeWithoutPadding);
            TString data = object.substr(currentPosition, sizeWithoutPadding) + TString(((paddingIsNecessary) ? "0" : ""));
            ui32 localCheckSum = Crc32c(data.begin(), data.size());
            result[i].SetLocalCheckSum(localCheckSum);
            result[i].SetData(std::move(data));
            result[i].SetPartIndex(i);

            currentPosition += sizeWithoutPadding;
        }

        // Allocating necessary memory for parity parts
        for (ui32 i = MainNumber; i < MainNumber + ParityNumber; ++i) {
            result[i].MutableData()->resize(dataSize);
        }
        // Extracting data start pointers of elements in result
        auto parts = GetDataStartsPtrs(result);
        // Process encoding
        int *matrix = reed_sol_vandermonde_coding_matrix(MainNumber, ParityNumber, 8);
        jerasure_matrix_encode(MainNumber, ParityNumber, 8, matrix, parts.begin(), parts.begin() + MainNumber, dataSize); // work, pls
        // Finishing initialization of parity blocks
        for (ui32 i = MainNumber; i < MainNumber + ParityNumber; ++i) {
            result[i].SetPartIndex(i);
            result[i].SetGlobalCheckSum(globalCheckSum);
            result[i].SetLocalCheckSum(Crc32c(result[i].MutableData()->begin(), dataSize));
            result[i].SetTotalSize(totalSize);
        }

        return result;
    }

    // Method believes that size of parts is n
    bool TErasureHelper::MainPartsAreCorrect(const TVector<NSaas::TErasurePart>& parts) const {
        if (parts.size() < MainNumber || parts[0].GetPartIndex() != 0) {
            return false;
        }

        for (ui32 i = 0; i < MainNumber; ++i) {
            ui32 checkSum = Crc32c(parts[i].GetData().begin(), parts[i].GetData().size());

            if (checkSum != parts[i].GetLocalCheckSum() || !parts[i].GetTotalSize()) {
                return false;
            }
        }

        return true;
    }

    // Method believes that every part is correct main part
    TString TErasureHelper::MergeFromMainParts(const TVector<NSaas::TErasurePart>& parts) const {
        Y_ENSURE(parts.size() >= MainNumber, "Some of main parts are missed");

        TBuffer buffer;
        TBufferOutput out(buffer);

        for (ui32 i = 0; i < MainNumber; ++i) {
            if (i == 0) {
                buffer.Reserve(parts[i].GetTotalSize()); // allocating necessary memory
            }

            ui32 sizeWithoutPadding = parts[i].GetLocalSizeWithoutPadding();
            auto Data = parts[i].GetData();
            out.Write(Data.begin(), sizeWithoutPadding);
        }

        // Moving to TString
        TString result;
        buffer.AsString(result);

        return result;
    }

    TString TErasureHelper::Merge(const TVector<NSaas::TErasurePart>& parts) const {
        TVector<NSaas::TErasurePart> newParts(MainNumber + ParityNumber);
        THashMap<int, int> correctParts;
        ui32 dataSize = 0;
        ui32 TotalSize = 0;
        ui32 GlobalCheckSum = 0;
        for (auto &part : parts) {
            ui32 checkSum = Crc32c(part.GetData().begin(), part.GetData().size());

            if (checkSum == part.GetLocalCheckSum()) {
                newParts[part.GetPartIndex()] = part;
                if (!dataSize) {
                    dataSize = part.GetData().size();
                }
                if (!TotalSize) {
                    TotalSize = part.GetTotalSize();
                }
                if (!GlobalCheckSum) {
                    GlobalCheckSum = part.GetGlobalCheckSum();
                }

                Y_ENSURE(dataSize == part.GetData().size(), "Correct data contains incorrect info");
                Y_ENSURE(TotalSize == part.GetTotalSize(), "Correct data contains incorrect info");
                Y_ENSURE(GlobalCheckSum == part.GetGlobalCheckSum(), "Correct data contains incorrect info");

                correctParts[(int)part.GetPartIndex()] += 1;
            }
        }

        TVector<int> erasures;
        for (int i = 0; i < (int)(MainNumber + ParityNumber); ++i) {
            if (correctParts.find(i) == correctParts.end()) {
                erasures.push_back(i);
            }
        }
        erasures.push_back(-1); // this shit is a need of library...

        if (MainPartsAreCorrect(newParts)) {
            return MergeFromMainParts(newParts);
        }

        Y_ENSURE(erasures.size() <= ParityNumber + 1, "Unable to recover the data, too many parts are corrupted");

        for (auto i : erasures) {
            if (i != -1) {
                newParts[i].MutableData()->resize(dataSize); // allocating necessary memory
            }
        }

        auto dataStarts = GetDataStartsPtrs(newParts);
        int *matrix = reed_sol_vandermonde_coding_matrix(MainNumber, ParityNumber, 8);
        auto result = jerasure_matrix_decode(MainNumber, ParityNumber, 8, matrix, 1, erasures.begin(), dataStarts.begin(), dataStarts.begin() + MainNumber, dataSize);

        Y_ENSURE(result == 0, "Bad finish status"); // checking that normally finished

        for (auto i : erasures) {
            if (i == -1) {
                continue;
            }

            newParts[i].SetPartIndex(i);
            newParts[i].SetLocalSizeWithoutPadding(dataSize - (ui32)(newParts[i].GetData()[dataSize - 1] == '0'));
            newParts[i].SetLocalCheckSum(Crc32c(newParts[i].GetData().begin(), newParts[i].GetData().size()));
            newParts[i].SetTotalSize(TotalSize);
            newParts[i].SetGlobalCheckSum(GlobalCheckSum);
        }

        Y_ENSURE(MainPartsAreCorrect(newParts), "Got incorrect parts after recovering");

        return MergeFromMainParts(newParts);
    }
}

