#pragma once

#include <type_traits>
#include <pgg/database/database.h>
#include <pgg/database/performer.h>
#include <pgg/logging.h>
#include <boost/asio/coroutine.hpp>
#include <boost/asio/yield.hpp>

namespace pgg {
namespace database {
namespace fallback {

using Query = query::Query;
using EndpointType = query::Traits::EndpointType;
using LogPtr = logging::LogPtr;

namespace rules {

namespace detail {

/**
 * Function indicates when the "no role for end point" error must be retried.
 * Generally we must do the only one retry per role for this error.
 *
 * We need to retry role on FIRST (nTry==0) and SECOND (nTry==1) tries.
 * On first try we will retry error from previous strategy. On second try -
 * error from the current strategy.
 *
 * EXAMPLE
 *
 * We have a "Master then Replica" strategy and trying to get to master.
 * We call Master::recoverable() with no error, then call  Master::nextTry()
 * and get error of no master role found. Now we call Master::recoverable()
 * with the error and this is SECOND try (nTry==1). Now we have to retry error
 * of no master role found. We call Master::nextTry() once more, but still
 * get the same error. After that we need to switch to next strategy, because
 * obviosly we can not recover the error with retries. So we have nTry==2 and
 * when Master::recoverable() is called we must return false to switch to the
 * Replica strategy. After switching strategy we call Replice::recoverable()
 * with the no role found error. In this case we have FIRST try (nTry==0),
 * and we can retry the error from previous strategy, so Replice::recoverable()
 * must return true.
 */
inline bool needRetryRole(std::size_t nTry, const pgg::error_code& code) {
    return nTry <= 1 && code == errc::noEndpointFound;
}

inline bool userMigrated(const pgg::error_code& ec, const pgg::error_code& prev) {
    return ec == errc::userRemovedFromShard && prev != errc::userRemovedFromShard;
}

} // namespace detail

class Master {
public:
    template <typename ConnectionProvider>
    DatabasePtr nextTry(const ConnectionProvider & db, const pgg::error_code& ec = pgg::error_code{}) {
        const bool forceUpdate = nTry != 0 || ec == errc::userRemovedFromShard;
        nTry++;
        return db(EndpointQuery{EndpointType::master, forceUpdate, 0}, LogError::disable);
    }

    bool recoverable(const pgg::error_code& ec, const pgg::error_code& prev) const {
        if(!ec || nTry == maxTries) {
            return false;
        }

        return ec == errc::communicationError ||
                ec == errc::databaseReadOnly ||
                detail::needRetryRole(nTry, ec) ||
                detail::userMigrated(ec, prev);
    }

    std::string name() const { return "master"; }

private:
    unsigned nTry = 0;
    constexpr static unsigned maxTries = 2;
};

template <EndpointType::Enum endpointType>
class BaseReplica {
public:
    template <typename ConnectionProvider>
    DatabasePtr nextTry(const ConnectionProvider & db, const pgg::error_code& ec = pgg::error_code{}) {
        const bool forceUpdate = ec == errc::userRemovedFromShard;
        return db(EndpointQuery{endpointType, forceUpdate, nTry++}, LogError::disable);
    }

    bool recoverable(const pgg::error_code& ec, const pgg::error_code& prev) const {
        if(!ec) {
            return false;
        }

        return ec == errc::communicationError ||
                detail::needRetryRole(nTry, ec) ||
                detail::userMigrated(ec, prev);
    }

