#include "yasm_decoder.h"
#include "common.h"
#include "encoder.h"
#include "instance_key.h"
#include "message_string_pool.h"

#include <solomon/libs/cpp/yasm/constants/labels.h>
#include <solomon/libs/cpp/yasm/shard_config/shard_config.h>

#include <infra/yasm/stockpile_client/common/base_types.h>

#include <library/cpp/consistent_hashing/consistent_hashing.h>
#include <library/cpp/containers/absl_flat_hash/flat_hash_set.h>
#include <library/cpp/monlib/metrics/metric_consumer.h>

using namespace NHistDb::NStockpile;
using namespace NMonitoring;
using namespace NSolomon::NYasm;
using namespace NYasm::NInterfaces::NInternal;
using namespace NZoom;

namespace NSolomon::NFetcher::NYasm {
namespace {

    class TRecordWrapper {
    public:
        TRecordWrapper(IMultiShardEncoder& encoder, const IMessageStringPool& pool)
            : Encoder_{encoder}
            , StringPool_{pool}
        {
        }

        virtual ~TRecordWrapper() = default;
        virtual void ConsumeTags(IMultiShardEncoder&) { }
        virtual TYasmShardKey MakeShardKey(TStringBuf signalName) const = 0;
        virtual TYasmShardKey MakeGroupShardKey(TStringBuf signalName) const = 0;

        void ConsumeSignal(TStringBuf signalName, const TValue& value) {
            Encoder_.SwitchShards();
            Encoder_.AddShard(MakeShardKey(signalName));
            Encoder_.AddShard(MakeGroupShardKey(signalName));

            if (Encoder_.SupportedValue(value)) {
                Encoder_.WriteLabel(LABEL_SIGNAL, signalName);
                ConsumeTags(Encoder_);
                Encoder_.WriteValue(value);
            }
        }

        void Consume(const TRecord& record) {
            if (record.ValuesSize() == 0) {
                return;
            }

            for (const auto& signal: record.GetValues()) {
                auto signalName = StringPool_.SignalName(signal.GetSignalName());
                ConsumeSignal(signalName, signal.GetValue());
            }
        }

        void ConsumeCommon(TStringBuf instanceType, const TRecord& record, const TRecord& commonRecord) {
            if (commonRecord.ValuesSize() == 0) {
                return;
            }

            const auto& commonTable = GetCommonRulesTable();
            auto it = commonTable.find(instanceType);
            if (it == commonTable.end()) {
                return;
            }

            const auto& rule = it->second;
            if (!rule.empty()) {
                TCommonRulesMatchTagsConsumer consumer(rule);
                ConsumeTags(consumer);
                if (!consumer.Match()) {
                    return;
                }
            }

            absl::flat_hash_set<ui32> signals;
            signals.reserve(record.ValuesSize());
            for (const auto& signal: record.GetValues()) {
                signals.emplace(signal.GetSignalName().index());
            }

            for (const auto& signal: commonRecord.GetValues()) {
                const auto& signalName = signal.GetSignalName();
                if (!signals.contains(signalName.index())) {
                    ConsumeSignal(StringPool_.SignalName(signalName), signal.GetValue());
                }
            }
        }

    protected:
        IMultiShardEncoder& Encoder_;
        const IMessageStringPool& StringPool_;
    };

    class TPerInstanceWrapper: public TRecordWrapper {
    public:
        TPerInstanceWrapper(
                IMultiShardEncoder& encoder,
                IMessageStringPool& pool,
                const TPerInstanceData& data,
                ui64 shardCount)
            : TRecordWrapper{encoder, pool}
            , InstanceType_{data.GetInstanceType()}
            , Data_{data}
            , ShardCount_{shardCount}
        {
        }

        void ConsumeInstanceKey(const ::NYasm::NInterfaces::NInternal::TInstanceKey& protoKey, IMultiShardEncoder& encoder) {
            TInstanceKey instanceKey{StringPool_.InstanceKey(protoKey)};
            instanceKey.ForEachTag([&](auto key, auto value) {
                encoder.WriteLabel(key, value);
            });

            instanceKey.ForEachAggr([&](auto aggr) {
                encoder.WriteLabel(aggr, NSolomon::NYasm::AGGREGATED_MARKER);
            });
        }

