#include "lookalike.h"

#include <crypta/lib/native/dates/dates.h>
#include <crypta/lookalike/lib/native/segment_embedding_model.h>
#include <crypta/lookalike/lib/native/user_embedding_model.h>
#include <crypta/lookalike/lib/native/normalize.h>
#include <crypta/lookalike/proto/yt_node_names.pb.h>


TVector<float> UnpackVector(const TString& stringVector) {
    return NVectorOperations::Unpack<float>(stringVector, true);
}

void TPredictMapper::InitSegments() {
    for (auto& segmentMeta : State->GetSegments()) {
        auto segmentDssmVector = SegmentEmbeddingModel->Embed(segmentMeta.second.GetUserDataStats());
        NCrypta::NLookalike::Normalize(segmentDssmVector);

        TSegment segment(
            segmentMeta.first,
            segmentMeta.second,
            &(State->GetGlobalUserDataStats()),
            State->GetMaxFilterErrorRate(),
            segmentDssmVector);
        Segments.push_back(segment);
    }
}

TLookalikeIntermediateScoredRecord* TPredictMapper::AddOutput(TUserEmbedding& userEmbedding) {
    TLookalikeIntermediateScoredRecord out;
    out.SetYandexuid(ToString(userEmbedding.GetUserId()));
    Outputs.push_back(out);
    return &(Outputs.back());
}

void TPredictMapper::YieldTop(TTableWriter<TLookalikeIntermediateScoredRecord>* output) {
    for (auto& segment : Segments) {
        segment.YieldTop(output);
    }
    Outputs.clear();
}

TPredictMapper::TSegment::TSegment(TString segmentId, const TLookalikeMapping::TSegmentMeta& meta,
                                   const TUserDataStats* globalUserDataStats, double maxFilterErrorRate,
                                   const TEmbedding& segmentDssmVector)
    : MaxFilterErrorRate(maxFilterErrorRate)
    , GlobalTotal(0)
    , OutputSize(0)
    , SegmentId(segmentId)
    , ExternalId()
    , DssmVector(segmentDssmVector)
    , Options()
    , SegmentAttributes()
    , GlobalAttributes()
    , OutputStorages()
    , OutputSizes()
    , RegionCounts()
    , DeviceProbabilities()
{
    auto& userDataStats = meta.GetUserDataStats();
    Options = &(meta.GetOptions());
    SegmentAttributes = &(userDataStats.GetAttributes());
    GlobalAttributes = &(globalUserDataStats->GetAttributes());
    GlobalTotal = static_cast<double>(globalUserDataStats->GetCounts().GetTotal());
    OutputSize = Options->GetCounts().GetOutput();
    TopProbability = OutputSize / GlobalTotal;
    if (!Options->GetIncludeInput()) {
        OutputSize += Options->GetCounts().GetInput() + maxFilterErrorRate * GlobalTotal;
        TopProbability += maxFilterErrorRate;
    }

    if (Options->GetEnforceRegion()) {
        InitRegionStats();
    }
    if (Options->GetEnforceDeviceAndPlatform()) {
        InitDeviceStats();
    }
}

