#include "merge.h"

#include <solomon/libs/cpp/math/math.h>

#include <util/generic/algorithm.h>

namespace NSolomon::NTsModel {

namespace {

bool BoundsEqual(const NTs::THistogramPoint::TValue& l, const NTs::THistogramPoint::TValue& r) {
    return std::equal(
        l.Buckets.begin(), l.Buckets.end(),
        r.Buckets.begin(), r.Buckets.end(),
        [](NTs::THistogramPoint::TBucket l, NTs::THistogramPoint::TBucket r) -> bool {
            return l.UpperBound == r.UpperBound;
        });
}

double DenomMultiplier(ui64 lDenom, ui64 rDenom) {
    i64 delta = static_cast<i64>(rDenom) - static_cast<i64>(lDenom);
    if (delta == 0) {
        return 1.0;
    } else if (delta > 0) {
        return static_cast<double>(rDenom) / static_cast<double>(delta);
    } else {
        return -static_cast<double>(delta) / static_cast<double>(lDenom);
    }
}

void MergeSameBounds(const NTs::THistogramPoint::TValue& l, NTs::THistogramPoint::TValue* r) {
    Y_ASSERT(!r->Buckets.empty() && r->Buckets.size() == l.Buckets.size());

    if (l.Denom == r->Denom) {
        for (size_t i = 0; i < r->Buckets.size(); ++i) {
            r->Buckets[i].Value += l.Buckets[i].Value;
        }
    } else {
        double leftMultiplier = DenomMultiplier(r->Denom, l.Denom);  // FIXME: shady arguments order
        for (size_t i = 0; i < r->Buckets.size(); ++i) {
            double lValue = static_cast<double>(l.Buckets[i].Value) * leftMultiplier;
            double rValue = static_cast<double>(r->Buckets[i].Value);
            r->Buckets[i].Value = static_cast<ui64>(round(lValue + rValue));
        }
    }
}

void MergeDifferentBounds(const NTs::THistogramPoint::TValue& l, NTs::THistogramPoint::TValue* r) {
    Y_ASSERT(!r->Buckets.empty());

    double leftMultiplier = DenomMultiplier(r->Denom, l.Denom);  // FIXME: shady arguments order
    for (auto bucket: l.Buckets) {
        auto newBucket = LowerBoundBy(
            r->Buckets.begin(), r->Buckets.end(), bucket.UpperBound,
            [](NTs::THistogramPoint::TBucket bucket) -> double { return bucket.UpperBound; });

        if (Y_UNLIKELY(newBucket == r->Buckets.end())) {
            newBucket--;
        }

        double lValue = static_cast<double>(bucket.Value) * leftMultiplier;
        newBucket->Value += static_cast<ui64>(round(lValue));
    }
}

i64 ExtendUp(NTs::TLogHistogramPoint::TValue* h, i64 expectedEnd) {
    Y_ASSERT(expectedEnd > static_cast<i64>(h->Values.size()));

    if (expectedEnd <= h->MaxBucketCount) {
        h->Values.resize(expectedEnd, 0.0);
        return expectedEnd;
    }

    i64 toRemove = std::min(expectedEnd - h->MaxBucketCount, static_cast<i64>(h->Values.size()));
    if (toRemove > 0) {
        double firstWeight = std::accumulate(h->Values.begin(), h->Values.begin() + toRemove, 0.0);
        h->Values.erase(h->Values.begin(), h->Values.begin() + toRemove);
        if (h->Values.empty()) {
            h->Values.push_back(firstWeight);
        } else {
            h->Values[0] += firstWeight;
        }
    }

    h->Values.resize(h->MaxBucketCount, 0.0);
    h->StartPower += toRemove;

    return expectedEnd - toRemove;
}

i64 ExtendDown(NTs::TLogHistogramPoint::TValue* h, i64 expectedBegin) {
    Y_ASSERT(expectedBegin < 0);

    i64 toAdd = std::min(h->MaxBucketCount - static_cast<i64>(h->Values.size()), -expectedBegin);
    if (toAdd > 0) {
        h->Values.insert(h->Values.begin(), toAdd, 0.0);
        h->StartPower -= toAdd;
    }
    return expectedBegin + toAdd;
}

std::pair<i64, i64> ExtendBounds(NTs::TLogHistogramPoint::TValue* h, i64 begin, i64 end) {
    i64 realBegin;
    i64 realEnd;

    if (end > static_cast<i64>(h->Values.size())) {
        realEnd = ExtendUp(h, end);
        begin += realEnd - end;
    } else {
        realEnd = end;
    }

    if (begin < 0) {
        realBegin = ExtendDown(h, begin);
        realEnd += realBegin - begin;
    } else {
        realBegin = begin;
    }

    return {realBegin, realEnd};
}

} // namespace

void Merge(const TGaugePoint& l, TGaugePoint* r) {
    if (r->Merge) {
        MergeRaw(l, r);
        r->Count += l.Count;
    }
}

void Merge(const TCounterPoint& l, TCounterPoint* r) {
    if (r->Merge) {
        MergeRaw(l, r);
        r->Count += l.Count;
    }
}

void Merge(const TRatePoint& l, TRatePoint* r) {
    if (r->Merge) {
        MergeRaw(l, r);
        r->Count += l.Count;
    }
}

void Merge(const TIGaugePoint& l, TIGaugePoint* r) {
    if (r->Merge) {
        MergeRaw(l, r);
        r->Count += l.Count;
    }
}

void Merge(const THistPoint& l, THistPoint* r) {
    if (r->Merge) {
        MergeRaw(l, r);
        r->Count += l.Count;
    }
}

void Merge(const THistRatePoint& l, THistRatePoint* r) {
    if (r->Merge) {
        MergeRaw(l, r);
        r->Count += l.Count;
    }
}

void Merge(const TDSummaryPoint& l, TDSummaryPoint* r) {
    if (r->Merge) {
        MergeRaw(l, r);
        r->Count += l.Count;
    }
}

void Merge(const TLogHistPoint& l, TLogHistPoint* r) {
    if (r->Merge) {
        MergeRaw(l, r);
        r->Count += l.Count;
    }
}

void MergeRaw(const NTs::TDoublePoint::TValue& l, NTs::TDoublePoint::TValue* r) {
    if (std::isnan(l.Num) && std::isnan(r->Num)) {
        r->Denom = 0;
        return;
    }

    if (std::isnan(l.Num)) {
        return;
    }

    if (std::isnan(r->Num)) {
        r->Num = l.Num;
        r->Denom = l.Denom;
    } else if (l.Denom == r->Denom) {
        r->Num += l.Num;
    } else {
        r->Num = l.ValueDivided() + r->ValueDivided();
        r->Denom = 0;
    }
}

void MergeRaw(const NTs::TLongPoint::TValue& l, NTs::TLongPoint::TValue* r) {
    r->Value += l.Value;
}

void MergeRaw(const NTs::THistogramPoint::TValue& l, NTs::THistogramPoint::TValue* r) {
    if (l.Buckets.empty()) {
        return;
    }

    if (r->Buckets.empty()) {
        r->Buckets = l.Buckets;
        r->Denom = l.Denom;
    } else if (BoundsEqual(l, *r)) {
        MergeSameBounds(l, r);
    } else {
        MergeDifferentBounds(l, r);
    }
}

void MergeRaw(const NTs::TLogHistogramPoint::TValue& l, NTs::TLogHistogramPoint::TValue* r) {
    r->ZeroCount += l.ZeroCount;

    if (AreDoublesEqual(l.Base, r->Base)) {
        i64 beginBeforeExtend = l.StartPower - r->StartPower;
        i64 endBeforeExtend = beginBeforeExtend + l.Values.size();

        auto[begin, end] = ExtendBounds(r, beginBeforeExtend, endBeforeExtend);

        Y_ASSERT(end - begin == static_cast<i64>(l.Values.size()));

        for (i64 resPos = begin, lPos = 0; resPos < end; ++resPos, ++lPos) {
            r->Values[std::max(resPos, 0l)] += l.Values[lPos];
        }
    } else {
        for (size_t i = 0; i < l.Values.size(); ++i) {
            auto value = l.Values[i];
            if (AreDoublesEqual(value, 0)) {
                continue;
            }

            double bucketBound = std::pow(l.Base, static_cast<double>(i) + l.StartPower);
            i64 indexInNewHist = static_cast<i64>(std::floor(std::log(bucketBound) / std::log(r->Base))) - r->StartPower;

            if (indexInNewHist >= static_cast<i64>(r->Values.size())) {
                indexInNewHist = ExtendUp(r, indexInNewHist + 1) - 1;
            } else if (indexInNewHist < 0) {
                ExtendDown(r, indexInNewHist);
                indexInNewHist = 0;
            }

            r->Values[indexInNewHist] += value;
        }
    }
}

void MergeRaw(const NTs::TSummaryIntPoint::TValue& l, NTs::TSummaryIntPoint::TValue* r) {
    if (l.CountValue == 0) {
        return;
    }

    if (r->CountValue == 0) {
        r->Sum = l.Sum;
        r->Min = l.Min;
        r->Max = l.Max;
        r->Last = l.Last;
        r->CountValue = l.CountValue;
    } else {
        r->Sum += l.Sum;
        r->Min = std::min(l.Min, r->Min);
        r->Max = std::max(l.Max, r->Max);
        r->CountValue += l.CountValue;
    }
}

void MergeRaw(const NTs::TSummaryDoublePoint::TValue& l, NTs::TSummaryDoublePoint::TValue* r) {
    if (l.CountValue == 0) {
        return;
    }

    if (r->CountValue == 0) {
        r->Sum = l.Sum;
        r->Min = l.Min;
        r->Max = l.Max;
        r->Last = l.Last;
        r->CountValue = l.CountValue;
    } else {
        r->Sum += l.Sum;
        r->Min = std::min(l.Min, r->Min);
        r->Max = std::max(l.Max, r->Max);
        r->CountValue += l.CountValue;
    }
}

} // namespace NSolomon::NTsModel::NPrivate
