#include "id_hash.h"

#include <crypta/audience/lib/native/storage.h>

#include <crypta/lib/native/retargeting_ids/retargeting_id.h>

#include <ads/bsyeti/libs/experiments/plugins/bigb/plugin.h>

#include <util/digest/multi.h>
#include <util/generic/buffer.h>
#include <util/random/random.h>
#include <library/cpp/codecs/static/static.h>
#include <library/cpp/json/json_writer.h>
#include <library/cpp/protobuf/json/json2proto.h>

#include <random>

using namespace NCrypta;
using NJson::TJsonWriter;

i64 ToExportSegmentId(i64 audienceSegmentId) {
    if (NRetargetingIds::TRetargetingId(audienceSegmentId).GetType() == NRetargetingIds::EType::DirectLalSegment) {
        return audienceSegmentId;
    }

    return NRetargetingIds::TRetargetingId(audienceSegmentId, NRetargetingIds::EType::MetrikaAudienceSegment).Serialize();
}

TFilterStorage::TFilterStorage() {
}

void TFilterStorage::Do(TTableReader<TStorageEntry>* input, TTableWriter<TStorageEntry>* output) {
    for (; input->IsValid(); input->Next()) {
        auto& row = input->GetRow();
        // older than 01/01/2000 @ 12:00am (UTC)
        // TODO: is there any better way? icookie?
        auto validCookie = (row.GetYandexuid() > 946684800ul);
        if (validCookie) {
            output->AddRow(row);
        }
    }
}

static TString ToJson(const TString& idType, std::variant<ui64, TString> id, i64 timestamp, TVector<i64> segments) {
    TStringStream out;
    TJsonWriter json(&out, false);
    json.OpenMap();
    if (std::holds_alternative<ui64>(id)) {
        json.Write(idType, std::get<ui64>(id));
    }
    if (std::holds_alternative<TString>(id)) {
        json.Write(idType, std::get<TString>(id));
    }
    json.Write("timestamp", timestamp);
    json.Write("segments");
    json.OpenArray();
    for (const auto& segment : segments) {
        json.Write(ToExportSegmentId(segment));
    }
    json.CloseArray();
    json.CloseMap();
    json.Flush();
    return out.Str();
}

TCombineIntoJson::TCombineIntoJson()
    : Timestamp()
{
}

TCombineIntoJson::TCombineIntoJson(const TBuffer& timestamp)
    : Timestamp(FromString<ui64>(timestamp.Data(), timestamp.Size()))
{
}

void TCombineIntoJson::Do(TTableReader<TStorageEntry>* input, TTableWriter<TNode>* output) {
    ui64 yandexuid = 0;
    TVector<i64> segments{};
    for (bool isFirst = true; input->IsValid(); isFirst = false, input->Next()) {
        auto& row = input->GetRow();
        if (isFirst) {
            yandexuid = row.GetYandexuid();
        }
        segments.push_back(row.GetSegmentID());
    }
    output->AddRow(TNode()("random", RandomNumber<float>())("record", ToJson("yandexuid", yandexuid, Timestamp, segments)));
}

TPartitionStorageCryptaID::TPartitionStorageCryptaID() {
}

void TPartitionStorageCryptaID::Do(TTableReader<TCryptaIdStorageEntry>* input, TTableWriter<TCryptaIdStorageEntry>* output) {
    for (; input->IsValid(); input->Next()) {
        auto& row = input->GetRow();
        output->AddRow(row);
    }
}

TCombineIntoJsonCryptaID::TCombineIntoJsonCryptaID()
    : TStateful()
{
}

TCombineIntoJsonCryptaID::TCombineIntoJsonCryptaID(const TBuffer& buffer)
    : TStateful(buffer)
{
}

void TCombineIntoJsonCryptaID::Do(TTableReader<TCryptaIdStorageEntry>* input, TTableWriter<TNode>* output) {
    ui64 cryptaId = 0;

    ui64 count = 0;
    ui64 limit = State->GetOptions().GetMaxSegmentsPerUser();
    TVector<i64> segments{};
    segments.reserve(limit);

    for (bool isFirst = true; input->IsValid(); isFirst = false, input->Next()) {
        auto& row = input->GetRow();
        if (isFirst) {
            cryptaId = row.GetCryptaID();
        }
        if (segments.size() == limit) {
            ui64 candidate = RandomNumber<ui64>() % (count + 1);
            if (candidate < segments.size()) {
                segments[candidate] = row.GetSegmentID();
            }
        } else {
            segments.push_back(row.GetSegmentID());
        }
    }
    output->AddRow(TNode()("random", RandomNumber<float>())("record", ToJson("cryptaId", cryptaId, State->GetTimestamp(), segments)));
}

