#pragma once

#include "settings.h"
#include "thread_safe_accumulator_wrapper.h"
#include "conninfo.h"
#include <ymod_pq/call.h>
#include <yplatform/log/contains_logger.h>

namespace ymod_pq {
namespace detail {
inline std::string rate_str(double rate)
{
    std::ostringstream os;
    os << std::setprecision(3) << std::fixed << rate;
    return os.str();
}
}

template <typename CallImpl, typename ReplicaMonitor, std::time_t (*time)(std::time_t*) = std::time>
class basic_cluster : public yplatform::log::contains_logger
{
    using call_impl = CallImpl;
    using call_impl_ptr = std::shared_ptr<call_impl>;
    using call_impl_factory = std::function<call_impl_ptr(const ymod_pq::settings&)>;
    using replica_monitor = ReplicaMonitor;
    using replica_monitor_ptr = std::shared_ptr<replica_monitor>;
    using replica_info = typename replica_monitor::replica;

    struct rate_stat
    {
        uint64_t requests = 0;
        uint64_t errors = 0;

        rate_stat& operator+=(const rate_stat& rhs)
        {
            requests += rhs.requests;
            errors += rhs.errors;
            return *this;
        }
    };

    using rate_accumulator = yplatform::per_second_accumulator<rate_stat>;
    using wrapped_rate_accumulator = thread_safe_accumulator_wrapper<rate_accumulator>;
    using rate_accumulator_ptr = std::shared_ptr<wrapped_rate_accumulator>;
    using duration = yplatform::time_traits::duration;
    using lag_source = std::function<duration(const std::string&)>;
    using lock_guard = std::lock_guard<std::mutex>;

public:
    struct settings
    {
        ymod_pq::settings master;
        ymod_pq::settings replica;
        ymod_pq::settings fallback;
        double enable_fallback_error_rate = 1.0;
        double disable_fallback_error_rate = 0.0;
        double probe_rps = 0.0;
        unsigned rate_accumulator_window = 5;
        bool auto_fallback_enabled = false;
    };

    enum class host_state
    {
        ok,
        fallback,
        forced_fallback
    };

    basic_cluster(
        const std::string& conninfo_string,
        const call_impl_factory& make_impl,
        const replica_monitor_ptr& replica_monitor,
        const settings& settings)
        : conninfo_(conninfo_string)
        , dbname_(conninfo_.get("dbname"))
        , make_impl_(make_impl)
        , replica_monitor_(replica_monitor)
        , settings_(settings)
        , auto_fallback_enabled_(settings.auto_fallback_enabled)
        , stats_update_ts_(time(nullptr))
    {
        using namespace std::string_literals;

        auto master_settings = set_conninfo(settings_.master, conninfo_);
        auto default_replica_conninfo = make_replica_conninfo(conninfo_.get("host"));
        auto default_replica_settings = set_conninfo(settings_.replica, default_replica_conninfo);
        master_ = host{ "master", conninfo_,      make_impl(master_settings),
                        true,     host_state::ok, make_rate() };
        default_replica_ = host{ "default_replica",
                                 default_replica_conninfo,
                                 make_impl(default_replica_settings),
                                 false,
                                 host_state::ok,
                                 make_rate() };
        for (auto& name : conninfo_.hosts())
        {
            auto replica_conninfo = make_replica_conninfo(name);
            auto replica_settings = set_conninfo(settings_.replica, replica_conninfo);
            replicas_.push_back(host{ name,
                                      replica_conninfo,
                                      make_impl(replica_settings),
                                      false,
                                      host_state::ok,
                                      make_rate() });
        }
        if (replicas_.empty())
        {
            throw std::runtime_error("no replicas could be extracted from master conninfo");
        }
    }

    void start()
    {
        replica_monitor_->start();
    }

    void stop()
    {
        replica_monitor_->stop();
    }

    future_result request(
        yplatform::task_context_ptr ctx,
        const std::string& request,
        bind_array_ptr bind_vars,
        response_handler_ptr handler,
        bool log_timings,
        const yplatform::time_traits::duration& deadline,
        request_target target)
    {
        auto [pq, rate] = get_pq(ctx, target);
        return count_errors<promise_result>(
            rate, pq->request(ctx, "", request, bind_vars, handler, log_timings, deadline));
    }

    future_up_result update(
        yplatform::task_context_ptr ctx,
        const std::string& request,
        bind_array_ptr bind_arr,
        bool log_timings,
        const yplatform::time_traits::duration& deadline,
        request_target target)
    {
        auto [pq, rate] = get_pq(ctx, target);
        return count_errors<promise_up_result>(
            rate, pq->update(ctx, "", request, bind_arr, log_timings, deadline));
    }

    future_result execute(
        yplatform::task_context_ptr ctx,
        const std::string& request,
        bind_array_ptr bind_arr,
        bool log_timings,
        const yplatform::time_traits::duration& deadline,
        request_target target)
    {
        auto [pq, rate] = get_pq(ctx, target);
        return count_errors<promise_result>(
            rate, pq->execute(ctx, "", request, bind_arr, log_timings, deadline));
    }