void TPredictMapper::TSegment::YieldTop(TTableWriter<TLookalikeIntermediateScoredRecord>* output) {
    for (auto& output_storage : OutputStorages) {
        TString index = output_storage.first;
        TUserData::TAttributes attributes;
        Y_PROTOBUF_SUPPRESS_NODISCARD attributes.ParseFromString(index);
        auto toOutput = [output, this, index, attributes](std::pair<double, TLookalikeIntermediateScoredRecord*>& scoreWithOut) {
            double score = scoreWithOut.first;
            TLookalikeIntermediateScoredRecord out = *(scoreWithOut.second);
            double regionSize = this->Options->GetCounts().GetOutput();
            double availableRegionSize = regionSize;
            double deviceProbability = 1.;
            double filteredSize = 0;
            auto region = attributes.GetRegion();
            auto device = attributes.GetDevice();
            if (this->Options->GetEnforceRegion()) {
                regionSize *= this->OutputSizes[attributes.GetRegion()] / this->OutputSize;
                availableRegionSize = this->OutputSizes[attributes.GetRegion()];
            }
            if (this->Options->GetEnforceDeviceAndPlatform()) {
                deviceProbability = DeviceProbabilities[attributes.GetDevice()].Segment;
            }
            if (!this->Options->GetIncludeInput()) {
                filteredSize = RegionCounts[region].Segment + MaxFilterErrorRate * RegionCounts[region].Global;
            }

            out.SetMinusScore(-score);
            out.SetIndex(index);
            out.SetRegion(region);
            out.SetDevice(device);
            out.SetMinusRegionSize(-regionSize);
            out.SetMinusDeviceProbability(-deviceProbability);
            out.SetGroupID(this->SegmentId);

            output->AddRow(out);
        };
        double currentProbability = 0.;
        double globalProbability = 1.;
        if (CurrentCount > 0) {
            double size = static_cast<double>(output_storage.second.GetSize());
            currentProbability = size / CurrentCount;
        }

        globalProbability = RegionCounts[attributes.GetRegion()].Global * DeviceProbabilities[attributes.GetDevice()].Global / GlobalTotal;
        double correction = 1.;
        if (currentProbability > 0) {
            correction = std::max(1., globalProbability / currentProbability);
        }
        output_storage.second.YieldTop(toOutput, correction);
    }
    CurrentCount = 0;
}

void TPredictMapper::TSegment::Add(const TUserEmbedding& userEmbedding, TLookalikeIntermediateScoredRecord* out,
        const TEmbedding& dssmVector) {
    auto score = ComputeScore(dssmVector);
    auto attributes = GetFilterAttributes(userEmbedding.GetAttributes());
    TString storageIndex = GetStorageIndex(attributes);
    auto storageIt = OutputStorages.find(storageIndex);
    if (storageIt != OutputStorages.end()) {
        storageIt->second.Add(score, out);
    } else {
        auto& storage = OutputStorages[storageIndex];
        storage.SetProbability(GetProbability(attributes));
        storage.SetRate(2.);
        storage.Add(score, out);
    }
    CurrentCount += 1;
}

TString TPredictMapper::TSegment::GetStorageIndex(const TUserData::TAttributes& attributes, bool yetFiltered) {
    if (!yetFiltered) {
        auto filterAttributes = GetFilterAttributes(attributes);
        return GetStorageIndex(filterAttributes, yetFiltered = true);
    }
    TString index;
    Y_PROTOBUF_SUPPRESS_NODISCARD attributes.SerializeToString(&index);
    return index;
}

double TPredictMapper::TSegment::ComputeScore(const TEmbedding& rowDssmVector) {
    auto score = NVectorOperations::Dot(DssmVector, rowDssmVector);
    return (static_cast<double>(score) + 1.0) / 2.0;
}

void TPredictMapper::TSegment::InitRegionStats(size_t sizeRegions) {
    TVector<size_t> regionCounts;
    regionCounts.reserve(SegmentAttributes->GetRegion().size());

    for (auto& region : SegmentAttributes->GetRegion()) {
        if (region.GetRegion()) {
            regionCounts.push_back(region.GetCount());
        }
    }
    sizeRegions = std::min(regionCounts.size(), sizeRegions);

    size_t topSum = 0;
    size_t otherSum = 0;
    std::partial_sort(regionCounts.begin(), regionCounts.begin() + sizeRegions, regionCounts.end(), std::greater<size_t>());
    for (size_t i = 0; i < regionCounts.size(); ++i) {
        if (i < sizeRegions) {
            topSum += regionCounts[i];
        } else {
            otherSum += regionCounts[i];
        }
    }
    double rate = 1. + static_cast<double>(otherSum) / static_cast<double>(topSum);
    size_t minCount = regionCounts[static_cast<int>(sizeRegions) - 1];
    double segmentTotalRegionCount = 0;
    for (auto& stat : SegmentAttributes->GetRegion()) {
        double count = static_cast<double>(stat.GetCount());
        if (count >= minCount && stat.GetRegion()) {
            auto& region = RegionCounts[stat.GetRegion()];
            region.Segment = count * rate;
            segmentTotalRegionCount += region.Segment;
        }
    }
    size_t NONE_REGION = 0;
    auto& region = RegionCounts[NONE_REGION];
    region.Segment = 1;
    segmentTotalRegionCount += 1;
    for (auto& stat : GlobalAttributes->GetRegion()) {
        if (RegionCounts.contains(stat.GetRegion())) {
            RegionCounts[stat.GetRegion()].Global = stat.GetCount();
        } else {
            RegionCounts[NONE_REGION].Global += stat.GetCount();
        }
    }

    double reminder = static_cast<double>(OutputSize);
    double sumSizes = 0;
    THashMap<ui64, double> realProbabilities;
    for (auto& count : RegionCounts) {
        realProbabilities[count.first] = count.second.Segment / segmentTotalRegionCount;
    }
    int iteration = 0;
    while (reminder > 0 && iteration < 10) {
        for (auto& count : RegionCounts) {
            auto region = count.first;
            double size = OutputSizes[region];
            double addition = static_cast<double>(static_cast<int>(realProbabilities[region] * reminder + 1));
            double availableSize = std::min(size + addition, count.second.Global);
            double diff = availableSize - size;
            sumSizes += diff;
            OutputSizes[region] = availableSize;
        }
        reminder = OutputSize - sumSizes;
        iteration++;
    }
}

