#include <mail/template_master/lib/db/connection_provider.h>
#include <mail/template_master/ut/mock/resolver_mock.h>
#include <mail/template_master/ut/utils.h>

#include <library/cpp/testing/unittest/registar.h>

#include <memory>
#include <chrono>

namespace NTemplateMaster::NTests {

using namespace std::chrono_literals;
using ::testing::Return;
using ::testing::_;
using ::testing::InSequence;

using TConnectionProvider = NTemplateMaster::NDatabase::TConnectionProvider<std::shared_ptr<TResolver>>;
using EExecutionPolicy = NTemplateMaster::NDatabase::EExecutionPolicy;
using TMapShard = ::sharpei::client::MapShard;
using TShard = ::sharpei::client::Shard;
using TDatabase = ::sharpei::client::Shard::Database;
using EError = NTemplateMaster::NErrors::EError;

class TConnectionProviderTest : public TTestBase, public TWithSpawn {
    UNIT_TEST_SUITE(TConnectionProviderTest)
        UNIT_TEST(GetMaster)
        UNIT_TEST(GetReplicaOrMaster)
        UNIT_TEST(NotFoundDbWithRequestedExecutionPolicy)
        UNIT_TEST(Error)
    UNIT_TEST_SUITE_END();
public:
    void SetUp() override {
        Io = std::make_unique<boost::asio::io_context>();
        Context = GetContext();
    }

    void GetMaster() {
        Spawn([=](TYield yield) {
            const InSequence s;
            auto resolver = std::make_shared<TResolver>();
            resolver->mock = GetResolverMock();
            TConnectionProvider connProvider(nullptr, resolver, {});

            TDatabase dbMaster{{"host", 80, "", ""}, "master", "alive", {}};
            std::vector<TDatabase> databases({dbMaster});
            TShard shard{"1", "1", databases};
            TMapShard shards{{"shard1", shard}};
            EXPECT_CALL(*resolver->mock, GetAllShards())
                .WillOnce(Return(std::make_pair(shards, boost::system::error_code{})));
            auto dbInfoExp = connProvider.GetDatabaseInfo(Context, EExecutionPolicy::Master, yield);
            EXPECT_TRUE(dbInfoExp);
            const auto info = dbInfoExp.value();
            EXPECT_EQ(info.Db.role, "master");
            EXPECT_EQ(info.Db.status, "alive");
            EXPECT_EQ(info.Db.address.host, "host");
        });
    }

    void GetReplicaOrMaster() {
        Spawn([=](TYield yield) {
            const InSequence s;
            auto resolver = std::make_shared<TResolver>();
            resolver->mock = GetResolverMock();
            TConnectionProvider connProvider(nullptr, resolver, {});

            TDatabase dbMaster{{"host", 80, "", ""}, "master", "alive", {}};
            TDatabase dbReplica{{"hostReplica", 80, "", ""}, "replica", "alive", {}};
            std::vector<TDatabase> databases({dbMaster, dbReplica});
            TShard shard{"1", "1", databases};
            TMapShard shards{{"shard1", shard}};
            EXPECT_CALL(*resolver->mock, GetAllShards())
                    .WillOnce(Return(std::make_pair(shards, boost::system::error_code{})));
            auto dbInfoExp = connProvider.GetDatabaseInfo(Context, EExecutionPolicy::ReplicaOrMaster, yield);
            EXPECT_TRUE(dbInfoExp);
            const auto info = dbInfoExp.value();
            EXPECT_EQ(info.Db.status, "alive");
        });
    }

    void NotFoundDbWithRequestedExecutionPolicy() {
        Spawn([=](TYield yield) {
            const InSequence s;
            auto resolver = std::make_shared<TResolver>();
            resolver->mock = GetResolverMock();
            TConnectionProvider connProvider(nullptr, resolver, {});

            std::vector<TDatabase> databases;
            TShard shard{"1", "1", databases};
            TMapShard shards{{"shard1", shard}};
            EXPECT_CALL(*resolver->mock, GetAllShards())
                    .WillOnce(Return(std::make_pair(shards, boost::system::error_code{})));
            auto dbInfoExp = connProvider.GetDatabaseInfo(Context, EExecutionPolicy::Master, yield);
            EXPECT_FALSE(dbInfoExp);
            EXPECT_EQ(static_cast<EError>(dbInfoExp.error().value()), EError::RequestedDbNotFound);
        });
    }

    void Error() {
        Spawn([=](TYield yield) {
            const InSequence s;
            auto resolver = std::make_shared<TResolver>();
            resolver->mock = GetResolverMock();
            TConnectionProvider connProvider(nullptr, resolver, {});

            TMapShard shards;
            auto ec = boost::system::errc::make_error_code(boost::system::errc::not_supported);
            EXPECT_CALL(*resolver->mock, GetAllShards())
                    .WillOnce(Return(std::make_pair(shards, ec)));
            auto dbInfoExp = connProvider.GetDatabaseInfo(Context, EExecutionPolicy::Master, yield);
            EXPECT_FALSE(dbInfoExp);
            EXPECT_EQ(dbInfoExp.error(), ec);
        });
    }
private:
    NTemplateMaster::TContextPtr Context;
};

}

UNIT_TEST_SUITE_REGISTRATION(NTemplateMaster::NTests::TConnectionProviderTest);