TPartitionStorageDevices::TPartitionStorageDevices() {
}

void TPartitionStorageDevices::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    for (; input->IsValid(); input->Next()) {
        auto& row = input->GetRow();
        output->AddRow(row);
    }
}

TCombineIntoJsonDevices::TCombineIntoJsonDevices()
    : TStateful()
{
}

TCombineIntoJsonDevices::TCombineIntoJsonDevices(const TBuffer& buffer)
    : TStateful(buffer)
{
}

void TCombineIntoJsonDevices::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    TMaybe<TString> idfa;
    TMaybe<TString> gaid;

    ui64 count = 0;
    ui64 limit = State->GetOptions().GetMaxSegmentsPerUser();
    TVector<i64> segments{};
    segments.reserve(limit);

    for (bool isFirst = true; input->IsValid(); isFirst = false, input->Next()) {
        auto& row = input->GetRow();

        auto maybeIdfa = row["Idfa"];
        if (maybeIdfa.IsString()) {
            idfa = maybeIdfa.AsString();
        }

        auto maybeGaid = row["Gaid"];
        if (maybeGaid.IsString()) {
            gaid = maybeGaid.AsString();
        }

        auto maybeSegment = row["SegmentID"];
        if (maybeSegment.IsInt64()) {
            if (segments.size() == limit) {
                ui64 candidate = RandomNumber<ui64>() % (count + 1);
                if (candidate < segments.size()) {
                    segments[candidate] = maybeSegment.AsInt64();
                }
            } else {
                segments.push_back(maybeSegment.AsInt64());
            }
        }
    }

    if (segments.empty()) {
        return;
    }

    auto randomKey = RandomNumber<float>();
    auto timestamp = State->GetTimestamp();
    if (idfa) {
        output->AddRow(TNode()("random", randomKey)("record", ToJson("idfa", idfa.GetRef(), timestamp, segments)));
    }
    if (gaid) {
        output->AddRow(TNode()("random", randomKey)("record", ToJson("gaid", gaid.GetRef(), timestamp, segments)));
    }
}

TUpdateFullState::TUpdateFullState() = default;

TUpdateFullState::TUpdateFullState(const TBuffer& state)
    : TStateful(state) {
}

void TUpdateFullState::Start(TWriter*) {
    Sampler = THashSampler<IdHash>(State->GetSampler().GetDenominator(), State->GetSampler().GetRest(), TRestSampler::EMode::Equal);
};

void TUpdateFullState::Do(TReader* input, TWriter* output) {
    Clear();
    Read(input);
    if (!BindingStatuses.empty()) {
        WriteState(output);
        WriteCollectorAndMeta(output);
    }
}

const TUpdateFullStateState::TSegmentState& TUpdateFullState::GetSegmentState(i64 segmentId) {
    const auto& it = State->GetSegments().find(segmentId);
    if (it == State->GetSegments().end()) {
        ythrow yexception() << "Segment not found: " << segmentId;
    }
    return it->second;
}

void TUpdateFullState::Clear() {
    IdMeta.Clear();
    BindingStatuses.clear();
    OtherColumns = Nothing();
    Segments.clear();
}

void TUpdateFullState::Read(TReader* input) {
    TFullBinding binding;

    for (; input->IsValid(); input->Next()) {
        const auto tableIndex = input->GetTableIndex();

        bool isUpdate = (tableIndex == static_cast<ui32>(EInputIndex::SegmentUpdates));
        bool isState = (tableIndex == static_cast<ui32>(EInputIndex::State));
        bool isMeta = (tableIndex == static_cast<ui32>(EInputIndex::Meta));

        if (isUpdate || isState) {
            if (isUpdate && input->GetRow<TFullBinding>().GetTimestamp() != GetSegmentState(input->GetRow<TFullBinding>().GetSegmentID()).GetTimestamp()) {
                continue;
            }

            input->MoveRow(&binding);

            if (!IdMeta.HasId()) {
                IdMeta.SetId(binding.GetId());
                IdMeta.SetIdType(binding.GetIdType());
            }

            if (!OtherColumns.Defined()) {
                OtherColumns = binding.GetOtherColumns();
            }

            const auto segmentId = binding.GetSegmentID();
            auto it = BindingStatuses.find(segmentId);
            if (it == BindingStatuses.end()) {
                it = BindingStatuses.emplace(segmentId, TBindingStatus(GetSegmentState(segmentId))).first;
            }

            if (isUpdate) {
                it->second.SetHasUpdate();
            } else {
                it->second.SetHasState();
            }
        } else if (isMeta) {
            input->MoveRow(&IdMeta);
        } else {
            ythrow yexception() << "Unknown index: " << tableIndex;
        }
    }

    Segments.reserve(BindingStatuses.size());
    for (const auto& [segmentId, status]: BindingStatuses) {
        if (!status.WasRemoved()) {
            Segments.push_back(ToExportSegmentId(segmentId));
        }
    }

    OutputToSampleLog = Sampler.Passes(IdMeta.GetId()) && Segments.size() >= State->GetSampleSizeLowerBound();
}

