#include "normal_suffucient_statistic_aggregator.h"

#include "utils.h"

#include <util/generic/xrange.h>

using namespace NLab;

namespace {
    template <class T>
    void AddInplace(T& destination, T& source) {
        Y_ENSURE(source.GetData().size() == destination.GetData().size(), "Sizes should be equal");
        for (auto i : xrange(destination.GetData().size())) {
            Of(destination)[i] += Of(source)[i];
        }
    }
}

void TNormalSufficientStatisticAggregator::UpdateWith(const TNormalSufficientStatistic& statistic) {
    if (!statistic.GetMean().GetData().size()) {
        return;
    }

    TVectorType mean;
    mean.CopyFrom(statistic.GetMean());

    TVectorType mean2;
    if (statistic.GetCount() == 1) {
        Normalize(mean);
        ui64 size = mean.GetData().size();
        mean2.CopyFrom(statistic.GetMean());
        for (auto i : xrange(size)) {
            auto p = mean.GetData().Get(i);
            Of(mean2)[i] = p*p;
        }
    } else {
        mean2.CopyFrom(statistic.GetMean2());
    }
    if (Count == 0) {
        Mean.CopyFrom(mean);
        Mean2.CopyFrom(mean2);
    } else {
        AddInplace(Mean, mean);
        AddInplace(Mean2, mean2);
    }
    Count += statistic.GetCount();
}

void TNormalSufficientStatisticAggregator::MergeInto(TNormalSufficientStatistic& statistic) {
    if (!Count) {
        return;
    }
    statistic.MutableMean()->CopyFrom(Mean);
    statistic.MutableMean2()->CopyFrom(Mean2);

    statistic.SetCount(Count);
}
