#include "impl.h"
#include <ymod_cache/error.h>
#include "simple_mapper.h"
#include "consistent_mapper.h"
#include "protocol.h"

#include <yplatform/exception.h>

#include <boost/functional/hash.hpp>
#include <boost/lexical_cast.hpp>
#include <sstream>
#include <boost/property_tree/ptree.hpp>
#include <boost/range/algorithm/transform.hpp>

namespace ymod_cache {
namespace memcached {

using boost::make_shared;

std::size_t hash_value(const segment& seg)
{
    return boost::hash_range(seg.begin(), seg.end());
}

impl::impl(yplatform::reactor& reactor)
    : yplatform::net::client_module<session>("memcached_client")
    , reactor_(yplatform::reactor::make_not_owning_copy(reactor))
{
}

void impl::init(const yplatform::ptree& xml)
{
    using yplatform::time_traits::milliseconds;

    impl_base::init(xml);

    // Configure memcached
    const yplatform::ptree& memcached_xml = xml.get_child("memcached");
    if (!yplatform::net::client_module<session>::open_client(memcached_xml, logger()))
        throw internal_error("invalid net client options") << BOOST_ERROR_INFO;

    recovery_timeout_ = memcached_xml.get("recovery_timeout", 60);
    max_hops_count_ = memcached_xml.get("max_hops_count", 0);
    get_queue_timeout_ = milliseconds(memcached_xml.get("get_queue_timeout", 0));
    has_queue_timeout_ = milliseconds(memcached_xml.get("has_queue_timeout", 0));
    remove_queue_timeout_ = milliseconds(memcached_xml.get("remove_queue_timeout", 0));
    set_queue_timeout_ = milliseconds(memcached_xml.get("set_queue_timeout", 0));

    // Add servers
    const string& servers_conf = memcached_xml.get<string>("servers");
    std::ifstream ifs(servers_conf.c_str());
    while(ifs.good())
    {
        string host;
        std::getline(ifs, host, ':');
        if (ifs.eof())
            break;
        string port_s;
        std::getline(ifs, port_s);
        const unsigned short port = boost::lexical_cast<unsigned short>(port_s);
        const std::size_t max_sessions_count = memcached_xml.get("max_connections", 1);
        const std::size_t max_queue_size = memcached_xml.get("max_queue_size", 1);

        YLOG_LOCAL(info) << "add server " << host << ":" << port;
        server_info::session_pools_t pools;
        for (const auto& pool : reactor_->get_pools()) {
          YLOG_LOCAL(info) << "add pool for server " << host << ":" << port << " io=" << static_cast<const void*>(pool->io())
                           << " max_sessions_count=" << max_sessions_count << " max_queue_size=" << max_queue_size;
            pools[pool->io()].reset(new session_pool(max_sessions_count, max_queue_size));
        }
        servers_.emplace_back(server_info(host, port, std::move(pools)));
        servers_.back().recovery_timer.reset(new yplatform::net::timer_t(*client_->get_io()));
    }
    if (servers_.empty())
        throw internal_error("no servers specified") << BOOST_ERROR_INFO;

    // Initialize hash mapper
    const yplatform::ptree& mapper_xml = memcached_xml.get_child("hash_mapper");
    const string& mapper_method = mapper_xml.get("<xmlattr>.method", "simple");
    if (mapper_method == "simple")
        hash_mapper_.reset(new simple_mapper<MAPPER_TEMPLATE_ARGS>(0, servers_.size()-1));
    else if (mapper_method == "consistent")
    {
        size_t dots_per_output = mapper_xml.get("dots_per_server", 100);
        uint32_t seed = mapper_xml.get<uint32_t>("seed");
        YLOG_LOCAL(info) << "using supplied seed: " << seed;

        hash_mapper_.reset(new consistent_mapper<MAPPER_TEMPLATE_ARGS>(
            0, servers_.size()-1,
            seed,
            dots_per_output
        ));
    }
    else
    {
        std::ostringstream ss;
        ss << "unknown hash mapping method: '" << mapper_method << "'";
        throw internal_error(ss.str()) << BOOST_ERROR_INFO;
    }
}

void impl::start()
{
    run_client();
}

void impl::stop()
{
    for (server_vector::iterator it = servers_.begin(); it != servers_.end(); ++it)
    {
        it->recovery_timer.reset();
    }
}

void impl::fini()
{
    stop_client();
    impl_base::fini();
}

void impl::server_recovery_hook(
    const boost::system::error_code& err,
    std::size_t server_id)
{
    if (err != boost::asio::error::operation_aborted)
    {
        hash_mapper_->unmask(server_id);
        server_info& server = servers_[server_id];
        YLOG_LOCAL(error) << "server '" << server.address << ":" << server.port << "' unmasked";
    }
}

future_result impl::set(
    task_context_ptr ctx,
    const segment& key,
    const segment& value)
{
    YLOG_CTX_LOCAL(ctx, debug) << "set key=" << std::string(key.begin(), key.end())
                     << " value=" << std::string(key.begin(), key.end());

    promise_result out;

    request_ptr req(new_request(PROTOCOL_BINARY_CMD_SET, key, value));
    req->variant.set.message.body.flags = 0;
    req->variant.set.message.body.expiration = value_ttl_;

    future_response resp_f = execute(ctx, req, set_queue_timeout_);
    resp_f.add_callback(boost::bind(&impl::set_cb, this, resp_f, out));
    return out;
}

future_segment impl::get(
    task_context_ptr ctx,
    const segment& key)
{
    YLOG_CTX_LOCAL(ctx, debug) << "get key=" << std::string(key.begin(), key.end());

    promise_segment out;

    request_ptr req(new_request(PROTOCOL_BINARY_CMD_GET, key, boost::none));

    future_response resp_f = execute(ctx, req, get_queue_timeout_);
    resp_f.add_callback(boost::bind(&impl::get_cb, this, resp_f, out));
    return out;
}

future_bool impl::has(
    task_context_ptr ctx,
    const segment& key)
{
    YLOG_CTX_LOCAL(ctx, debug) << "has key=" << std::string(key.begin(), key.end());

    promise_bool out;

    request_ptr req(new_request(PROTOCOL_BINARY_CMD_GET, key, boost::none));

    future_response resp_f = execute(ctx, req, has_queue_timeout_);
    resp_f.add_callback(boost::bind(&impl::has_cb, this, resp_f, out));
    return out;
}

future_result impl::remove(
    task_context_ptr ctx,
    const segment& key)
{
    YLOG_CTX_LOCAL(ctx, debug) << "remove key=" << std::string(key.begin(), key.end());

    promise_result out;

    request_ptr req(new_request(PROTOCOL_BINARY_CMD_DELETE, key, boost::none));

    future_response resp_f = execute(ctx, req, remove_queue_timeout_);
    resp_f.add_callback(boost::bind(&impl::remove_cb, this, resp_f, out));
    return out;
}

const impl::module_stats_ptr impl::get_module_stats() const {
    return boost::make_shared<impl_stats>(servers_);
}

future_response impl::execute(
    task_context_ptr ctx,
    const request_ptr& req,
    const time_duration& queue_timeout,
    unsigned hops_count,
    promise_response resp_p)
{
    std::size_t server_id;

    try {
        server_id = hash_mapper_->map(hash_value(req->key));
    } catch(const yplatform::exception &e) {
        resp_p.set_exception(e);
        return resp_p;
    } catch(...) {
        resp_p.set_current_exception();
        return resp_p;
    }

    const auto handler = [this, ctx, req, queue_timeout, server_id, resp_p, hops_count]
            (const boost::system::error_code& ec, session_handle h) {
        this->get_session_cb(ctx, req, queue_timeout, server_id, resp_p, hops_count, ec, std::move(h));
    };

    servers_[server_id].get_session(io(), handler, queue_timeout);

    return resp_p;
}

void impl::get_session_cb(task_context_ptr ctx,
    const request_ptr& req,
    const time_duration& queue_timeout,
    std::size_t server_id,
    promise_response resp_p,
    unsigned hops_count,
    const boost::system::error_code& error,
    session_handle sess_handle)
{
    using namespace yamail;

    if (error) {
        YLOG_CTX_LOCAL(ctx, error) << "get session error: " << error.message();
        resp_p.set(error_code::get_session_error);
    } else {
        if (sess_handle.empty()) {
            session_ptr session = create_session();
            const server_info& server = servers_[server_id];
            session->set_server(server.address, server.port);
            sess_handle.reset(std::move(session));
        } else {
            sess_handle.get()->is_new(false);
        }
        const auto& sess = sess_handle.get();
        YLOG_CTX_LOCAL(ctx, debug) << "use session=" << sess;
        sess->set_context(ctx);
        future_response sess_resp_f = sess->run(req);
        auto shared_sess_handle = std::make_shared<session_handle>(std::move(sess_handle));
        auto handler = [this, ctx, req, queue_timeout, server_id, sess_resp_f, resp_p, hops_count, shared_sess_handle] () mutable {
            this->execute_cb(ctx, req, queue_timeout, server_id, std::move(*shared_sess_handle), sess_resp_f, resp_p, hops_count);
        };
        sess_resp_f.add_callback(std::move(handler));
    }
}

void impl::execute_cb(
    task_context_ptr ctx,
    const request_ptr& req,
    const time_duration& queue_timeout,
    std::size_t server_id,
    session_handle sess_handle,
    future_response sess_resp_f,
    promise_response resp_p,
    unsigned hops_count)
{
    server_info& server = servers_[server_id];

    try {
        response_result resp = sess_resp_f.get();

        if (resp) {
            // Return open session to pool only if we get "good" response status.
            // "Good" means that sended request was properly parsed.
            switch(resp.value->variant.generic.message.header.response.status)
            {
            case PROTOCOL_BINARY_RESPONSE_SUCCESS:
            case PROTOCOL_BINARY_RESPONSE_KEY_ENOENT:
            case PROTOCOL_BINARY_RESPONSE_KEY_EEXISTS:
            case PROTOCOL_BINARY_RESPONSE_NOT_STORED:
            case PROTOCOL_BINARY_RESPONSE_ENOMEM:
                sess_handle.recycle();
            }
        }

        resp_p.set(resp);
    } catch (...) {
        if (hops_count >= max_hops_count_)
        {
            YLOG_CTX_LOCAL(ctx, error)
                << "server '"
                << server.address << ":" << server.port << "' failed";
            resp_p.set_current_exception();
            return;
        }

        if (sess_handle.get()->is_new()) {
            mask_server(ctx, server_id);
        }

        sess_handle.waste();

        execute(ctx, req, queue_timeout, ++hops_count, resp_p);
    }
}

void impl::set_cb(
    future_response resp_f,
    promise_result out)
{
    try
    {
        response_result result = resp_f.get();

        if (!result) {
            out.set(result.error);
            return;
        }

        response_ptr resp = result.value;

        if (resp->variant.generic.message.header.response.opcode != PROTOCOL_BINARY_CMD_SET)
        {
            // XXX: mask server in this case?
            out.set_exception(backend_error() << BOOST_ERROR_INFO << yplatform::error_private_info("bad response from server"));
        }
        switch (resp->variant.set.message.header.response.status)
        {
        case PROTOCOL_BINARY_RESPONSE_SUCCESS:
            out.set(void_result());
            return;
        default:
            std::stringstream ss;
            ss << "memcached: " << resp->value;
            out.set_exception(backend_error() << BOOST_ERROR_INFO << yplatform::error_private_info(ss.str()));
            return;
        }
    }
    catch (const ymod_cache::error& e)
    {
        out.set_exception(e);
    }
    catch(const std::exception& e)
    {
        out.set_exception(internal_error() << BOOST_ERROR_INFO << yplatform::error_private_info(e.what()));
    }
    catch(...)
    {
        out.set_current_exception();
    }
}

void impl::get_cb(
    future_response resp_f,
    promise_segment out)
{
    try
    {
        response_result result = resp_f.get();

        if (!result) {
            out.set(result.error);
            return;
        }

        response_ptr resp = result.value;

        if (resp->variant.generic.message.header.response.opcode != PROTOCOL_BINARY_CMD_GET)
        {
            // XXX: mask server in this case?
            out.set_exception(backend_error() << BOOST_ERROR_INFO << yplatform::error_private_info("bad response from server"));
        }
        switch (resp->variant.set.message.header.response.status)
        {
        case PROTOCOL_BINARY_RESPONSE_SUCCESS:
            out.set(resp->value);
            return;
        case PROTOCOL_BINARY_RESPONSE_KEY_ENOENT:
            out.set(boost::none);
            return;
        default:
            std::stringstream ss;
            ss << "memcached: " << resp->value;
            out.set_exception(backend_error() << BOOST_ERROR_INFO << yplatform::error_private_info(ss.str()));
            return;
        }
    }
    catch (const ymod_cache::error& e)
    {
        out.set_exception(e);
    }
    catch(const std::exception& e)
    {
        out.set_exception(internal_error() << BOOST_ERROR_INFO << yplatform::error_private_info(e.what()));
    }
    catch(...)
    {
        out.set_current_exception();
    }
}

void impl::has_cb(
    future_response resp_f,
    promise_bool out)
{
    try
    {
        response_result result = resp_f.get();

        if (!result) {
            out.set(result.error);
            return;
        }

        response_ptr resp = result.value;

        if (resp->variant.generic.message.header.response.opcode != PROTOCOL_BINARY_CMD_GET)
        {
            // XXX: mask server in this case?
            out.set_exception(backend_error() << BOOST_ERROR_INFO << yplatform::error_private_info("bad response from server"));
        }
        switch (resp->variant.set.message.header.response.status)
        {
        case PROTOCOL_BINARY_RESPONSE_SUCCESS:
            out.set(true);
            return;
        case PROTOCOL_BINARY_RESPONSE_KEY_ENOENT:
            out.set(false);
            return;
        default:
            std::stringstream ss;
            ss << "memcached: " << resp->value;
            out.set_exception(backend_error() << BOOST_ERROR_INFO << yplatform::error_private_info(ss.str()));
            return;
        }
    }
    catch (const ymod_cache::error& e)
    {
        out.set_exception(e);
    }
    catch(const std::exception& e)
    {
        out.set_exception(internal_error() << BOOST_ERROR_INFO << yplatform::error_private_info(e.what()));
    }
    catch(...)
    {
        out.set_current_exception();
    }
}

void impl::remove_cb(
    future_response resp_f,
    promise_result out)
{
    try
    {
        response_result result = resp_f.get();

        if (!result) {
            out.set(result.error);
            return;
        }

        response_ptr resp = result.value;

        if (resp->variant.generic.message.header.response.opcode != PROTOCOL_BINARY_CMD_DELETE)
        {
            // XXX: mask server in this case?
            out.set_exception(backend_error() << BOOST_ERROR_INFO << yplatform::error_private_info("bad response from server"));
        }
        switch (resp->variant.set.message.header.response.status)
        {
        case PROTOCOL_BINARY_RESPONSE_SUCCESS:
        case PROTOCOL_BINARY_RESPONSE_KEY_ENOENT:
            out.set(void_result());
            return;
        default:
            std::stringstream ss;
            ss << "memcached: " << resp->value;
            out.set_exception(backend_error() << BOOST_ERROR_INFO << yplatform::error_private_info(ss.str()));
            return;
        }
    }
    catch (const ymod_cache::error& e)
    {
        out.set_exception(e);
    }
    catch(const std::exception& e)
    {
        out.set_exception(internal_error() << BOOST_ERROR_INFO << yplatform::error_private_info(e.what()));
    }
    catch(...)
    {
        out.set_current_exception();
    }
}

void impl::mask_server(task_context_ptr ctx, std::size_t server_id)
{
    const server_info& server = servers_[server_id];

    YLOG_CTX_LOCAL(ctx, error) << "server '"
        << server.address << ":" << server.port << "' failed, masking it for "
        << recovery_timeout_ << " seconds";
    hash_mapper_->mask(server_id);
    server.recovery_timer->expires_from_now(yplatform::time_traits::seconds(recovery_timeout_));
    server.recovery_timer->async_wait(boost::bind(
        &impl::server_recovery_hook,
        this,
        boost::asio::placeholders::error,
        server_id
    ));
}

ptree make_session_pool_stats(const session_pool& pool)
{
    ptree stats;

    stats.put("size", pool.size());
    stats.put("available", pool.available());
    stats.put("used", pool.used());
    stats.put("queue_size", pool.impl().queue().size());

    return stats;
}

ptree make_session_pool_stats(const server_info::session_pools_t& pools)
{
    ptree stats;

    for (const auto& v : pools) {
        std::ostringstream key;
        key << "pool_" << v.first;
        stats.add_child(key.str(), make_session_pool_stats(*v.second));
    }

    return stats;
}

ptree make_server_stats(const server_info& server)
{
    ptree stats;

    stats.add_child("session_pool", make_session_pool_stats(server.pools));

    return stats;
}

string make_server_key(const server_info& server)
{
    std::stringstream stream;
    stream << server.address << ":" << server.port;
    return stream.str();
}

void ptree_inserter(ptree& ptree, const server_info& server)
{
    ptree.push_back(make_pair(make_server_key(server), make_server_stats(server)));
}

impl_stats::ptree_ptr impl_stats::core_get_stat() const
{
    ptree_ptr result = base::core_get_stat();

    std::for_each(servers_.begin(), servers_.end(),
            boost::bind(ptree_inserter, ref(*result), _1));

    return result;
}
}}