void TUpdateFullState::WriteState(TWriter* output) {
    TFullBinding stateRow;
    stateRow.SetId(IdMeta.GetId());
    stateRow.SetIdType(IdMeta.GetIdType());
    if (OtherColumns.Defined()) {
        stateRow.SetOtherColumns(*OtherColumns);
    }

    for (const auto& [segmentId, status]: BindingStatuses) {
        if (!status.WasRemoved()) {
            stateRow.SetSegmentID(segmentId);
            stateRow.SetTimestamp(status.GetTimestamp());
            output->AddRow(stateRow, static_cast<ui32>(EOutputIndex::State));
            if (OutputToSampleLog) {
                output->AddRow(stateRow, static_cast<ui32>(EOutputIndex::SampleLog));
            }
        }
    }
}

void TUpdateFullState::WriteCollectorAndMeta(TWriter* output) {
    NIdentifiers::TGenericID genericId(IdMeta.GetIdType(), IdMeta.GetId());
    if (!genericId.IsValid()) {
        return;
    }

    TUserSegmentsRow bigbRow;
    auto& userSegments = *bigbRow.MutableUserSegments();

    *userSegments.MutableUserId() = genericId.ToProto();
    userSegments.SetTimestamp(State->GetTimestamp());

    for (const auto& [segmentId, status]: BindingStatuses) {
        const auto exportSegmentId = ToExportSegmentId(segmentId);

        if (status.WasRemoved()) {
            userSegments.MutableUpdate()->AddRemoved(exportSegmentId);
        }

        if (status.WasAdded()) {
            auto& addedSegment = *userSegments.MutableUpdate()->AddAdded();
            addedSegment.SetId(exportSegmentId);
            addedSegment.SetPriority(0);
        }
    }

    bool overwrite = State->GetOverwriteAll() || Segments.size() > State->GetHardLimit() || IdMeta.GetLastUploadTimestamp() < State->GetUploadTimestampThreshold();
    bool hasDelta = !(userSegments.GetUpdate().GetRemoved().empty() && userSegments.GetUpdate().GetAdded().empty());

    if (overwrite) {
        userSegments.ClearUpdate();

        auto& overwriteIds = *(*userSegments.MutableOverwrite()->MutableSegmentsPerPriority())[0].MutableIds();
        overwriteIds.Reserve(State->GetHardLimit());

        const auto& hash = MultiHash(NIdentifiers::TGenericID(userSegments.GetUserId()).GetValue(), userSegments.GetTimestamp());
        std::sample(Segments.begin(), Segments.end(), google::protobuf::RepeatedFieldBackInserter(&overwriteIds), State->GetHardLimit(), std::minstd_rand(hash));
    }

    if ((overwrite || hasDelta) && userSegments.HasUserId()) {
        if (overwrite) {
            IdMeta.SetLastUploadTimestamp(State->GetTimestamp());
        }
        if (OutputToSampleLog) {
            output->AddRow(bigbRow, static_cast<ui32>(EOutputIndex::SampleCollector));
        }
        output->AddRow(std::move(bigbRow), static_cast<ui32>(EOutputIndex::Collector));
    }

    if (Segments.size() > 0) {
        output->AddRow(IdMeta, static_cast<ui32>(EOutputIndex::Meta));
    }
}

TUpdateFullState::TBindingStatus::TBindingStatus(const TUpdateFullStateState::TSegmentState& segmentState)
    : SegmentState(segmentState) {
}

void TUpdateFullState::TBindingStatus::SetHasState() {
    HasState = true;
}

void TUpdateFullState::TBindingStatus::SetHasUpdate() {
    HasUpdate = true;
}

bool TUpdateFullState::TBindingStatus::WasRemoved() const {
    bool removedOnUpdate = SegmentState.GetUpdated() && HasState && !HasUpdate;
    return SegmentState.GetRemovedOnTTL() || removedOnUpdate;
}

bool TUpdateFullState::TBindingStatus::WasAdded() const {
    return HasUpdate && !HasState && !SegmentState.GetRemovedOnTTL();
};

ui64 TUpdateFullState::TBindingStatus::GetTimestamp() const {
    return SegmentState.GetTimestamp();
};
