#include <internal/db/ep/meta_cache_based_endpoint_provider.h>
#include <internal/dc_vanga.h>
#include <internal/logger.h>
#include <internal/poller/meta_shards_provider.h>

#include <boost/range/adaptor/filtered.hpp>

#include <iterator>
#include <stdexcept>
#include <utility>

using namespace sharpei;

namespace {

auto getLogger() {
    static auto logger = logdog::bind(log::GetLogger(log::sharpeiLogKey), log::where_file = "meta_cache_based_ep");
    return logger;
}

}  // namespace

namespace sharpei::db {

yamail::expected<Shard::Database> SimpleEndpointSelectionStrategy::selectEndpoint(
    const std::vector<Shard::Database>& dbs) {
    if (dbs.empty()) {
        return yamail::make_unexpected(ExplainedError(Error::appropriateHostNotFound));
    }
    const auto next = ++iterator % dbs.size();
    return dbs[next];
}

DCAwareMainStrategy::DCAwareMainStrategy(const std::string& localDC, const std::size_t aliveHostsThreshold)
: localDC_(localDC)
, aliveHostsThreshold_(aliveHostsThreshold) {
}

DCAwareMainStrategy::DCAwareMainStrategy(const dc_vanga::Rules& rules, const std::size_t aliveHostsThreshold)
: DCAwareMainStrategy(dc_vanga::DCVanga(rules).getLocalhostDC(), aliveHostsThreshold) {
}

yamail::expected<Shard::Database> DCAwareMainStrategy::selectEndpoint(const std::vector<Shard::Database>& dbs) {
    const auto filter = [&] (const auto& db) { return db.address().dataCenter == localDC_; };
    const std::size_t aliveHostsCount = std::count_if(dbs.begin(), dbs.end(), filter);
    if (aliveHostsThreshold_ > aliveHostsCount) {
        return yamail::make_unexpected(ExplainedError(Error::appropriateHostNotFound, "not enough alive hosts in my DC"));
    }

    std::vector<Shard::Database> fromLocalDC;
    fromLocalDC.reserve(aliveHostsCount);
    std::copy_if(dbs.begin(), dbs.end(), std::back_inserter(fromLocalDC), filter);
    return impl_.selectEndpoint(fromLocalDC);
}

IEndpointProvider::Endpoint toEndpoint(const Shard::Database& db, const AuthInfo& authInfo) {
    // just for clarity: host is one of the hosts from the |meta_connection.hostlist| config option,
    // port, dbname, authInfo are taken from the |meta_connection.db.adaptor.conn_info| config option.
    return ConnectionInfo(db.address().host, db.address().port, db.address().dbname, authInfo);
}

MetaCacheBasedEndpointProvider::MetaCacheBasedEndpointProvider(cache::CachePtr metaCache, const db::AuthInfo& authInfo,
                                                               FiltrationStrategyPtr filter,
                                                               EndpointSelectionStrategyPtr mainStrategy,
                                                               EndpointSelectionStrategyPtr fallbackStrategy)
: metaCache_(std::move(metaCache)), authInfo_(authInfo), filter_(std::move(filter)),
  mainStrategy_(std::move(mainStrategy)), fallbackStrategy_(std::move(fallbackStrategy)) {
}

IEndpointProvider::Endpoint MetaCacheBasedEndpointProvider::getNext() {
    const auto dbs = getMetaDatabases();
    if (const auto mainChoice = mainStrategy_->selectEndpoint(dbs)) {
        LOGDOG_(getLogger(), debug,
                log::message = "main strategy successfully selected endpoint: " + mainChoice->address().host);
        return toEndpoint(*mainChoice, authInfo_);
    } else {
        LOGDOG_(getLogger(), error, log::message = "main strategy failed; fallback strategy will be used",
                log::error_code = Error::endpointProviderError, log::reason = mainChoice.error());
        if (const auto fallbackChoice = fallbackStrategy_->selectEndpoint(dbs)) {
            LOGDOG_(getLogger(), notice,
                    log::message = "fallback strategy selected endpoint: " + fallbackChoice->address().host);
            return toEndpoint(*fallbackChoice, authInfo_);
        } else {
            LOGDOG_(getLogger(), error, log::message = "fallback strategy failed",
                    log::error_code = Error::endpointProviderError, log::reason = fallbackChoice.error());
            throw std::runtime_error(fallbackChoice.error().full_message());
        }
    }
}

std::vector<Shard::Database> MetaCacheBasedEndpointProvider::getMetaDatabases() const {
    ExplainedError err;
    boost::optional<Shard> metabase;
    std::tie(err, metabase) = metaCache_->getShard(poller::MetaShardsProvider::shardId);
    if (err) {
        LOGDOG_(getLogger(), error, log::message = "metaCache_->getShard error",
                log::error_code = Error::endpointProviderError, log::reason = err);
        throw std::runtime_error("metaCache_->getShard error");
    }
    if (!metabase.is_initialized()) {
        throw std::logic_error("return code is ok but the list of databases is empty");
    }
    return filterDatabases(get(metabase).databases, *filter_);
}

EndpointProviderPtr makeMetaCacheBasedEndpointProvider(cache::CachePtr metaCache,
                                                       const db::AuthInfo& authInfo,
                                                       const dc_vanga::Rules& rules,
                                                       const std::size_t aliveHostsThreshold) {
    FiltrationStrategyPtr filter{new AcceptingAllNodesFiltrationStrategy};
    EndpointSelectionStrategyPtr mainStrategy;
    if (!rules.empty()) {
        LOGDOG_(getLogger(), notice, log::message = "DCVanga rules provided, so DCAwareMainStrategy will be used");
        mainStrategy = std::make_shared<DCAwareMainStrategy>(rules, aliveHostsThreshold);
    } else {
        LOGDOG_(getLogger(), notice, log::message = "no DCVanga rules provided, so SimpleMainStrategy will be used");
        mainStrategy = std::make_shared<SimpleEndpointSelectionStrategy>();
    }
    EndpointSelectionStrategyPtr fallbackStrategy = std::make_shared<SimpleEndpointSelectionStrategy>();
    return std::make_shared<MetaCacheBasedEndpointProvider>(std::move(metaCache), authInfo, std::move(filter),
                                                            std::move(mainStrategy), std::move(fallbackStrategy));
}

}  // namespace sharpei::db
