#include "accumulators.h"

#include <library/cpp/testing/unittest/registar.h>

using namespace NZoom::NAccumulators;
using namespace NZoom::NValue;
using namespace NZoom::NHgram;

Y_UNIT_TEST_SUITE(TZoomTAccumulatorsTest) {

    Y_UNIT_TEST(AverageAccumulator) {
        TAccumulator acc = TAccumulator(EAccumulatorType::Average);
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(0.0, 0));

        acc.Mul(TValue());
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(0.0, 0));

        acc.Mul(TValue(1.0));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(1.0, 1));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(7.0, 4));

        acc.Mul(TValue(4.0, 5));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(11.0, 9));

        acc.Mul(acc.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(22.0, 18));

        acc.Clean();
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(0.0, 0));
    }

    Y_UNIT_TEST(RollupAverageAccumulator) {
        TVector<EAccumulatorType> accs = {EAccumulatorType::Avg, EAccumulatorType::List};

        for (const auto acc_type: accs) {
            TAccumulator acc = TAccumulator(acc_type);
            UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());

            acc.Mul(TValue());
            UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());

            acc.Mul(TValue(TVector<double>()));
            UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());

            acc.Mul(TValue(1.0));
            UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(1.0));

            acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
            UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(1.75));

            acc.Mul(TValue(4.0, 5));
            UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(11.0 / 9.0));

            acc.Mul(acc.GetValue());
            UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue((11.0 + 11.0 / 9.0) / 10.0));

            acc.Clean();
            UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());
        }
    }

    Y_UNIT_TEST(HgramAccumulator) {
        TAccumulator acc = TAccumulator(EAccumulatorType::Hgram);
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(THgram::Default()));
        UNIT_ASSERT_VALUES_EQUAL(acc.GetValue().GetType(), EValueType::SMALL_HGRAM);

        acc.Mul(TValue(0.0));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(THgram::Small({}, 1)));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(THgram::Small({1.0, 2.0, 3.0}, 1)));

        acc.Mul(TValue(4.0, 5));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(THgram::Small({1.0, 2.0, 3.0, 0.8}, 1)));

        acc.Mul(TValue(THgram::Small({1.0, 2.0, 3.0}, 1)));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(THgram::Small({1.0, 2.0, 3.0, 0.8, 1.0, 2.0, 3.0}, 2)));

        acc.Clean();
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(THgram::Default()));
    }

    Y_UNIT_TEST(LastAccumulator) {
        TAccumulator acc = TAccumulator(EAccumulatorType::Last);
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());

        acc.Mul(TValue(0.0));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(0.0));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(2.0));

        acc.Mul(TValue(4.0, 5));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(0.8));

        acc.Clean();
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());
    }

    Y_UNIT_TEST(MaxAccumulator) {
        TAccumulator acc = TAccumulator(EAccumulatorType::Max);
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());

        acc.Mul(TValue(0.0));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(0.0));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(2.0));

        acc.Mul(TValue(TVector<double>({1.0, 1.0, 1.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(2.0));

        acc.Mul(TValue(4.0, 5));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(2.0));

        acc.Clean();
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());
    }

    Y_UNIT_TEST(MinAccumulator) {
        TAccumulator acc = TAccumulator(EAccumulatorType::Min);
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());

        acc.Mul(TValue(10.0));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(10.0));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(2.0));

        acc.Mul(TValue(TVector<double>({1.0, 1.0, 1.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(1.0));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(1.0));

        acc.Mul(TValue(4.0, 5));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(0.8));

        acc.Clean();
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());
    }

    Y_UNIT_TEST(SummAccumulator) {
        TAccumulator acc = TAccumulator(EAccumulatorType::Summ);
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(0.0));

        acc.Mul(TValue(10.0));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(10.0));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(12.0));

        acc.Mul(TValue(TVector<double>({1.0, 1.0, 1.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(13.0));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(15.0));

        acc.Mul(TValue(4.0, 5));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(15.8));

        acc.Clean();
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(0.0));
    }

    Y_UNIT_TEST(SummNoneAccumulator) {
        TAccumulator acc = TAccumulator(EAccumulatorType::SummNone);
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());

        acc.Mul(TValue(10.0));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(10.0));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(12.0));

        acc.Mul(TValue(TVector<double>({1.0, 1.0, 1.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(13.0));

        acc.Mul(TValue(TVector<double>({1.0, 2.0, 3.0})));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(15.0));

        acc.Mul(TValue(4.0, 5));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(15.8));

        acc.Clean();
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue());
    }

    Y_UNIT_TEST(LegacyTypesConvertingAccumulator) {
        TAccumulator acc = TAccumulator(EAccumulatorType::Hgram, false);
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(THgram::Default(false)));
        UNIT_ASSERT_VALUES_EQUAL(acc.GetValue().GetType(), EValueType::NORMAL_HGRAM);

        acc.Mul(TValue(THgram::Small({3.0}, 1)));
        UNIT_ASSERT_VALUES_EQUAL(TValue(acc.GetValue()), TValue(THgram::Normal({0, 0, 1}, 1, 0)));
        UNIT_ASSERT_VALUES_EQUAL(acc.GetValue().GetType(), EValueType::NORMAL_HGRAM);
    }

    Y_UNIT_TEST(CompactAccumulators) {
        TCompactAccumulatorsArray accumulators(EAccumulatorType::Summ, 1);
        UNIT_ASSERT(accumulators.Empty());
        UNIT_ASSERT_VALUES_EQUAL(accumulators.Len(), 0);

        accumulators.Mul(TValue(1.0), 0);
        UNIT_ASSERT_VALUES_EQUAL(TValue(accumulators.GetValue(0)), TValue(1.0));
        UNIT_ASSERT(!accumulators.Empty());
        UNIT_ASSERT_VALUES_EQUAL(accumulators.Len(), 1);

        accumulators.Clean();
        UNIT_ASSERT(accumulators.Empty());
        UNIT_ASSERT_EXCEPTION(accumulators.GetValue(0), yexception);
    }

    Y_UNIT_TEST(CompactAccumulatorsMerge) {
        TCompactAccumulatorsArray accumulators(EAccumulatorType::Summ, 1);
        UNIT_ASSERT(accumulators.Empty());
        UNIT_ASSERT_VALUES_EQUAL(accumulators.Len(), 0);

        accumulators.Merge(TValue(1.0), 0);
        accumulators.Merge(TValue(1.0), 0);

        UNIT_ASSERT_VALUES_EQUAL(TValue(accumulators.GetValue(0)), TValue(1.0));
        UNIT_ASSERT(!accumulators.Empty());
        UNIT_ASSERT_VALUES_EQUAL(accumulators.Len(), 1);

        accumulators.Clean();
        UNIT_ASSERT(accumulators.Empty());
        UNIT_ASSERT_EXCEPTION(accumulators.GetValue(0), yexception);
    }

    Y_UNIT_TEST(CompactAccumulatorsMul) {
        TCompactAccumulatorsArray accumulators(EAccumulatorType::Summ, 1);
        UNIT_ASSERT(accumulators.Empty());
        UNIT_ASSERT_VALUES_EQUAL(accumulators.Len(), 0);

        accumulators.Mul(TValue(1.0), 0);
        accumulators.Mul(TValue(1.0), 0);

        UNIT_ASSERT_VALUES_EQUAL(TValue(accumulators.GetValue(0)), TValue(2.0));
        UNIT_ASSERT(!accumulators.Empty());
        UNIT_ASSERT_VALUES_EQUAL(accumulators.Len(), 1);

        accumulators.Clean();
        UNIT_ASSERT(accumulators.Empty());
        UNIT_ASSERT_EXCEPTION(accumulators.GetValue(0), yexception);
    }

}