void TPredictMapper::TSegment::InitDeviceStats() {
    double sumCount = 0;
    for (auto& stat : SegmentAttributes->GetDevice()) {
        sumCount += stat.GetCount();
    }
    if (sumCount > 0) {
        for (auto& stat : SegmentAttributes->GetDevice()) {
            DeviceProbabilities[stat.GetDevice()].Segment = stat.GetCount() / sumCount;
        }
    }

    sumCount = 0;
    for (auto& stat : GlobalAttributes->GetDevice()) {
        sumCount += stat.GetCount();
    }
    if (sumCount > 0) {
        for (auto& stat : GlobalAttributes->GetDevice()) {
            DeviceProbabilities[stat.GetDevice()].Global = stat.GetCount() / sumCount;
        }
    }

    sumCount = 0;
    for (auto& device : DeviceProbabilities) {
        auto& stat = device.second;
        if (stat.Global) {
            sumCount += stat.Segment / stat.Global;
        }
    }
    sumCount = std::min(1., sumCount);
    DeviceConditionalSum = 0;
    for (auto& device : DeviceProbabilities) {
        auto& stat = device.second;
        stat.Conditional = stat.Segment / stat.Global / sumCount;
        DeviceConditionalSum += stat.Conditional;
    }
}

TUserData::TAttributes TPredictMapper::TSegment::GetFilterAttributes(const TUserData::TAttributes& userDataAttributes) {
    TUserData::TAttributes attributes;
    if (Options->GetEnforceRegion()) {
        auto region = userDataAttributes.GetRegion();
        if (RegionCounts.contains(region)) {
            attributes.SetRegion(region);
        }
    }
    if (Options->GetEnforceDeviceAndPlatform()) {
        attributes.SetDevice(userDataAttributes.GetDevice());
    }
    return attributes;
}

double TPredictMapper::TSegment::GetProbability(const TUserData::TAttributes& attributes) {
    double probability = 1.;
    if (Options->GetEnforceRegion()) {
        auto region = attributes.GetRegion();
        if (RegionCounts[region].Global) {
            probability *= OutputSizes[region] / RegionCounts[region].Global;
        }
    } else {
        probability = TopProbability;
    }
    if (Options->GetEnforceDeviceAndPlatform()) {
        auto device = DeviceProbabilities[attributes.GetDevice()];
        if (device.Global) {
            probability *= device.Conditional;
        }
    }

    return std::min(1., std::max(0., probability));
}

void TPredictMapper::Start(TWriter* writer) {
    Y_UNUSED(writer);

    SegmentEmbeddingModel = MakeHolder<NCrypta::NLookalike::TSegmentEmbeddingModel>(
            NCrypta::NLookalike::TYtNodeNames().GetDssmModelFile(),
            NCrypta::NLookalike::TYtNodeNames().GetSegmentsDictFile());
}

