#include "calcer.h"
#include "model.h"

#include <drive/backend/proto/models.pb.h>

#include <library/cpp/json/json_reader.h>
#include <library/cpp/logger/global/global.h>
#include <library/cpp/testing/unittest/registar.h>
#include <library/cpp/testing/unittest/tests_data.h>

#include <rtline/library/json/builder.h>
#include <rtline/library/json/cast.h>
#include <rtline/util/algorithm/ptr.h>

#include <util/stream/file.h>

#include <cmath>

Y_UNIT_TEST_SUITE(OfferModelsSuite) {
    TString SupplyDemandHistoryModelDescription = R"(
{
    "name": "fakenews42",
    "type": "supply_demand_history",
    "polynom": "100E1003040000000SF0"
}
    )";

    TString Features = R"(
[
    [
        9.28,
        0,
        2.330200260027019,
        2.85,
        0,
        1.3480731482996928,
        55.65922546,
        37.53060532,
        3,
        0.7451041666666666,
        0.74,
        469,
        13.058360179622305,
        0,
        0,
        1,
        1,
        0,
        0
    ],
    [
        "kia_rio_xline"
    ]
]
    )";

    Y_UNIT_TEST(SupplyDemandHistory) {
        NJson::TJsonValue description = NJson::ReadJsonFastTree(SupplyDemandHistoryModelDescription);
        {
            auto path = TFsPath(GetWorkPath()) / "sdh" / "model.bin";
            TIFStream file(path);
            description["data"] = file.ReadAll();
        }
        auto model = NDrive::IOfferModel::Construct(description);
        UNIT_ASSERT(model);
        UNIT_ASSERT_VALUES_EQUAL(model->GetName(), "fakenews42");
        UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TSupplyDemandHistoryModel::Type());

        auto proto = model->Serialize<NDrive::NProto::TOfferModel>();
        auto deserialized = NDrive::IOfferModel::Construct(proto);
        UNIT_ASSERT(deserialized);

        auto featuresJson = NJson::ReadJsonFastTree(Features);
        auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);

        UNIT_ASSERT_DOUBLES_EQUAL(features.Floats[NDriveOfferFactors::FI_PRICE], 9.28, 0.001);
        auto applied = model->Calc(features);
        UNIT_ASSERT_DOUBLES_EQUAL(applied, 27.365, 0.01);
    }

    TString DRIVEANALYTICS_72ModelDescription = R"(
{
    "name": "fakenews42",
    "type": "supply_demand_history",
    "polynom": "100I1020000000U7"
}
    )";

    Y_UNIT_TEST(DRIVEANALYTICS_72) {
        NJson::TJsonValue description = NJson::ReadJsonFastTree(DRIVEANALYTICS_72ModelDescription);
        {
            auto path = TFsPath(GetWorkPath()) / "sdh" / "model.bin";
            TIFStream file(path);
            description["data"] = file.ReadAll();
        }
        auto model = NDrive::IOfferModel::Construct(description);
        UNIT_ASSERT(model);
        double lat = 55.781154;
        double lon = 37.465340;

        /*
        double dLat = 9.0 * 1000 / 1e6;
        double dLon = dLat / std::cos(55.45 * M_PI / 180);

        double zLat = 55.14;
        double zLon = 36.58;

        auto i = std::floor((lat - zLat) / dLat);
        auto j = std::floor((lon - zLon) / dLon);

        auto id = 1 * 1024 * 1024 + i * 1024 + j;
        */

        NDrive::TOfferFeatures features;
        features.Floats[NDriveOfferFactors::FI_PRICE] = 10;
        features.Floats[NDriveOfferFactors::FI_LATITUDE] = lat;
        features.Floats[NDriveOfferFactors::FI_LONGITUDE] = lon;
        NDrive::CalcRequestFeatures(features, TInstant::Seconds(1550640824));

        /*
        auto weekday = features.Floats[NDriveOfferFactors::FI_DAY_OF_THE_WEEK];
        auto dayMinutes = TDuration::Seconds(features.Floats[NDriveOfferFactors::FI_TIME_OF_THE_DAY] * 24 * 60 * 60).Minutes();
        */

        auto applied = model->Calc(features);

        UNIT_ASSERT_DOUBLES_EQUAL(applied, 10, 0.01);
    }

    TString VariationOptimizerModelDescription = R"(
{
    "name": "fakenews42",
    "type": "variation_optimizer",
    "polynom": "10091020000000UN0000I03"
}
    )";

    Y_UNIT_TEST(VariationOptimizer) {
        NJson::TJsonValue description = NJson::ReadJsonFastTree(VariationOptimizerModelDescription);
        {
            auto path = TFsPath(GetWorkPath()) / "vo" / "model.cbm";
            TIFStream file(path);
            description["data"].AppendValue(file.ReadAll());
        }
        {
            auto path = TFsPath(GetWorkPath()) / "vo" / "model.cbm";
            TIFStream file(path);
            description["data"].AppendValue(file.ReadAll());
        }
        auto model = NDrive::IOfferModel::Construct(description);
        UNIT_ASSERT(model);
        UNIT_ASSERT_VALUES_EQUAL(model->GetName(), "fakenews42");
        UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TVariationOptimizerModel::Type());

        auto proto = model->Serialize<NDrive::NProto::TOfferModel>();
        auto deserialized = NDrive::IOfferModel::Construct(proto);
        UNIT_ASSERT(deserialized);

        auto featuresJson = NJson::ReadJsonFastTree(Features);
        auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);

        auto applied = model->Calc(features);
        UNIT_ASSERT_DOUBLES_EQUAL(applied, 9.38, 0.01);
        UNIT_ASSERT(std::abs(features.Floats[NDriveOfferFactors::FI_MATRIXNET]) > 0.001);
        UNIT_ASSERT(std::abs(features.Floats[NDriveOfferFactors::FI_MATRIXNET_2]) > 0.001);
        UNIT_ASSERT(std::abs(features.Floats[NDriveOfferFactors::FI_MATRIXNET_COMPOSITE]) > 0.001);
    }

    TString CatboostModelDescription = R"(
{
    "name": "fakenews42",
    "type": "catboost",
    "polynom": "10091020000000UN0000I03"
}
    )";

    Y_UNIT_TEST(CatboostModel) {
        NJson::TJsonValue description = NJson::ReadJsonFastTree(CatboostModelDescription);
        {
            auto path = TFsPath(GetWorkPath()) / "vo" / "model.cbm";
            TIFStream file(path);
            description["data"] = file.ReadAll();
        }
        auto model = NDrive::IOfferModel::Construct(description);
        UNIT_ASSERT(model);
        UNIT_ASSERT_VALUES_EQUAL(model->GetName(), "fakenews42");
        UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TCatboostModel::Type());

        auto proto = model->Serialize<NDrive::NProto::TOfferModel>();
        auto deserialized = NDrive::IOfferModel::Construct(proto);
        UNIT_ASSERT(deserialized);

        auto featuresJson = NJson::ReadJsonFastTree(Features);
        auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);

        auto applied = model->Calc(features);
        UNIT_ASSERT_DOUBLES_EQUAL(applied, 5.129, 0.01);
    }

    TString DRIVEANALYTICS_256ModelDescription = R"(
{
    "name": "DA-256",
    "type": "variation_optimizer",
    "variation_absolute_min": 6,
    "variation_absolute_max": 10,
    "variation_step": 0.07,
    "polynom": "100810C4040000000SF0"
}
    )";

    TString DRIVEANALYTICS_256Features = R"(
