#include "session.h"

#include <yplatform/encoding/url_encode.h>
#include <yplatform/util/split.h>

namespace yxiva { namespace web { namespace websocket_rpc {

static void send_bad_request(websocket_stream_ptr stream, const std::string& reason)
{
    static const string EMPTY;
    send_response(*stream, EMPTY, ymod_webserver::codes::bad_request, reason);
}

static bool json_to_string(const json_value& in, string& out)
{
    if (in.is_string() || in.is_number())
    {
        out.assign(in.to_string());
        return true;
    }
    else if (in.is_object())
    {
        out.assign(in.stringify());
        return true;
    }
    return false;
}

static bool unlimited(const ymod_webserver::request& req)
{
    return req.url.make_full_path() == "ping";
}

static void send_limit_exceeded(websocket_stream_ptr& stream)
{
    stream->close_connection(
        ymod_webserver::websocket::codes::too_many_requests, "messages limit exceeded");
}

session::session(
    const websocket_stream_ptr& ws_stream,
    settings_ptr settings,
    std::shared_ptr<dispatcher> dispatcher,
    std::shared_ptr<stats> stats,
    const yplatform::log::source& logger,
    const yplatform::log::tskv_logger& typed_logger)
    : yplatform::log::contains_logger(logger)
    , strand_(ws_stream->get_io_service())
    , ctx_(ws_stream->ctx())
    , ws_stream_(ws_stream)
    , settings_(settings)
    , dispatcher_(dispatcher)
    , stats_(stats)
    , typed_logger_(typed_logger)
{
    ++stats_->sessions;
    ++stats_->active_sessions;
    YLOG_CTX_LOCAL(ctx_, info) << "new RPC session";
}

session::~session()
{
    try
    {
        shutdown_nosync();
        --stats_->active_sessions;
        // TODO add stats, duration
        YLOG_CTX_LOCAL(ctx_, info) << "RPC session closed";
    }
    catch (...)
    {
    }
}

void session::run()
{
    auto self = shared_from_this();
    auto locked_ws_stream = ws_stream_.lock();
    if (!locked_ws_stream) return;

    // Exceptions in callbacks are caught by webserver.
    locked_ws_stream->set_close_callback(
        strand_.wrap([this, self](uint16_t code, const string& reason) {
            for (auto& pair : streams_)
            {
                if (auto stream = pair.second.lock())
                {
                    stream->close(code, reason);
                }
            }
        }));
    locked_ws_stream->add_message_callback(strand_.wrap([this, self](const ws_message_t& msg) {
        if (msg.opcode == ws_message_t::opcode_text ||
            msg.opcode == ws_message_t::opcode_continuation)
        {
            input_buffer_.append(msg.data.begin(), msg.data.end());
            if (msg.is_finished())
            {
                execute(input_buffer_);
                input_buffer_.clear();
            }
        }
    }));
    locked_ws_stream->begin_receive();
}

void session::shutdown()
{
    strand_.dispatch([this, self = shared_from_this()]() { shutdown_nosync(); });
}

void session::shutdown_nosync()
{
    for (auto& pair : streams_)
    {
        if (auto stream = pair.second.lock())
        {
            stream->close(1001, "goaway");
        }
    }
    streams_.clear();
    if (auto locked_ws_stream = ws_stream_.lock())
    {
        locked_ws_stream->close_connection(
            ymod_webserver::websocket::codes::close_opcode_go_away, "goaway");
    }
}

operation::result session::make_stream(
    const websocket_stream_ptr& stream,
    const json_value& json_msg,
    std::shared_ptr<websocket_rpc::stream>& rpc_stream)
{
    auto req = boost::make_shared<request>();
    req->context = boost::make_shared<ymod_webserver::context>();
    req->context->local_address = stream->ctx()->local_address;
    req->context->local_port = stream->ctx()->local_port;
    req->context->remote_address = stream->ctx()->remote_address;
    req->context->remote_port = stream->ctx()->remote_port;
    req->context->start_time = stream->ctx()->start_time;
    req->context->state = stream->ctx()->state;
    req->context->set_request(req);
    req->headers = stream->request()->headers;
    req->origin = stream->request()->origin;
    req->url.host.domain = stream->request()->url.host.domain;
    req->url.host.port = stream->request()->url.host.port;
    string full_path;
    full_path = json_get<string>(json_msg, "method", full_path);
    yplatform::util::split(req->url.path, full_path, "/");

    auto&& json_params = json_msg["params"];
    if (!json_params.is_object()) return "invalid params node type";
    for (auto it = json_params.members_begin(); it != json_params.members_end(); ++it)
    {
        auto&& key = it.key();
        if (!(*it).is_string())
        {
            return "invalid param value type";
        }
        req->url.params.emplace(string(key), (*it).to_string());
    }

    auto json_request_id = json_msg["id"];
    string request_id;
    if (!json_request_id.is_null() && !json_to_string(json_request_id, request_id))
    {
        return "invalid request id type";
    }

    auto stream_logger = logger();
    stream_logger.set_log_prefix(stream_logger.get_log_prefix() + ctx_->uniq_id());
    rpc_stream.reset(
        new websocket_rpc::stream(stream, request_id, req, stream_logger, typed_logger_),
        stream_deleter{ shared_from_this() });
    return operation::success;
}

void session::execute(const string& websocket_rpc_request)
{
    std::shared_ptr<websocket_rpc::stream> rpc_stream;
    auto locked_ws_stream = ws_stream_.lock();

    if (message_limit_exceeded())
    {
        send_limit_exceeded(locked_ws_stream);
        ws_stream_.reset();
        return;
    }

    json_value json_msg;
    if (auto error = json_msg.parse(websocket_rpc_request))
    {
        send_bad_request(locked_ws_stream, *error);
        return;
    }

    if (!json_msg["result"].is_null() || !json_msg["error"].is_null())
    {
        // Session and stream share knowledge about response routing by stream ctx id.
        bool processed = false;
        string stream_id;
        if (json_msg["error"].is_null() && json_to_string(json_msg["id"], stream_id) &&
            route_incoming_message(stream_id, json_msg))
        {
            processed = true;
        }

        if (!processed)
        {
            // TODO limit response length
            // stringify to drop original formatting.
            YLOG_CTX_LOCAL(locked_ws_stream->ctx(), info)
                << "got unexpected response: " << json_msg.stringify();
        }
        return;
    }

    auto decoded = make_stream(locked_ws_stream, json_msg, rpc_stream);
    if (!decoded)
    {
        send_bad_request(locked_ws_stream, decoded.error_reason);
        return;
    }
    ++stats_->streams;
    ++stats_->active_streams;

    if (rate_limit_exceeded(rpc_stream))
    {
        rpc_stream->result(ymod_webserver::codes::too_many_requests, "rate limit");
        return;
    }

    stats_->accept_request();

    streams_.emplace(rpc_stream->ctx()->uniq_id(), rpc_stream);
    (*dispatcher_)(rpc_stream);
}

bool session::message_limit_exceeded()
{
    ++stats_->messages_received;
    return ++messages_received_ > settings_->websocket_rpc.session_max_messages;
}

bool session::rate_limit_exceeded(const std::shared_ptr<websocket_rpc::stream>& rpc_stream)
{
    auto current_rps = stats_->new_request();
    return !unlimited(*rpc_stream->request()) &&
        current_rps > static_cast<int>(settings_->rps_limit);
}

bool session::route_incoming_message(const string& stream_id, const json_value& json_msg)
{
    auto it = streams_.find(stream_id);
    if (it != streams_.end())
    {
        if (auto stream = it->second.lock())
        {
            // Pass exceptions to execute() caller.
            stream->pick_incoming_message(json_msg["result"]);
            return true;
        }
    }
    return false;
}

}}}