void TPredictMapper::Do(TTableReader<TUserEmbedding>* input, TTableWriter<TLookalikeIntermediateScoredRecord>* output) {
    InitSegments();
    for (; input->IsValid(); input->Next()) {
        auto userEmbedding = input->GetRow();
        TEmbedding dssmVector(userEmbedding.GetEmbedding().begin(), userEmbedding.GetEmbedding().end());
        NCrypta::NLookalike::Normalize(dssmVector);
        TLookalikeIntermediateScoredRecord* out = AddOutput(userEmbedding);
        for (auto& segment : Segments) {
            segment.Add(userEmbedding, out, dssmVector);
        }
        if (static_cast<int>(Outputs.size()) == MAX_COUNT) {
            YieldTop(output);
        }
    }

    YieldTop(output);
}

void TPredictReducer::Do(TReader* input, TTableWriter<TLookalikeOutput>* output) {
    TSegment segment;
    TString currentIndex = "";
    int currentRegionSize = 0;
    int currentIndexSize = 0;
    ui64 currentRegion = 0;
    TDevice currentDevice;
    double regionSize = 0;
    double deviceProbability = 0;
    double reminder = 0;
    TVector<std::pair<TString, double>> reserve;

    for (bool isFirst = true; input->IsValid(); input->Next()) {
        auto tableIndex = input->GetTableIndex();
        if (tableIndex == 0) {
            auto& stats = input->GetRow<TUserDataStats>();
            segment.Init(stats, State);
            isFirst = false;
            continue;

        } else {
            auto& row = input->GetRow<TLookalikeIntermediateScoredRecord>();
            auto region = row.GetRegion();
            auto device = row.GetDevice();
            auto index = row.GetIndex();
            auto yandexuid = row.GetYandexuid();
            auto score = -row.GetMinusScore();

            if (segment.IsReadyToFinish()) {
                break;
            }

            if (index != currentIndex || device != currentDevice) {
                reminder += deviceProbability * regionSize - currentIndexSize;
                currentIndex = index;
                currentDevice = device;
                currentIndexSize = 0;
            }

            if (index != currentIndex || region != currentRegion) {
                segment.YieldFromReserve(regionSize - currentRegionSize, output);
                regionSize = -row.GetMinusRegionSize() + 1;

                currentRegion = region;
                currentRegionSize = 0;
                currentIndexSize = 0;
                reminder = 0;
            }

            regionSize = -row.GetMinusRegionSize() + 1;
            deviceProbability = -row.GetMinusDeviceProbability();

            if (regionSize < currentRegionSize) {
                continue;
            }
            bool remind = false;
            if (deviceProbability * regionSize < currentIndexSize) {
                if (reminder > 0) {
                    remind = true;
                } else {
                    segment.AddToReserve(yandexuid, score, regionSize);
                    continue;
                }
            }

            if (segment.Yield(score, yandexuid, output)) {
                currentRegionSize += 1;
                currentIndexSize += 1;
                if (remind) {
                    reminder--;
                }
            }
        }
    }

    segment.WriteCustomStatistics();
}

void TLookalikeMapper::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    srand(time(NULL));
    InitSegments();
    OldestTimestamp = State->GetOldestTimestamp();
    for (; input->IsValid(); input->Next()) {
        auto row = input->GetRow();
        if (!NDates::HasActiveDate(row, "days_active", OldestTimestamp)) {
            continue;
        }
        TNode* out = AddOutput(row);
        auto vector = UnpackVector(row["vector"].AsString());
        for (auto& segment : Segments) {
            segment.Add(row, out, vector);
        }

        if (static_cast<int>(Outputs.size()) == MAX_COUNT) {
            YieldTop(output);
        }
    }

    YieldTop(output);
}

void TLookalikeMapper::TSegment::YieldTop(TTableWriter<TNode>* output) {
    auto toOutput = [output, this](std::pair<double, TNode*>& scoreWithOut) {
        double score = scoreWithOut.first;
        TNode out = *(scoreWithOut.second);
        out("score", score);
        out("_score", -score);
        out("volatile_id", TString(this->SegmentId));
        output->AddRow(out);
    };
    for (auto& output_storage : OutputStorages) {
        output_storage.second.YieldTop(toOutput);
    }
}