    std::string name() const { return EndpointType(endpointType).toString(); }

private:
    unsigned nTry = 0;
};

using Replica = BaseReplica<EndpointType::replica>;
using NoLagReplica = BaseReplica<EndpointType::noLagReplica>;
using LagReplica = BaseReplica<EndpointType::lagReplica>;

} // namespace rules

template <typename ... RulesList>
using Rules = std::tuple<RulesList...>;

template <std::size_t N, typename RulesT>
using Rule = typename std::tuple_element<N, RulesT>::type;

namespace details {

struct Request {
    template <typename Conn, typename ... Args>
    static void call(Conn&& c, Args&& ... args) { c.request(std::forward<Args>(args)...); }
};

struct Fetch {
    template <typename Conn, typename ... Args>
    static void call(Conn&& c, Args&& ... args) { c.fetch(std::forward<Args>(args)...); }
};

struct Update {
    template <typename Conn, typename ... Args>
    static void call(Conn&& c, Args&& ... args) { c.update(std::forward<Args>(args)...); }
};

struct Execute {
    template <typename Conn, typename ... Args>
    static void call(Conn&& c, Args&& ... args) { c.execute(std::forward<Args>(args)...); }
};

} // namespace details

namespace strategy {

using ReadMasterThenReplica = Rules<rules::Master, rules::Replica>;
using MasterOnly = Rules<rules::Master>;
using ReadReplicaThenMaster = Rules<rules::Replica, rules::Master>;
using ReplicaOnly = Rules<rules::Replica>;
using ReadNoLagReplicaThenMasterThenReplica = Rules<rules::NoLagReplica, rules::Master, rules::LagReplica>;

template <typename T>
struct StrategyName;

template <>
struct StrategyName<ReadMasterThenReplica> {
    constexpr static const char* value = "read_master_then_replica";
};

template <>
struct StrategyName<ReadReplicaThenMaster> {
    constexpr static const char* value = "read_replica_then_master";
};

template <>
struct StrategyName<MasterOnly> {
    constexpr static const char* value = "master_only";
};

template <>
struct StrategyName<ReplicaOnly> {
    constexpr static const char* value = "replica_only";
};

template <>
struct StrategyName<ReadNoLagReplicaThenMasterThenReplica> {
    constexpr static const char* value = "read_no_lag_replica_then_master_then_replica";
};

template <>
struct StrategyName<Rules<rules::NoLagReplica>> {
    constexpr static const char* value = "no_lag_replica_only";
};

template <>
struct StrategyName<Rules<rules::LagReplica>> {
    constexpr static const char* value = "lag_replica_only";
};

template <typename Strategy>
constexpr std::string_view strategyName(Strategy) {
    return {StrategyName<Strategy>::value};
}

} // namespace strategy

namespace details {

template <typename Rules>
class Caller {
public:
    Caller(Rules rules) : rules(std::move(rules)) {}

    template <typename Method, typename Handler, typename ConnProvider >
    void run(ConnProvider db,
            Query::QueryPtr qptr, Handler handler, LogPtr log) const {

        using Ctx = Context<Handler, ConnProvider>;
        auto ctx = std::make_shared<Ctx>(std::move(db), qptr, std::move(handler), rules, log);
        RuleExecutor<Method, decltype(ctx), 0>(ctx).run();
    }

private:
    Rules rules;

    static void fill(logging::Attributes& attributes, const error_code::info_type& info) {
        using namespace logging;
        attributes.insert(attributes.end(), {
            QueryText(info.query_text),
            QueryValues(info.query_values),
            ConnectionInfo(info.connstr)
        });
    }

    template <typename Handler, typename ConnProvider >
    struct Context {
        Query::QueryPtr qptr;
        Handler handler;
        Rules rules;
        ConnProvider db;
        LogPtr log;
        pgg::error_code prevError;
        logging::StrategyName strategyName;

        Context(ConnProvider db, Query::QueryPtr qptr,
                Handler&& handler, Rules rules, LogPtr log)
        : qptr(qptr), handler(std::move(handler)), rules(std::move(rules)),
              db(std::move(db)), log(log), strategyName(strategy::strategyName(rules)) {}

