#include "function.h"

#include <library/cpp/testing/unittest/registar.h>
#include <util/generic/algorithm.h>
#include <util/string/join.h>

namespace NRTYFeatures {
    namespace {
        template <typename TItem>
        static TString PrintKeys(const THashMap<TString, TItem>& collection) {
            using namespace std;
            TVector<TString> names;
            transform(collection.begin(), collection.end(), back_inserter(names), [](const auto& kv) { return kv.first; });
            Sort(names);
            return JoinSeq(" ", names);
        }
    }

    class TImportedFunctionsTest: public NUnitTest::TTestBase {
        UNIT_TEST_SUITE(TImportedFunctionsTest)
        UNIT_TEST(TestImport);
        UNIT_TEST_SUITE_END();

    private:
        class TMockComponent {
        public:
            ui32 Vl;
        public:
            i32 GetSomething(ui32) {
                return std::numeric_limits<i32>::min();
            }

            ui32 GetAnother(ui32 docId) {
                return Vl + docId - 2;
            }

            float CalcMockFunction(ui32, TArrayRef<const float> args) {
                return 10 + (args.empty() ? 0.f : args.front());
            }

            float CalcMockFunctionEx(ui32, TArrayRef<const float> args, const TRTYFunctionCtx&) {
                return 17.5f + (args.empty() ? 0.f : args.front());
            }

            float CalcLayered(ui32 docId, i64 subId, TArrayRef<const float> args, const TRTYFunctionCtx&) {
                return docId + subId + args.size();
            }

            void Register(TImportedFunctionsBuilder& b) {
                using T = TMockComponent;

                b.AddGta(&T::GetSomething, this, "GetSomething");
                b.AddGta(&T::GetAnother, this, "GetAnother");
                b.Add<TFactorCalcerUserFunc>(&T::CalcMockFunction, this, "UserFoo", 2);
                b.Add<TFactorCalcerUserFuncEx>(&T::CalcMockFunctionEx, this, "UserFooEx", 2);
                b.Add<TFactorCalcerLayeredFunc>(&T::CalcLayered, this, "StreamFoo", 2);
            }
        };
    public:
        void TestImport() {
            TMockComponent x;
            x.Vl = 41;

            TImportedFunctions coll;

            using namespace NRTYFeatures;

            TImportedFunctionsBuilder b(coll);
            b.Request("GetAnother");
            b.Request("UserFoo");
            b.Request("UserFooEx");
            b.Request("StreamFoo");
            x.Register(b);

            TRTYFunctionCtx uctx = Default<TRTYFunctionCtx>();

            const auto& imported = coll.GetFunctions();
            UNIT_ASSERT_VALUES_EQUAL(2.5f, imported.UserFuncs.at("UserFoo")(/*docId=*/532312, {-7.5f, 0.0f}, uctx));
            UNIT_ASSERT_VALUES_EQUAL(10.0f, imported.UserFuncs.at("UserFooEx")(/*docId=*/532312, {-7.5f, 0.0f}, uctx));
            UNIT_ASSERT_VALUES_EQUAL("42", imported.ComponentGtas.at("GetAnother")(3));
            UNIT_ASSERT_VALUES_EQUAL("142", imported.ComponentGtas.at("GetAnother")(103));
            UNIT_ASSERT_VALUES_EQUAL(9.0f, imported.LayeredFuncs.at("StreamFoo")(/*docId=*/5, /*subId=*/2, {0.5f, 0.0f}, uctx));
            UNIT_ASSERT_EXCEPTION(imported.ComponentGtas.at("GetSomething")(1), yexception);
            UNIT_ASSERT_VALUES_EQUAL("GetAnother", PrintKeys(imported.ComponentGtas));
            UNIT_ASSERT_VALUES_EQUAL("UserFoo UserFooEx", PrintKeys(imported.UserFuncs));
            UNIT_ASSERT_VALUES_EQUAL("StreamFoo", PrintKeys(imported.LayeredFuncs));
            UNIT_ASSERT_VALUES_EQUAL("", JoinSeq(" ", b.GetMissing()));

            TImportedFunctionsBuilder b2(coll); // different builder - does not matter - might have been the same
            b2.Request("GetSomething");
            b2.Request("NonExistent");
            x.Register(b2);
            UNIT_ASSERT_VALUES_EQUAL("NonExistent", JoinSeq(" ", b2.GetMissing()));
            UNIT_ASSERT_VALUES_EQUAL("-2147483648", imported.ComponentGtas.at("GetSomething")(1));
            UNIT_ASSERT_VALUES_EQUAL("GetAnother GetSomething", PrintKeys(imported.ComponentGtas));
            UNIT_ASSERT_VALUES_EQUAL("UserFoo UserFooEx", PrintKeys(imported.UserFuncs));
        }
    };
}

UNIT_TEST_SUITE_REGISTRATION(NRTYFeatures::TImportedFunctionsTest);