        void ConsumeTags(IMultiShardEncoder& encoder) override {
            for (auto&& protoKey: Data_.GetInstanceKeys()) {
                ConsumeInstanceKey(protoKey, encoder);
            }
        }

        TYasmShardKey MakeShardKey(TStringBuf signalName) const override {
            return {InstanceType_, Data_.GetHostName(), signalName, ShardCount_};
        }

        TYasmShardKey MakeGroupShardKey(TStringBuf signalName) const override {
            return {InstanceType_, {}, signalName, ShardCount_};
        }

    private:
        TStringBuf InstanceType_;
        const TPerInstanceData& Data_;
        ui64 ShardCount_;
    };

    class TAggregatedWrapper: public TRecordWrapper {
    public:
        TAggregatedWrapper(
                IMultiShardEncoder& encoder,
                IMessageStringPool& pool,
                TInstanceKey instanceKey,
                TStringBuf hostName,
                ui64 shardCount)
            : TRecordWrapper{encoder, pool}
            , InstanceKey_{instanceKey}
            , HostName_{hostName}
            , ShardCount_{shardCount}
        {
        }

        void ConsumeTags(IMultiShardEncoder& encoder) override {
            InstanceKey_.ForEachTag([&](auto key, auto value) {
                encoder.WriteLabel(key, value);
            });

            InstanceKey_.ForEachAggr([&](auto aggr) {
                encoder.WriteLabel(aggr, NSolomon::NYasm::AGGREGATED_MARKER);
            });
        }

        TYasmShardKey MakeShardKey(TStringBuf signalName) const override {
            return {InstanceKey_.GetItype(), HostName_, signalName, ShardCount_};
        }

        TYasmShardKey MakeGroupShardKey(TStringBuf signalName) const override {
            return {InstanceKey_.GetItype(), InstanceKey_.GetHostName(), signalName, ShardCount_};
        }

    private:
        TInstanceKey InstanceKey_;
        TStringBuf HostName_;
        ui64 ShardCount_;
    };

    class TYasmRecordsDecoder {
    public:
        TYasmRecordsDecoder(
                const IShardConfig& shardConfig,
                IMultiShardEncoder& encoder,
                TStringBuf hostName,
                const IYasmItypeWhiteList* whiteList)
            : ShardConfig_{shardConfig}
            , Encoder_{encoder}
            , HostName_{hostName}
            , WhiteList_{whiteList}
        {
        }

        void Decode(const TPerInstanceRecords& records) {
            auto pool = CreateStringPool(records.GetSignalNameTable(), records.GetInstanceKeyTable());

            const TPerInstanceData* commonData = nullptr;
            for (const auto& data: records.GetRecords()) {
                if (TInstanceKey::IsCommonItype(data.GetInstanceType())) {
                    commonData = &data;
                    break;
                }
            }

            // consume per-instance data
            for (const auto& data: records.GetRecords()) {
                TStringBuf instanceType = data.GetInstanceType();
                if (WhiteList_ && !WhiteList_->Contain(instanceType)) {
                    continue;
                }

                const ui64 shardCount = ShardConfig_.GetShardCount(instanceType);
                TPerInstanceWrapper wrapped{Encoder_, *pool, data, shardCount};
                wrapped.Consume(data.GetRecord());

                if (commonData) {
                    wrapped.ConsumeCommon(instanceType, data.GetRecord(), commonData->GetRecord());
                }
            }
        }

