#include "michurin_state_operator.h"

#include <library/cpp/blockcodecs/codecs.h>
#include <library/cpp/iterator/zip.h>

#include <yt/yt/client/api/rowset.h>
#include <yt/yt/client/table_client/helpers.h>
#include <yt/yt/client/table_client/name_table.h>
#include <yt/yt/client/table_client/row_buffer.h>
#include <yt/yt/client/table_client/unversioned_row.h>

#include <util/generic/buffer.h>

namespace NMichurin {
    TMichurinStateTableOperator::TMichurinStateTableOperator(
        TOptions options)
        : Options(std::move(options))
    {
    }

    NYT::TFuture<TVector<TString>> TMichurinStateTableOperator::Load(
        NYT::NApi::IClientPtr client, const TVector<ui64>& ui64Keys) const {
        auto nameTable = NYT::New<NYT::NTableClient::TNameTable>();
        i32 keyField = nameTable->GetIdOrRegisterName(Options.KeyColumn);

        auto rowBuffer = NYT::New<NYT::NTableClient::TRowBuffer>();
        auto keys = TVector<NYT::NTableClient::TUnversionedRow>(
            Reserve(ui64Keys.size()));

        for (auto& key : ui64Keys) {
            NYT::NTableClient::TUnversionedRowBuilder builder;
            builder.AddValue(
                NYT::NTableClient::MakeUnversionedUint64Value(key, keyField));

            keys.push_back(rowBuffer->CaptureRow(builder.GetRow()));
        }

        const auto range = NYT::MakeSharedRange(
            std::move(keys), std::move(rowBuffer));

        NYT::NApi::TLookupRowsOptions options;
        options.Timeout = TDuration::Seconds(10);
        options.KeepMissingRows = true;

        return client->LookupRows(
                         Options.Table, nameTable, range, options)
            .Apply(BIND([count = ui64Keys.size(),
                         valueColumn = Options.ValueColumn,
                         codecColumn = Options.CodecColumn](NYT::NApi::IUnversionedRowsetPtr rowset) {
                Y_ENSURE(rowset->GetRows().size() == count);

                auto schema = rowset->GetSchema();
                i32 valueIndex = schema->GetColumnIndexOrThrow(valueColumn);

                bool isCompressed = codecColumn.Defined();

                i32 codecIndex = -1;
                if (isCompressed) {
                    codecIndex = schema->GetColumnIndexOrThrow(
                        codecColumn.GetRef());
                }

                auto isStringLikeType = [](const auto& data, bool allowOptional) {
                    switch (data.Type) {
                        case NYT::NTableClient::EValueType::String:
                            return true;

                        case NYT::NTableClient::EValueType::Null:
                            return allowOptional;

                        default:
                            return false;
                    }
                };

                TBuffer buffer;

                auto values = TVector<TString>(Reserve(count));

                for (const auto& row : rowset->GetRows()) {
                    values.emplace_back();

                    if (!row) {
                        continue;
                    }

                    const auto& valueData = row[valueIndex];
                    Y_ENSURE(isStringLikeType(valueData, false));

                    if (!valueData.Length) {
                        continue;
                    }

                    auto valueDataView = TStringBuf(
                        valueData.Data.String, valueData.Length);

                    if (isCompressed) {
                        const auto& codecData = row[codecIndex];
                        Y_ENSURE(isStringLikeType(codecData, true));

                        if (isStringLikeType(codecData, false)) {
                            const auto codecPtr = NBlockCodecs::Codec({codecData.Data.String, codecData.Length});

                            buffer.Clear();
                            codecPtr->Decode(valueDataView, buffer);

                            valueDataView = TStringBuf(
                                buffer.data(), buffer.size());
                        }
                    }

                    values.back() = TString(valueDataView);
                }

                return values;
            }));
    }

    void TMichurinStateTableOperator::Write(
        NYT::NApi::ITransactionPtr tx,
        const TVector<ui64>& keys, const TVector<TString>& values) const {
        auto nameTable = NYT::New<NYT::NTableClient::TNameTable>();
        i32 keyField = nameTable->GetIdOrRegisterName(Options.KeyColumn);
        i32 valueField = nameTable->GetIdOrRegisterName(Options.ValueColumn);

        bool isCompressed = Options.CodecColumn.Defined() && Options.Codec.Defined();

        i32 codecField = -1;
        TMaybe<TString> codecName;
        if (isCompressed) {
            codecField = nameTable->GetIdOrRegisterName(
                Options.CodecColumn.GetRef());

            codecName = Options.Codec.GetRef();
        }

        auto rowBuffer = NYT::New<NYT::NTableClient::TRowBuffer>();
        auto rows = TVector<NYT::NApi::TRowModification>();

        auto makeUnversionedValue = [&rowBuffer](i32 field, TStringBuf value) {
            return rowBuffer->CaptureValue(
                NYT::NTableClient::MakeUnversionedStringValue(value, field));
        };
        auto makeUnversionedUi64Value = [&rowBuffer](i32 field, ui64 value) {
            return rowBuffer->CaptureValue(
                NYT::NTableClient::MakeUnversionedUint64Value(value, field));
        };

        TBuffer buffer;
        for (auto [key, value] : Zip(keys, values)) {
            if (value.empty()) {
                auto row = rowBuffer->AllocateUnversioned(1);
                row[keyField] = makeUnversionedUi64Value(keyField, key);
                rows.push_back({NYT::NApi::ERowModificationType::Delete, row.ToTypeErasedRow(), NYT::NTableClient::TLockMask()});
            } else {
                auto row = rowBuffer->AllocateUnversioned(2 + isCompressed);

                auto valueView = TStringBuf(value);
                if (isCompressed) {
                    buffer.Clear();

                    const auto codec = NBlockCodecs::Codec(codecName.GetRef());
                    codec->Encode(valueView, buffer);

                    valueView = TStringBuf(buffer.data(), buffer.size());

                    row[codecField] = makeUnversionedValue(
                        codecField, codecName.GetRef());
                }

                row[keyField] = makeUnversionedUi64Value(keyField, key);
                row[valueField] = makeUnversionedValue(valueField, valueView);

                rows.push_back({NYT::NApi::ERowModificationType::Write,
                                row.ToTypeErasedRow(), NYT::NTableClient::TLockMask()});
            }
        }

        const auto range = NYT::MakeSharedRange(
            std::move(rows), std::move(rowBuffer));

        tx->ModifyRows(Options.Table, nameTable, range);
    }
}
