#include "accumulators.h"

#include <library/cpp/testing/unittest/registar.h>

#include <util/generic/xrange.h>
#include <util/random/easy.h>

#include <cfloat>

using namespace NNetmon;

class TDescriptiveStatisticsTest: public TTestBase {
    UNIT_TEST_SUITE(TDescriptiveStatisticsTest);
    UNIT_TEST(TestMissingMean)
    UNIT_TEST(TestMissingVariance)
    UNIT_TEST(TestAppendAndMerge)
    UNIT_TEST(TestSerialization)
    UNIT_TEST_SUITE_END();

private:
    inline void TestMissingMean() {
        TDescriptiveStatistics stat;
        UNIT_ASSERT_VALUES_EQUAL(stat.Mean(), Nothing());
        UNIT_ASSERT_VALUES_EQUAL(stat.Variance(), Nothing());
    }

    inline void TestMissingVariance() {
        TDescriptiveStatistics stat;
        stat.Append(1.0);
        UNIT_ASSERT_VALUES_EQUAL(*stat.Mean(), 1.0);
        UNIT_ASSERT_VALUES_EQUAL(stat.Variance(), Nothing());
    }

    inline void TestAppendAndMerge() {
        TDescriptiveStatistics overall;
        TDescriptiveStatistics first;
        TDescriptiveStatistics second;

        std::vector<double> valueList;

        std::size_t count = 0;
        double mean = 0;
        for (const auto index : xrange(1000)) {
            const auto value(RandomNumber<double>());
            valueList.push_back(value);

            mean += (value - mean) / (count + 1);
            count++;

            overall.Append(value);
            if (index % 2) {
                first.Append(value);
            } else {
                second.Append(value);
            }
        }

        double delta = 0;
        for (const auto value : valueList) {
            delta += (value - mean) * (value - mean);
        }
        double variance = delta / (count - 1);

        first.Merge(second);
        second = first;

        UNIT_ASSERT_DOUBLES_EQUAL(mean, *overall.Mean(), FLT_EPSILON);
        UNIT_ASSERT_DOUBLES_EQUAL(mean, *first.Mean(), FLT_EPSILON);
        UNIT_ASSERT_DOUBLES_EQUAL(mean, *second.Mean(), FLT_EPSILON);

        UNIT_ASSERT_DOUBLES_EQUAL(variance, *overall.Variance(), FLT_EPSILON);
        UNIT_ASSERT_DOUBLES_EQUAL(variance, *first.Variance(), FLT_EPSILON);
        UNIT_ASSERT_DOUBLES_EQUAL(variance, *second.Variance(), FLT_EPSILON);
    }

    inline void TestSerialization() {
        TDescriptiveStatistics stat;
        stat.Append(1.0);
        stat.Append(2.0);
        stat.Append(3.0);

        flatbuffers::FlatBufferBuilder builder;
        builder.Finish(stat.ToProto(builder));

        const NCommon::TDescriptiveStatistics& dumped(
            *flatbuffers::GetRoot<NCommon::TDescriptiveStatistics>(builder.GetBufferPointer()));
        TDescriptiveStatistics loaded(dumped);

        UNIT_ASSERT_DOUBLES_EQUAL(2.0, *loaded.Mean(), FLT_EPSILON);
        UNIT_ASSERT_DOUBLES_EQUAL(1.0, *loaded.Variance(), FLT_EPSILON);
    }
};

UNIT_TEST_SUITE_REGISTRATION(TDescriptiveStatisticsTest);