    auto get_stats() const
    {
        // Copy relevant state under lock, construct tree after unlocking.
        std::unique_lock<std::mutex> l(mutex_);
        auto master = master_;
        auto default_replica = default_replica_;
        auto replicas = replicas_;
        auto replicas_info = replica_monitor_->get();
        l.unlock();

        yplatform::ptree stats;
        stats.put_child("master", get_stats(master));
        stats.put_child("default_replica", get_stats(default_replica));
        for (auto& host : replicas)
        {
            auto key = host.name;
            std::replace(key.begin(), key.end(), '.', '_');
            stats.put_child(key, get_stats(host));
        }
        stats.put_child("monitor", get_replica_monitor_stats(replicas_info));
        return stats;
    }

    auto get_health()
    {
        auto now = time(nullptr);
        lock_guard l(mutex_);
        update_stats(boost::make_shared<yplatform::task_context>("get_health"), now);
        cluster_health ret;
        ret.master = get_health(master_);
        for (auto& host : replicas_)
        {
            ret.replicas[host.conninfo.to_string()] = get_health(host);
        }
        return ret;
    }

    void enable_auto_fallback(const yplatform::task_context_ptr& ctx)
    {
        lock_guard l(mutex_);
        auto_fallback_enabled_ = true;
        YLOG_CTX_LOCAL(ctx, info) << "enabled auto fallback for " << dbname_;
    }

    void disable_auto_fallback(const yplatform::task_context_ptr& ctx)
    {
        lock_guard l(mutex_);
        auto_fallback_enabled_ = false;
        YLOG_CTX_LOCAL(ctx, info) << "disabled auto fallback for " << dbname_;
    }

    void enable_fallback(const yplatform::task_context_ptr& ctx, request_target t)
    {
        lock_guard l(mutex_);
        bool modify_master = t == request_target::master || t == request_target::try_master;
        if (modify_master)
        {
            enable_fallback(ctx, master_, true);
        }
        else
        {
            enable_fallback(ctx, default_replica_, true);
            for (auto& replica : replicas_)
            {
                enable_fallback(ctx, replica, true);
            }
        }
    }

    void disable_fallback(const yplatform::task_context_ptr& ctx, request_target t)
    {
        lock_guard l(mutex_);
        bool modify_master = t == request_target::master || t == request_target::try_master;
        if (modify_master)
        {
            disable_fallback(ctx, master_);
        }
        else
        {
            disable_fallback(ctx, default_replica_);
            for (auto& replica : replicas_)
            {
                disable_fallback(ctx, replica);
            }
        }
    }

private:
    struct host
    {
        std::string name;
        conninfo conninfo;
        call_impl_ptr pq;
        bool is_master;
        host_state state;
        rate_accumulator_ptr rate;
        double error_rate = 0.0;
        double rps = 0.0;

        bool in_fallback()
        {
            return state == host_state::fallback || state == host_state::forced_fallback;
        }
    };

    std::tuple<call_impl_ptr, rate_accumulator_ptr> get_pq(
        const yplatform::task_context_ptr& ctx,
        request_target target)
    {
        auto now = time(nullptr);
        lock_guard l(mutex_);
        update_stats(ctx, now);
        host& host = get_host(target);
        return std::make_tuple(host.pq, host.rate);
    }

    host& get_host(request_target target)
    {
        switch (target)
        {
        case request_target::master:
            return master_;
        case request_target::try_master:
            return try_master();
        case request_target::try_replica:
            return try_replica();
        case request_target::replica:
            return best_replica();
        }
    }

    host& try_master()
    {
        if (!master_.in_fallback() || master_.rps < settings_.probe_rps)
        {
            return master_;
        }
        auto& replica = best_replica();
        if (!replica.in_fallback())
        {
            return replica;
        }
        return master_;
    }

    host& try_replica()
    {
        auto& replica = best_replica();
        if (!replica.in_fallback() || replica.rps < settings_.probe_rps)
        {
            return replica;
        }
        if (!master_.in_fallback())
        {
            return master_;
        }
        return replica;
    }

    host& best_replica()
    {
        auto p_best = replicas_.end();
        duration min_lag = duration::max();
        for (auto it = replicas_.begin(); it != replicas_.end(); ++it)
        {
            auto lag = get_replica_lag(it->name);
            if (lag < min_lag && !it->in_fallback())
            {
                p_best = it;
                min_lag = lag;
            }
        }
        return p_best != replicas_.end() ? *p_best : default_replica_;
    }

    yplatform::time_traits::duration get_replica_lag(const std::string& hostname) const
    {
        auto replicas = replica_monitor_->get();
        for (auto& replica : replicas)
        {
            if (replica.host == hostname)
            {
                return replica.lag;
            }
        }
        return yplatform::time_traits::duration::max();
    }

    void update_stats(const yplatform::task_context_ptr& ctx, time_t now)
    {
        if (stats_update_ts_ == now)
        {
            return;
        }
        stats_update_ts_ = now;
        update_stats(ctx, master_);
        update_stats(ctx, default_replica_);
        for (auto& replica : replicas_)
        {
            update_stats(ctx, replica);
        }
    }