        template <typename ...Args>
        void handle(database::error_code ec, Args && ... args) {
            if(ec && log) {
                using namespace logging;
                logError(*log, [&] {
                    std::ostringstream s;
                    s << "query failed with {error_category=\"" << ec.category().name() << "\","
                         " \"error_message=\"" << ec.message() << "\"}";

                    Attributes attributes = { &ec, strategyName };

                    if(ec.info()) {
                        const auto& info = ec.info().get();
                        s << std::endl;
                        streamErrorInfo(s, info);
                        fill(attributes, info);
                    }

                    attributes.push_back(OldMessage(s.str()));
                    return Record{ qptr->name(), "", std::move(attributes) };
                });
            }
            handler(std::move(ec), std::forward<Args>(args)...);
        }
    };

    template <typename Method, typename CtxPtr, std::size_t N>
    struct RuleExecutor : boost::asio::coroutine {
        CtxPtr ctx;

        RuleExecutor(CtxPtr ctx) : ctx(ctx) {}

        template <typename ...Args>
        void operator() (database::error_code ec, Args && ... args) {
            reenter(this) {
                while(ec) {
                    if(rule().recoverable(ec, ctx->prevError)) {
                        logRetry(ec);
                        yield run(ec);
                    } else {
                        RuleExecutor<Method, CtxPtr, N+1> tryNextRule(ctx);
                        tryNextRule(std::move(ec), std::forward<Args>(args)...);
                        return;
                    }
                }
                for(;;) {
                    yield ctx->handle(std::move(ec), std::forward<Args>(args)...);
                }
            }
        }

        Rule<N, Rules> & rule() const { return std::get<N>(ctx->rules); }

        void run(const database::error_code& ec = database::error_code{}) const {
            const auto conn = rule().nextTry(ctx->db, ec);
            ctx->prevError = ec;
            Method::call(*conn,*(ctx->qptr), *this);
        }

        void logRetry(const database::error_code& ec) const {
            if(ctx->log) {
                using namespace logging;
                logNotice(*(ctx->log), [&] {
                    Attributes attributes = { &ec, RuleName(rule().name()), ctx->strategyName };

                    std::ostringstream s;
                    s << "query failed with {";
                    if(ec.info()) {
                        const auto& info = ec.info().get();
                        s << "connstr=\"" << info.connstr << "\", ";
                        fill(attributes, info);
                    }
                    s << "error_category=\"" << ec.category().name()
                      << "\", error_message=\"" << ec.message()
                      << "\"} and will be retried with {next_rule=\"" << rule().name() << "\"}";

                    attributes.push_back(OldMessage(s.str()));
                    return Record{ ctx->qptr->name(), "", std::move(attributes) };
                });
            }
        }
    };

    template <typename Method, typename CtxPtr>
    struct RuleExecutor<Method, CtxPtr, std::tuple_size<Rules>::value> {
        CtxPtr ctx;
        RuleExecutor(CtxPtr ctx) : ctx(ctx) {}
        template <typename ...Args>
        void operator() (Args && ... args) const {
            ctx->handle(std::forward<Args>(args)...);
        }
    };
};

} // namespace details

class Strategy {
public:
    using ConnProvider = std::function<DatabasePtr(EndpointQuery, LogError)>;
    using RequestHandler = Connection::RequestHandler;
    using FetchHandler = Connection::FetchHandler;
    using UpdateHandler = Connection::UpdateHandler;
    using ExecuteHandler = Connection::ExecuteHandler;
    using QueryPtr = Query::QueryPtr;

    virtual ~Strategy() {}

    virtual void request(ConnProvider db, QueryPtr q, RequestHandler h, LogPtr log) const = 0;

    virtual void fetch(ConnProvider db, QueryPtr q, FetchHandler h, LogPtr log) const = 0;

    virtual void update(ConnProvider db, QueryPtr q, UpdateHandler h, LogPtr log) const = 0;

    virtual void execute(ConnProvider db, QueryPtr q, ExecuteHandler h, LogPtr log) const  = 0;
};

template <typename Rules>
class StrategyImpl : public Strategy {
public:
    StrategyImpl(Rules rules = Rules()) : caller(std::move(rules)) {}