        void Decode(const TAggregatedRecords& records) {
            auto pool = CreateStringPool(records.GetSignalNameTable(), records.GetInstanceKeyTable());

            const TAggregatedData* commonData = nullptr;
            for (const auto& data: records.GetRecords()) {
                TStringBuf instanceKey = pool->InstanceKey(data.GetInstanceKey());
                if (TInstanceKey::IsCommonItype(instanceKey)) {
                    commonData = &data;
                    break;
                }
            }

            // consume aggregated data
            for (const auto& data: records.GetRecords()) {
                TInstanceKey instanceKey{pool->InstanceKey(data.GetInstanceKey())};
                TStringBuf instanceType = instanceKey.GetItype();
                if (WhiteList_ && !WhiteList_->Contain(instanceType)) {
                    continue;
                }

                const ui64 shardCount = ShardConfig_.GetShardCount(instanceType);
                TAggregatedWrapper wrapped{Encoder_, *pool, instanceKey, HostName_, shardCount};
                wrapped.Consume(data.GetRecord());

                if (commonData) {
                    wrapped.ConsumeCommon(instanceType, data.GetRecord(), commonData->GetRecord());
                }
            }
        }

        void Finish(IDataConsumer* handler) {
            Encoder_.Close(handler);
        }

    private:
        const IShardConfig& ShardConfig_;
        IMultiShardEncoder& Encoder_;
        TString HostName_;
        const IYasmItypeWhiteList* WhiteList_ = nullptr;
    };

} // namespace

    TVector<TShardData> DecodeYasmAgentResponse(
        const TPerInstanceRecords& perInstance,
        const TAggregatedRecords& aggrRecords,
        const IShardConfig& shardConfig,
        THolder<IMultiShardEncoder> encoder,
        TStringBuf hostName,
        const IYasmItypeWhiteListPtr& whiteList)
    {
        TYasmRecordsDecoder decoder{shardConfig, *encoder, hostName, whiteList.Get()};
        decoder.Decode(perInstance);
        decoder.Decode(aggrRecords);

        struct TToVectorConsumer: public IDataConsumer {
            TVector<TShardData> Result;

            void OnShardData(TYasmShardKey key, TString data) override {
                Result.emplace_back(TShardData{
                    .Key = std::move(key),
                    .Format = EFormat::SPACK,
                    .Data = std::move(data),
                });
            }
        };

        TToVectorConsumer consumer;
        decoder.Finish(&consumer);
        return std::move(consumer.Result);
    }

    class TYasmAgentDecoder: public IYasmAgentDecoder {
    public:
        TYasmAgentDecoder(
                IShardConfigPtr shardConf,
                THolder<IMultiShardEncoder> encoder,
                TStringBuf hostName,
                IYasmItypeWhiteListPtr whiteList)
            : ShardConfig_{std::move(shardConf)}
            , Encoder_(std::move(encoder))
            , WhiteList_(std::move(whiteList))
            , Decoder_(MakeHolder<TYasmRecordsDecoder>(*ShardConfig_, *Encoder_, hostName, WhiteList_.Get()))
        {
        }

        void Decode(
                const TPerInstanceRecords& perInstanceRecords,
                const TAggregatedRecords& aggregatedRecords,
                IDataConsumer* consumer) override
        {
            Decoder_->Decode(perInstanceRecords);
            Decoder_->Decode(aggregatedRecords);
            Decoder_->Finish(consumer);
        }

    private:
        IShardConfigPtr ShardConfig_;
        THolder<IMultiShardEncoder> Encoder_;
        IYasmItypeWhiteListPtr WhiteList_;
        THolder<TYasmRecordsDecoder> Decoder_;
    };

    TIntrusivePtr<IYasmAgentDecoder> CreateYasmAgentDecoder(
            TStringBuf hostName,
            THolder<IMultiShardEncoder> encoder,
            IYasmItypeWhiteListPtr whiteList)
    {
        return CreateYasmAgentDecoder(
                CreateDefaultShardConfig(),
                std::move(encoder),
                hostName,
                std::move(whiteList));
    }

    TIntrusivePtr<IYasmAgentDecoder> CreateYasmAgentDecoder(
        IShardConfigPtr shardConf,
        THolder<IMultiShardEncoder> encoder,
        TStringBuf hostName,
        IYasmItypeWhiteListPtr whiteList)
    {
        return MakeIntrusive<TYasmAgentDecoder>(
                std::move(shardConf),
                std::move(encoder),
                hostName,
                std::move(whiteList));
    }

} // namespace NSolomon::NFetcher