    void update_stats(const yplatform::task_context_ptr& ctx, host& host)
    {
        update_rates(host);
        update_fallback_state(ctx, host);
    }

    void update_fallback_state(const yplatform::task_context_ptr& ctx, host& host)
    {
        if (!auto_fallback_enabled_ || host.state == host_state::forced_fallback)
        {
            return;
        }
        if (host.state == host_state::ok && host.error_rate > settings_.enable_fallback_error_rate)
        {
            enable_fallback(ctx, host, false);
        }
        if (host.state == host_state::fallback &&
            host.error_rate < settings_.disable_fallback_error_rate)
        {
            disable_fallback(ctx, host);
        }
    }

    void enable_fallback(const yplatform::task_context_ptr& ctx, host& host, bool forced)
    {
        using namespace std::string_literals;
        std::string forced_str = forced ? "forced "s : "";
        YLOG_CTX_LOCAL(ctx, info) << "enabling " << forced_str
                                  << "fallback, error_rate: " << host.error_rate
                                  << " rps: " << host.rps << " dbname " << dbname_;
        host.pq = make_impl_(set_conninfo(settings_.fallback, host.conninfo));
        host.state = forced ? host_state::forced_fallback : host_state::fallback;
    }

    void disable_fallback(const yplatform::task_context_ptr& ctx, host& host)
    {
        YLOG_CTX_LOCAL(ctx, info) << "disabling fallback, error_rate: " << host.error_rate
                                  << " rps: " << host.rps << " dbname " << dbname_;
        host.pq = host.is_master ? make_impl_(set_conninfo(settings_.master, host.conninfo)) :
                                   make_impl_(set_conninfo(settings_.replica, host.conninfo));
        host.state = host_state::ok;
    }

    template <typename Promise, typename Future>
    Promise count_errors(const rate_accumulator_ptr& rate, Future pq_fres)
    {
        Promise prom;
        pq_fres.add_callback([rate, pq_fres, prom]() mutable {
            try
            {
                prom.set(pq_fres.get());
                rate->add(time(nullptr), { 1, 0 });
            }
            catch (...)
            {
                prom.set_exception(std::current_exception());
                rate->add(time(nullptr), { 1, 1 });
            }
        });
        return prom;
    }

    rate_accumulator_ptr make_rate()
    {
        return std::make_shared<wrapped_rate_accumulator>(
            std::make_shared<rate_accumulator>(settings_.rate_accumulator_window));
    }

    void update_rates(host& host)
    {
        auto rate_counts = host.rate->get_sum(time(nullptr), settings_.rate_accumulator_window);
        if (rate_counts.requests != 0)
        {
            host.error_rate = rate_counts.errors / (double)rate_counts.requests;
            host.rps = rate_counts.requests / (double)settings_.rate_accumulator_window;
        }
        else
        {
            host.error_rate = 0.0;
            host.rps = 0.0;
        }
    }

    ymod_pq::settings set_conninfo(ymod_pq::settings st, const conninfo& c)
    {
        st.conninfo = c.to_string();
        return st;
    }

    conninfo make_replica_conninfo(const std::string& host)
    {
        auto replica_conninfo = conninfo_;
        replica_conninfo.reset("target_session_attrs");
        replica_conninfo.set("host", host);
        return replica_conninfo;
    }

    yplatform::ptree get_stats(const host& host) const
    {
        yplatform::ptree host_tree;
        host_tree.put("conninfo", host.conninfo.to_string());
        host_tree.put("is_master", host.is_master);
        if (!host.is_master)
        {
            host_tree.put("lag", get_replica_lag(host.name));
        }
        host_tree.put("state", to_string(host.state));
        host_tree.put("error_rate", detail::rate_str(host.error_rate));
        host_tree.put("rps", detail::rate_str(host.rps));
        host_tree.put_child("pq", host.pq->get_stats());
        return host_tree;
    }

    auto get_replica_monitor_stats(std::vector<replica_info> replicas) const
    {
        yplatform::ptree ret;
        yplatform::ptree replica_stat;
        for (auto& r : replicas)
        {
            std::replace(r.host.begin(), r.host.end(), '.', '_');
            replica_stat.put(r.host, r.lag);
        }
        ret.put_child("replicas", replica_stat);
        return ret;
    }

    host_health get_health(const host& h)
    {
        return { h.state != host_state::ok, h.rps, h.error_rate };
    }

    std::string to_string(host_state s) const
    {
        using namespace std::string_literals;
        switch (s)
        {
        case host_state::ok:
            return "ok"s;
        case host_state::fallback:
            return "fallback"s;
        case host_state::forced_fallback:
            return "forced_fallback"s;
        }
        return "unknown"s;
    }

    const conninfo conninfo_;
    const std::string dbname_;
    const call_impl_factory make_impl_;
    const replica_monitor_ptr replica_monitor_;
    const settings settings_;

    host master_;
    host default_replica_;
    std::vector<host> replicas_;
    bool auto_fallback_enabled_;
    std::time_t stats_update_ts_;
    mutable std::mutex mutex_;
};

}