void TLookalikeMapper::TSegment::Add(const TNode& row, TNode* out, const TVector<float>& rowVector) {
    auto score = ComputeScore(rowVector);
    auto factors = GetFactors(row);
    TString storageIndex = GetStorageIndex(factors);
    auto storageIt = OutputStorages.find(storageIndex);
    if (storageIt != OutputStorages.end()) {
        storageIt->second.Add(score, out);
    } else {
        auto& storage = OutputStorages[storageIndex];
        storage.SetProbability(GetProbability(factors));
        storage.Add(score, out);
    }
}

double TLookalikeMapper::TSegment::ComputeScore(const TVector<float>& rowVector) {
    double score = static_cast<double>(NVectorOperations::Dot(Vector, rowVector));
    return (score + 1.0) / 2.0;
}

TString TLookalikeMapper::TSegment::GetStorageIndex(const THashMap<TString, TString>& factors) {
    TString index = SENTINEL;
    for (const auto& factor : factors) {
        index += SENTINEL + factor.second;
    }
    return index;
}

THashMap<TString, TString> TLookalikeMapper::TSegment::GetFactors(const TNode& row) {
    THashMap<TString, TString> factors;
    for (auto& factor : FactorProbabilities) {
        TString rowValue = UNDEFINED;
        if (row.HasKey(factor.first)) {
            auto node = row[factor.first];
            if (node.IsString()) {
                rowValue = TString(node.AsString());
            } else {
                rowValue = std::to_string(node.AsInt64());
            }
        }
        factors[factor.first] = rowValue;
    }
    return factors;
}

double TLookalikeMapper::TSegment::GetProbability(const THashMap<TString, TString>& factors, double minValue) {
    double probability = TopProbability;
    for (auto& factor : factors) {
        auto probabilities = FactorProbabilities[factor.first];
        auto factorValue = factor.second;
        auto probabilityIt = probabilities.find(factorValue);
        if (probabilityIt == probabilities.end()) {
            factorValue = UNDEFINED;
        }
        probability *= probabilities[factorValue];
        probability = std::max(minValue, probability);
    }
    return probability;
}

void TLookalikeMapper::InitSegments() {
    auto SegmentsMeta = State->segments();
    for (auto& segmentMeta : SegmentsMeta) {
        TSegment segment(segmentMeta.first, segmentMeta.second);
        Segments.push_back(segment);
    }
}

TNode* TLookalikeMapper::AddOutput(TNode& row) {
    TNode out;
    out("yandexuid", row["yandexuid"]);
    Outputs.push_back(out);
    return &(Outputs.back());
}

void TLookalikeMapper::YieldTop(TTableWriter<TNode>* output) {
    for (auto& segment : Segments) {
        segment.YieldTop(output);
    }
    Outputs.clear();
}


void TLookalikeReducer::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    int count = 0;
    int size = -1;

    for (; input->IsValid(); input->Next()) {
        const auto& row = input->GetRow();
        TString segmentId = TString(row["volatile_id"].AsString());
        if (size == -1) {
            auto segments = State->segments();
            auto segment = segments[TString(segmentId)];
            size = segment.counts().output();
        }
        if (count == size) {
            break;
        }
        count++;
        output->AddRow(row);
    }
}

void TLookalikeJoiner::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    TNode out;
    bool isFirst = true;
    ui64 yandexuid = 0;
    for (; input->IsValid(); input->Next()) {
        auto& row = input->GetRow();
        if (isFirst) {
            yandexuid = row["yandexuid"].AsUint64();
            isFirst = false;
        }
        auto segmentNode = &(out);
        TString segment = row["volatile_id"].AsString();
        auto score = row["score"];
        char delim = ':';
        size_t pos = segment.find(delim);
        while (pos != TString::npos) {
            segmentNode = &((*segmentNode)[segment.substr(0, pos)]);
            pos++;
            segment = segment.substr(pos, static_cast<int>(segment.size()) - pos);
            pos = segment.find(delim);
        }
        (*segmentNode)(segment, score);
    }
    out("id", ToString(yandexuid));
    out("id_type", "yandexuid");
    output->AddRow(out);
}