    void request(ConnProvider db, QueryPtr q, RequestHandler h, LogPtr log) const override {
        caller.template run<details::Request>(std::move(db), q, h, log);
    }

    void fetch(ConnProvider db, QueryPtr q, FetchHandler h, LogPtr log) const override {
        caller.template run<details::Fetch>(std::move(db), q, h, log);
    }

    void update(ConnProvider db, QueryPtr q, UpdateHandler h, LogPtr log) const override {
        caller.template run<details::Update>(std::move(db), q, h, log);
    }

    void execute(ConnProvider db, QueryPtr q, ExecuteHandler h, LogPtr log) const override {
        caller.template run<details::Execute>(std::move(db), q, h, log);
    }

private:
    details::Caller<Rules> caller;
};

class StrategyProvider {
public:
    template <typename T>
    StrategyProvider(T autoStrategyRules)
    : ctx(std::make_shared<Ctx>(autoStrategyRules)) {}

    const Strategy & operator()(EndpointType type) const {
        switch(type.value()) {
            case EndpointType::master : return ctx->masterStrategy;
            case EndpointType::replica : return ctx->replicaStrategy;
            case EndpointType::lagReplica : return ctx->noLagReplicaStrategy;
            case EndpointType::noLagReplica : return ctx->lagReplicaStrategy;
            case EndpointType::automatic : return *(ctx->autoStartegy);
        }
        return  *(ctx->autoStartegy);
    }

private:
    struct Ctx {
        StrategyImpl<strategy::MasterOnly> masterStrategy;
        StrategyImpl<strategy::ReplicaOnly> replicaStrategy;
        StrategyImpl<Rules<rules::NoLagReplica>> noLagReplicaStrategy;
        StrategyImpl<Rules<rules::LagReplica>> lagReplicaStrategy;
        std::unique_ptr<Strategy> autoStartegy;
        template <typename T>
        Ctx( T autoRules )
        : autoStartegy(new StrategyImpl<std::decay_t<T>>(autoRules)){}
    };

    std::shared_ptr<Ctx> ctx;
};

class DatabaseImpl : public Database {
public:
    DatabaseImpl(Strategy::ConnProvider db, StrategyProvider provider, LogPtr log)
    : db(std::move(db)), provider(std::move(provider)), log(log) {}

    void request(const Query & q, RequestHandler h) const override {
        strategy(q).request(db, share(q), h, log);
    }

    void fetch(const Query & q, FetchHandler h) const override {
        strategy(q).fetch(db, share(q), h, log);
    }

    void update(const Query & q, UpdateHandler h) const override {
        strategy(q).update(db, share(q), h, log);
    }

    void execute(const Query & q, ExecuteHandler h) const override {
        strategy(q).execute(db, share(q), h, log);
    }

    void withinConnection(ConnectionHandler h) const override {
        db(EndpointQuery{}, LogError::enable)->withinConnection(h);
    }
private:
    Strategy::ConnProvider db;
    StrategyProvider provider;
    LogPtr log;

    const Strategy & strategy(const Query & q) const {
        return provider(q.endpointType());
    }

    Query::QueryPtr share(const Query & q) const { return q.clone(); }
};

struct DatabaseGenerator {
    Strategy::ConnProvider db;
    StrategyProvider provider;
    LogPtr log;

    DatabaseGenerator(Strategy::ConnProvider db,
            StrategyProvider provider, LogPtr log)
    : db(std::move(db)), provider(std::move(provider)), log(log) {}

    ConnectionPtr operator()(void) const {
        return boost::make_shared<DatabaseImpl>(db, provider, log);
    }

    DatabasePtr operator()(EndpointQuery endpoint, LogError logError) const {
        return db(endpoint, logError);
    }
};

} // namespace fallback
} // namespace database

namespace fb = database::fallback;

} // namespace pgg

#include <boost/asio/unyield.hpp>