[
    [
        8.600000381,
        0.9555556178,
        2.261763096,
        2.799999952,
        1.076923132,
        1.335001111,
        55.66846848,
        37.43664932,
        4,
        0.4179745317,
        0.75,
        0.4930000007,
        6.202535629,
        0,
        0,
        0,
        0,
        0,
        0.00976228714,
        0,
        0,
        1,
        0,
        0,
        0,
        5000,
        0,
        0,
        0,
        0,
        0,
        0,
        8.600000381,
        0,
        0,
        0.9606191516
    ],
    [
        "kia_rio_xline",
        "n",
        "Большой Спасоглинищевский переулок"
    ]
]
    )";

    Y_UNIT_TEST(DRIVEANALYTICS_256) {
        auto description = NJson::ReadJsonFastTree(DRIVEANALYTICS_256ModelDescription);
        {
            auto path = TFsPath(GetWorkPath()) / "da256_1" / "order_duration_20190528.cbm";
            TIFStream file(path);
            description["data"].AppendValue(file.ReadAll());
        }
        {
            auto path = TFsPath(GetWorkPath()) / "da256_2" / "time_till_accept_20190528.cbm";
            TIFStream file(path);
            description["data"].AppendValue(file.ReadAll());
        }
        auto model = NDrive::IOfferModel::Construct(description);
        UNIT_ASSERT(model);
        UNIT_ASSERT_VALUES_EQUAL(model->GetName(), "DA-256");
        UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TVariationOptimizerModel::Type());

        auto featuresJson = NJson::ReadJsonFastTree(DRIVEANALYTICS_256Features);
        auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);

        auto applied = model->Calc(features);
        UNIT_ASSERT_DOUBLES_EQUAL(applied, 9.92, 0.01);
    }

    Y_UNIT_TEST(JitterModel) {
        const auto featuresJson = NJson::ReadJsonFastTree(Features);
        const auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);
        const auto price = features.GetPrice();

        NDrive::TJitterOfferModel uniform("jitter", "jitter", 0.5, 1.5, 0, 0);
        auto f = features;
        auto firstUniformValue = uniform.Calc(f);
        auto secondUniformValue = uniform.Calc(f);
        UNIT_ASSERT_DOUBLES_EQUAL(firstUniformValue, secondUniformValue, 0.001);
        INFO_LOG << firstUniformValue << Endl;
        UNIT_ASSERT(firstUniformValue >= 0.5 * price);
        UNIT_ASSERT(firstUniformValue <= 1.5 * price);
        UNIT_ASSERT(std::abs(firstUniformValue / price - 1) > 0.01);

        NDrive::TJitterOfferModel normal("normal_jitter", "jitter", 0.1, 10, 1, 5);

        auto firstNormalValue = normal.Calc(f);
        auto secondNormalValue = normal.Calc(f);
        UNIT_ASSERT_DOUBLES_EQUAL(firstNormalValue, secondNormalValue, 0.001);
        INFO_LOG << firstNormalValue << Endl;
        UNIT_ASSERT(firstNormalValue >= 0.1 * price);
        UNIT_ASSERT(firstNormalValue <= 10. * price);
        UNIT_ASSERT(std::abs(firstNormalValue / price - 5) > 0.01);
        UNIT_ASSERT(std::abs(firstNormalValue / price - 5) < 3);
    }

    Y_UNIT_TEST(GeoLocalModel) {
        auto featuresJson = NJson::ReadJsonFastTree(DRIVEANALYTICS_256Features);
        auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);

        ui8 hour = 10;
        bool isBaseModel = true;
        bool isWeekend = false;
        auto street = TString("Большой Спасоглинищевский переулок");

        auto manual = NDrive::TGeoLocalModel::CalcHash(isBaseModel, isWeekend, hour, street);
        auto automatic = NDrive::TGeoLocalModel::CalcHash(features, false);
        UNIT_ASSERT_VALUES_EQUAL(manual, automatic);

        NDrive::TGeoLocalModel::TElements elements;
        elements.emplace_back(42, 4.2);
        elements.emplace_back(automatic, 13);
        NDrive::TGeoLocalModel::TOptions options;
        options.Name = "GeoLocal";
        options.Elements = elements;
        NDrive::TGeoLocalModel model(options);
        auto price = model.Calc(features);
        UNIT_ASSERT_DOUBLES_EQUAL(price, 13 * 8.6, 0.001);

        options.Polynom = "K0010000000V3"; // 1.0*MatrixNet
        NDrive::TGeoLocalModel model2(options);
        auto price2 = model2.Calc(features);
        UNIT_ASSERT_DOUBLES_EQUAL(price2, 13, 0.001);

        features.Categories2[NDriveOfferCatFactors2::FI_GEOCODED_STREET] = "fake";
        auto price3 = model2.Calc(features);
        UNIT_ASSERT_DOUBLES_EQUAL(price3, 8.6, 0.001);
    }

    Y_UNIT_TEST(InvAverageSessionClicksCount) {
        NDrive::TAdditionThresholdModel::TOptions modelOptions;
        modelOptions.Name = "fakemuse";
        modelOptions.Threshold = 0.5;
        modelOptions.Addition = 2;
        modelOptions.PolynomData = "51010000000V3"; // 1.0*InvAverageSessionClicksCount
        NDrive::TAdditionThresholdModel model(modelOptions);

        NDrive::TOfferFeatures features;
        features.Floats[NDriveOfferFactors::FI_PRICE] = 10;
        {
            features.Floats[NDriveOfferFactors::FI_INV_AVERAGE_SESSION_CLICKS_COUNT] = 0.3;
            auto result = model.Calc(features);
            UNIT_ASSERT_DOUBLES_EQUAL(result, 10, 0.001);
        }
        {
            features.Floats[NDriveOfferFactors::FI_INV_AVERAGE_SESSION_CLICKS_COUNT] = 0.9;
            auto result = model.Calc(features);
            UNIT_ASSERT_DOUBLES_EQUAL(result, 12, 0.001);
        }
    }

    Y_UNIT_TEST(RoundModel) {
        NDrive::TRoundModel::TOptions modelOptions;
        modelOptions.Name = "fake42";
        NDrive::TRoundModel model(modelOptions);

        NDrive::TOfferFeatures features;
        {
            features.Floats[NDriveOfferFactors::FI_PRICE] = 9.16;
            auto result = model.Calc(features);
            UNIT_ASSERT_DOUBLES_EQUAL(result, 8.99, 0.001);
        }
        {
            features.Floats[NDriveOfferFactors::FI_PRICE] = 9.86;
            auto result = model.Calc(features);
            UNIT_ASSERT_DOUBLES_EQUAL(result, 9.86, 0.001);
        }
    }

    Y_UNIT_TEST(RandomSetModel) {
        NDrive::TRandomSetModel::TOptions options;
        options.Name = "booze";
        options.Values = {
            1,
            2,
            3,
            4,
            42,
        };
        NDrive::TRandomSetModel model(options);
        for (size_t price = 0; price < 10; ++price) {
            for (double lat = 55.4; lat < 56.2; lat += 0.1) {
                NDrive::TOfferFeatures features;
                features.Floats[NDriveOfferFactors::FI_PRICE] = price;
                features.Floats[NDriveOfferFactors::FI_LATITUDE] = lat;
                auto result = model.Calc(features);
                UNIT_ASSERT(options.Values.contains(result));
            }
        }
    }

    TString TimeScheduleModelDescription = R"(
{
    "name": "time_schedule",
    "type": "time_schedule",
    "config": {
        "time_shift": 10800,
        "segments": [
            {
                "start": 240,
                "value": 980
            },
            {
                "start": 66540,
                "value": 970
            },
            {
                "start": 68520,
                "value": 970
            },
            {
                "start": 198120,
                "value": 910
            },
            {
                "start": 518640,
                "value": 1000
            },
            {
                "start": 603000,
                "value": 990
            }
        ]
    }
}
    )";

    Y_UNIT_TEST(TimeScheduleModel) {
        auto description = NJson::ReadJsonFastTree(TimeScheduleModelDescription);
        auto model = NDrive::IOfferModel::Construct(description);
        UNIT_ASSERT(model);
        UNIT_ASSERT_VALUES_EQUAL(model->GetName(), "time_schedule");
        UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TTimeScheduleModel::Type());

        auto featuresJson = NJson::ReadJsonFastTree(DRIVEANALYTICS_256Features);
        auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);
        features.Floats[NDriveOfferFactors::FI_TIMESTAMP] = 1571836032;

        auto applied = model->Calc(features);
        UNIT_ASSERT_DOUBLES_EQUAL(applied, 9.1, 0.01);
    }

    TString DRIVEANALYTICS_488ModelDescription = R"(
{
    "name": "fakenews42",
    "type": "catboost",
    "polynom": "10081020000000U7"
}
    )";

    Y_UNIT_TEST(DRIVEANALYTICS_488) {
        NJson::TJsonValue description = NJson::ReadJsonFastTree(DRIVEANALYTICS_488ModelDescription);
        {
            auto path = TFsPath(GetWorkPath()) / "da488" / "distance_to_discount.cbm";
            TIFStream file(path);
            description["data"] = file.ReadAll();
        }
        auto model = NDrive::IOfferModel::Construct(description);
        UNIT_ASSERT(model);
        UNIT_ASSERT_VALUES_EQUAL(model->GetName(), "fakenews42");
        UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TCatboostModel::Type());

        NDrive::TOfferFeatures features;
        features.Floats[NDriveOfferFactors::FI_PRICE] = 10;
        for (size_t i = 1; i <= 20; ++i) {
            auto walkingDuration = 100 * i;
            features.Floats[NDriveOfferFactors::FI_WALKING_DURATION] = walkingDuration;
            auto applied = model->Calc(features);
            Cout << walkingDuration << '\t' << applied << Endl;
        }
    }

    TString MulticlassModelDescription = R"(
{
    "name": "mltcls",
    "type": "catboost_multiclass",
    "soft_max": true
}
    )";

    Y_UNIT_TEST(MulticlassModelWithSoftMax) {
        NJson::TJsonValue description = NJson::ReadJsonFastTree(MulticlassModelDescription);
        {
            auto path = TFsPath(GetWorkPath()) / "mltcls" / "small_region_predictor.v1.cbm";
            TIFStream file(path);
            description["data"] = file.ReadAll();
        }

        auto model = NDrive::IOfferModel::Construct(description);
        UNIT_ASSERT(model);
        UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TCatboostMulticlassModel::Type());

        auto multiclassModel = dynamic_cast<NDrive::TCatboostMulticlassModel*>(model.Get());
        UNIT_ASSERT(multiclassModel);

        auto featuresJson = NJson::ReadJsonFastTree(DRIVEANALYTICS_256Features);
        auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);
        for (size_t i = 42; i < features.Floats.size(); ++i) {
            features.Floats[i] = ((i * i) % 37) / 37.0;
        }

        auto predictions = multiclassModel->Predict(features);
        UNIT_ASSERT_VALUES_EQUAL(predictions.size(), 14);
        double sum = 0;
        for (auto&& i : predictions) {
            UNIT_ASSERT(i <= 1);
            UNIT_ASSERT(i >= 0);
            sum += i;
        }
        UNIT_ASSERT_DOUBLES_EQUAL(sum, 1, 0.001);
    }

    Y_UNIT_TEST(MetaMulticlassModel) {
        auto first = NJson::ReadJsonFastTree(DRIVEANALYTICS_488ModelDescription);
        auto second = first;

        NJson::TJsonValue description = NJson::TMapBuilder
            ("name", "blabla42")
            ("type", NDrive::TMetaMulticlassModel::Type())
            ("submodels", NJson::TArrayBuilder(first)(second))
        ;

        auto featuresJson = NJson::ReadJsonFastTree(DRIVEANALYTICS_256Features);
        auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);

        auto model = Share(NDrive::IOfferModel::Construct(description));
        UNIT_ASSERT(model);
        UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TMetaMulticlassModel::Type());

        auto predictor = std::dynamic_pointer_cast<NDrive::IMulticlassModel>(model);
        UNIT_ASSERT(predictor);
        auto prediction = predictor->Predict(features);
        UNIT_ASSERT_VALUES_EQUAL(prediction.size(), 2);

        auto serialized = model->Serialize<TString>();
        auto deserialized = Share(NDrive::IOfferModel::Construct(serialized));
        UNIT_ASSERT(deserialized);
        auto deserializedPredictor = std::dynamic_pointer_cast<NDrive::IMulticlassModel>(deserialized);
        UNIT_ASSERT(deserializedPredictor);
        auto deserializedPrediction = deserializedPredictor->Predict(features);

        UNIT_ASSERT_VALUES_EQUAL(prediction.size(), deserializedPrediction.size());
        UNIT_ASSERT_DOUBLES_EQUAL(prediction[0], deserializedPrediction[0], 0.001);
        UNIT_ASSERT_DOUBLES_EQUAL(prediction[1], deserializedPrediction[1], 0.001);
    }
    Y_UNIT_TEST(LuaModel) {

        auto featuresJson = NJson::ReadJsonFastTree(DRIVEANALYTICS_256Features);
        auto features = NJson::FromJson<NDrive::TOfferFeatures>(featuresJson);
        features.Floats2[0] = 8.600000381;
        features.Floats2[1] = 0.9555556178;
        {
            NJson::TJsonValue description = NJson::TMapBuilder
                ("name", "lua_model")
                ("type", NDrive::TLuaModel::Type())
                ("script", "return coefs.floats2.Price + coefs.floats2.NormalizedPrice")
            ;

            auto model = Share(NDrive::IOfferModel::Construct(description));
            UNIT_ASSERT(model);
            UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TLuaModel::Type());
            auto applied = model->Calc(features);
            UNIT_ASSERT_DOUBLES_EQUAL(applied, features.Floats2[0] + features.Floats2[1], 1e-4);
        }
        {
            NJson::TJsonValue description = NJson::TMapBuilder
                ("name", "lua_model")
                ("type", NDrive::TLuaModel::Type())
                ("script", "x = 3.4\n"
                           "return coefs.floats2.Price + coefs.floats2.NormalizedPrice + math.sin(x)")
            ;
            auto model = Share(NDrive::IOfferModel::Construct(description));
            UNIT_ASSERT(model);
            UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TLuaModel::Type());
            auto applied = model->Calc(features);
            UNIT_ASSERT_DOUBLES_EQUAL(applied, features.Floats2[0] + features.Floats2[1] + sin(3.4), 1e-4);
        }
        {
            NJson::TJsonValue description = NJson::TMapBuilder
                ("name", "lua_model")
                ("type", NDrive::TLuaModel::Type())
                ("script",  "function get_len(my_str)\n"
                                "return string.len(my_str)\n"
                            "end\n"
                            "return get_len(coefs.categories2.WeatherPrecipitationType)")
            ;
            auto model = Share(NDrive::IOfferModel::Construct(description));
            UNIT_ASSERT(model);
            UNIT_ASSERT_VALUES_EQUAL(model->GetType(), NDrive::TLuaModel::Type());
            auto applied = model->Calc(features);
            UNIT_ASSERT_VALUES_EQUAL(applied, features.Categories2[NDriveOfferCatFactors2::FI_WEATHER_PRECIPITATION_TYPE].size());
        }
        {
            NJson::TJsonValue description = NJson::TMapBuilder
                ("name", "lua_model")
                ("type", NDrive::TLuaModel::Type())
                ("script",  "funcion get_len(my_str)\n" // funcion -> function
                                "return string.len(my_str)\n"
                            "end\n"
                            "return get_len(coefs.categories2.WeatherPrecipitationType)")
            ;
            UNIT_ASSERT_EXCEPTION(Share(NDrive::IOfferModel::Construct(description)), yexception);
        }
    }
}
