#include "subscribe.h"

#include "web/auth/user.h"
#include "web/auth/secure_sign.h"
#include "web/extract_requests.h"
#include "web/sse_subscriber.h"
#include "web/websocket_subscriber.h"
#include "web/websocket_rpc_subscriber.h"
#include "web/control_session.h"
#include "web/auth/oauth.h"
#include "web/rproxy_enable.h"
#include "web/utils/uniq.h"
#include <yxiva/core/iabstract.h>
#include <yxiva/core/methods/request_adaptor.h>
#include <yxiva/core/user_info.h>
#include <yxiva/core/authorizer.h>
#include "web/formatters/interface.h"
#include <yplatform/find.h>
#include <ymod_blackbox/client.h>
#include <algorithm>

namespace yxiva { namespace web {

namespace ph = std::placeholders;

namespace {

enum class protocol_type
{
    sse,
    websocket,
    websocket_rpc
};

struct task
{
    std::vector<string> uids_list;
    service_with_filter_list services_with_filters;
    std::vector<std::shared_ptr<const service_data>> services_data;
    std::vector<string> services;
    std::vector<string> strong_order_services;
    string topic;
    string client;
    string session;
    string oauth_token;
    string cookie;
    string secret_sign;
    std::time_t sign_ts = 0;
    string bb_connection_id;
    fetch_history history;
    bool strict_history = false;
    bool watch_subscribers = false;
    string filter;
};

}

typedef boost::shared_ptr<task> task_ptr;

inline task_ptr make_task()
{
    return boost::make_shared<task>();
}

typedef ymod_webserver::websocket::output_stream_ptr websocket_stream_ptr;

template <protocol_type, typename Stream>
timer_ptr create_timer(Stream stream)
{
    return stream->make_timer();
}

template <typename Stream = http_stream_ptr>
void send_not_allowed(Stream stream, ymod_webserver::context_ptr ctx)
{
    YLOG_CTX_GLOBAL(ctx, info) << "method not allowed";
    stream->result(http_codes::method_not_allowed, "");
}

template <>
void send_not_allowed<websocket_stream_ptr>(
    websocket_stream_ptr /*stream*/,
    ymod_webserver::context_ptr /*ctx*/)
{
}

template <protocol_type, typename Stream>
void send_internal_error(Stream stream, const string& info, ymod_webserver::context_ptr ctx)
{
    YLOG_CTX_GLOBAL(ctx, info) << "internal server error: " << info;
    stream->result(http_codes::internal_server_error, info);
}

template <>
void send_internal_error<protocol_type::websocket, websocket_stream_ptr>(
    websocket_stream_ptr stream,
    const string& info,
    ymod_webserver::context_ptr ctx)
{
    YLOG_CTX_GLOBAL(ctx, info) << "internal server error: " << info;
    json_value message;
    message["error"] = info;
    stream->close_connection(websocket_codes::internal_server_error, message.stringify());
}

template <typename Stream>
void send_attached(Stream /*stream*/, web_subscriber_ptr /*subscriber*/)
{
}

template <>
void send_attached<std::shared_ptr<websocket_rpc::stream>>(
    std::shared_ptr<websocket_rpc::stream> stream,
    web_subscriber_ptr subscriber)
{
    json_value result;
    result["operation"] = "attached";
    result["subscription_token"] = subscriber->id();
    stream->result_json(http_codes::ok, result);
}

template <protocol_type, typename Stream>
web_subscriber_ptr create_subscriber(
    Stream stream,
    settings_ptr settings,
    formatters::formatter_ptr formatter,
    const task_ptr& task);

template <>
web_subscriber_ptr create_subscriber<protocol_type::sse, http_stream_ptr>(
    http_stream_ptr stream,
    settings_ptr /*settings*/,
    formatters::formatter_ptr formatter,
    const task_ptr& task)
{
    auto req = stream->request();
    stream->set_code(http_codes::ok);
    stream->add_header("Access-Control-Allow-Origin", "*");
    stream->add_header(
        "Cache-Control",
        "max-age=0, must-revalidate, proxy-revalidate, no-cache, no-store, private");
    stream->add_header("Expires", "Thu, 01 Jan 1970 00:00:01 GMT");
    stream->add_header("Pragma", "no-cache");
    stream->set_content_type("text", "event-stream");
    stream->set_connection(false);

    yplatform::net::streamable_ptr chunked_stream = stream->result_chunked();
    shared_ptr<sse_subscriber> subscriber = ::yxiva::make_shared<sse_subscriber>(
        req->context, chunked_stream, formatter, task->sign_ts, task->client);
    return subscriber;
}

template <>
web_subscriber_ptr create_subscriber<protocol_type::websocket, websocket_stream_ptr>(
    websocket_stream_ptr stream,
    settings_ptr settings,
    formatters::formatter_ptr formatter,
    const task_ptr& task)
{
    auto req = stream->request();
    shared_ptr<websocket_subscriber> subscriber = ::yxiva::make_shared<websocket_subscriber>(
        req->context,
        stream,
        formatter,
        task->sign_ts,
        settings->rand_websocket_inactive_timeout(),
        task->client);
    subscriber->init();
    return subscriber;
}

template <>
web_subscriber_ptr create_subscriber<
    protocol_type::websocket_rpc,
    std::shared_ptr<websocket_rpc::stream>>(
    std::shared_ptr<websocket_rpc::stream> stream,
    settings_ptr /*settings*/,
    formatters::formatter_ptr formatter,
    const task_ptr& task)
{
    auto req = stream->request();
    auto subscriber = ::yxiva::make_shared<websocket_rpc_subscriber>(
        req->context, stream, formatter, task->sign_ts, task->client);
    subscriber->init();
    return subscriber;
}

template <protocol_type>
time_duration get_ping_interval(settings_ptr settings);
template <>
time_duration get_ping_interval<protocol_type::sse>(settings_ptr settings)
{
    return settings->ss_event_ping_interval;
}
template <>
time_duration get_ping_interval<protocol_type::websocket>(settings_ptr settings)
{
    return settings->websocket_ping_interval;
}
template <>
time_duration get_ping_interval<protocol_type::websocket_rpc>(settings_ptr settings)
{
    return settings->websocket_ping_interval;
}

inline std::vector<user_info>::const_iterator find_ui(
    const std::vector<user_info>& vinfo,
    const string& uid_to_search)
{
    for (auto it = vinfo.begin(); it != vinfo.end(); ++it)
    {
        string item_uid = static_cast<string>(it->uid);
        if (item_uid == uid_to_search) return it;
    }
    return vinfo.end();
}

void add_history_args_to_channel_keys(const task_ptr& task, multi_channels_set& requests)
{
    for (auto& req : requests)
    {
        auto uid_or_topic = task->topic.size() ? task->topic : req.ui.uid;
        for (auto& ch : req.channels)
        {
            for (auto& h : task->history)
            {
                if (ch.service != service_name(h.service)) continue;
                if (uid_or_topic != h.uid) continue;
                ch.position = h.position;
                ch.history_count = h.count;
                ch.strict_position = task->strict_history;
                ch.fetch_history = true;
                break;
            }
        }
    }
}

void set_watch_subscribers(
    const task_ptr& task,
    settings_ptr settings,
    multi_channels_set& requests)
{
    if (!task->watch_subscribers) return;
    for (auto& req : requests)
    {
        for (auto& channel : req.channels)
        {
            if (settings->api.watch_subscribers.enabled_for.count(channel.service))
            {
                channel.watch_subscribers = true;
            }
        }
    }
}

template <protocol_type Proto>
void set_subscription_id_generators(settings_ptr settings, multi_channels_set& requests)
{
    if constexpr (Proto != protocol_type::websocket && Proto != protocol_type::websocket_rpc)
    {
        return;
    }

    auto&& whitelist = settings->service_features.session_as_ws_subscription_id;
    for (auto& req : requests)
    {
        for (auto&& channel : req.channels)
        {
            if (whitelist.enabled_for(channel.service))
            {
                channel.make_sub_id = websocket_id_generator::make_from_session;
            }
        }
    }
}

size_t task_services_string_length(const std::vector<string>& services)
{
    return std::accumulate(services.begin(), services.end(), 0, [](int sum, const string& str) {
        return sum + str.size() + 1;
    });
}

string task_services_as_string(task_ptr task)
{
    string result;
    result.reserve(task_services_string_length(task->services));
    for (auto&& service : task->services)
    {
        if (result.size())
        {
            result += ",";
        }
        result += service;
    }
    return result;
}

template <protocol_type Proto, typename Stream>
void subscribe_impl(Stream stream, settings_ptr settings, task_ptr task)
{
    try
    {
        auto req = stream->request();
        req->context->custom_log_data["service"] = task_services_as_string(task);

        multi_channels_set expanded_requests;
        expanded_requests.reserve(task->uids_list.size());
        for (auto& uid : task->uids_list)
        {
            channels_set channels_by_user;
            for (auto& service : task->services_with_filters)
            {
                auto&& filter = task->filter.size() ? task->filter : service.filter;
                channels_by_user.push_back({ service_name(service.name),
                                             filter,
                                             task->client,
                                             task->session,
                                             task->bb_connection_id });
            }
            user_info user{ user_id(uid) };
            expanded_requests.emplace_back(user, std::move(channels_by_user));
        }

        if (expanded_requests.empty())
        {
            send_bad_request(stream, "no such channels");
            return;
        }

        if (task->topic.size())
        {
            assert(expanded_requests.size() == 1);
            expanded_requests[0].ui.uid = encode_topic_name(task->topic);
        }

        auto future_close = make_future_close(stream);

        auto formatters_kit = find_processor()->formatters();
        if (!formatters_kit->has("json"))
        {
            send_internal_error<Proto>(stream, "formatter was not found", req->ctx());
            return;
        }
        auto formatter = formatters_kit->get("json");

        add_history_args_to_channel_keys(task, expanded_requests);
        set_watch_subscribers(task, settings, expanded_requests);
        set_subscription_id_generators<Proto>(settings, expanded_requests);

        auto subscriber = create_subscriber<Proto>(stream, settings, formatter, task);
        req->context->custom_log_data["subscriber"] = subscriber->ctx()->uniq_id();

        control_session(
            expanded_requests,
            subscriber,
            create_timer<Proto>(stream),
            get_ping_interval<Proto>(settings),
            future_close);
        send_attached(stream, subscriber);

        if (task->uids_list.size() == 1 && task->services.size())
        {
            auto&& uid = task->uids_list[0];
            enable_rproxy(stream, settings, uid, task);
        }
    }
    catch (const std::exception& e)
    {
        YLOG_CTX_GLOBAL(stream->request()->context, error)
            << "subscribe_impl exception=\"" << e.what() << "\"";
    }
    catch (...)
    {
        YLOG_CTX_GLOBAL(stream->request()->context, error)
            << "subscribe_impl exception=\"unknown\"";
    }
}

template <protocol_type Proto, typename Stream>
void handle_cookie_check(
    const auth_error::error_code& err,
    const auth_info& auth,
    Stream stream,
    settings_ptr settings,
    task_ptr task) noexcept
{
    try
    {
        auto req = stream->request();
        req->context->profilers.pop("passport::cookie");

        if (err)
        {
            send_unauthorized(stream, "cookie auth failed");
            return;
        }

        for (auto& uid_to_check : task->uids_list)
        {
            if (find_ui(auth.users, uid_to_check) == auth.users.end())
            {
                send_unauthorized(
                    stream, string("uid ") + uid_to_check + " is not correct for this session");
                return;
            }
        }

        // TODO: use blackbox to get connection_id from cookies.
        task->bb_connection_id = auth.bb_connection_id;
        subscribe_impl<Proto, Stream>(stream, settings, task);
    }
    catch (const std::exception& e)
    {
        YLOG_CTX_GLOBAL(stream->request()->context, error)
            << "handle_auth_call exception error=\"" << e.what() << "\"";
    }
    catch (...)
    {
        YLOG_CTX_GLOBAL(stream->request()->context, error)
            << "handle_auth_call exception error=\"unknown\"";
    }
}

template <protocol_type Proto, typename Stream>
void handle_oauth_check(
    operation::result authorized,
    const user_authorization& user_auth,
    Stream stream,
    settings_ptr settings,
    task_ptr task) noexcept
{
    auto ctx = stream->request()->context;
    ctx->profilers.pop("passport::oauth");

    if (!authorized)
    {
        YLOG_CTX_GLOBAL(ctx, info) << "auth failed: error=" << authorized.error_reason;
        send_unauthorized(stream, "auth failed");
        return;
    }

    if (user_auth.uid != task->uids_list.at(0))
    {
        YLOG_CTX_GLOBAL(ctx, info) << "auth failed: original user=" << user_auth.uid;
        send_unauthorized(stream, "auth failed");
        return;
    }

    task->bb_connection_id = user_auth.bb_connection_id;
    subscribe_impl<Proto, Stream>(stream, settings, task);
}

template <protocol_type Proto, typename Stream>
void check_oauth_token(Stream stream, settings_ptr settings, task_ptr task)
{
    stream->request()->context->profilers.push("passport::oauth");
    auto mod_bb = yplatform::find<ymod_blackbox::client>("bbclient");
    namespace ph = std::placeholders;
    auto user_ip = stream->request()->context->remote_address;
    auto user_port = stream->request()->context->remote_port;
    resolve_config_with_oauth(
        mod_bb,
        settings,
        task->oauth_token,
        task->services,
        { user_ip, user_port },
        std::bind(handle_oauth_check<Proto, Stream>, ph::_1, ph::_2, stream, settings, task));
}

template <protocol_type Proto, typename Stream>
void check_secret_sign(Stream stream, settings_ptr settings, task_ptr task)
{
    auto valid_sign = make_secure_sign(
        services_uids_topic_data(task->services, task->uids_list, task->topic),
        task->sign_ts,
        settings->sign_secret);
    if (std::time(nullptr) >= task->sign_ts || valid_sign != task->secret_sign)
    {
        send_unauthorized(stream, "bad sign");
        return;
    }
    subscribe_impl<Proto, Stream>(stream, settings, task);
}

template <protocol_type Proto, typename Stream>
void check_cookies(Stream stream, settings_ptr settings, task_ptr task)
{
    auto req = stream->request();
    // Check requested services to allow passport auth.
    auto it = std::find_if(task->services_data.begin(), task->services_data.end(), [](auto& data) {
        return !data->properties.is_passport;
    });
    if (it != task->services_data.end())
    {
        send_unauthorized(stream, (*it)->properties.name + " not a passport service");
        return;
    }

    if (!req->origin || !settings->origin_domains.check_allowed(req->origin->domain).first)
    {
        send_unauthorized(stream, "this origin domain is not allowed");
        return;
    }
    string auth_domain = req->headers["host"];
    req->context->profilers.push("passport::cookie");
    auto user_ip = stream->request()->context->remote_address;
    auto user_port = stream->request()->context->remote_port;
    auth_user_cookies(
        req->context,
        task->cookie,
        auth_domain,
        { user_ip, user_port },
        std::bind(handle_cookie_check<Proto, Stream>, ph::_1, ph::_2, stream, settings, task));
}

template <protocol_type Proto, typename Stream>
void check_auth(Stream stream, settings_ptr settings, task_ptr task)
{
    task->services_data.reserve(task->services.size());
    auto auth_service_manager = find_service_manager(settings->api.auth_service_manager);
    for (auto& service_id : task->services)
    {
        auto service = auth_service_manager->find_service_by_name(service_id);
        if (!service)
        {
            send_bad_request(stream, "no service " + service_id);
            return;
        }
        task->services_data.push_back(service);
    }

    if (std::all_of(task->services_data.begin(), task->services_data.end(), [](auto& data) {
            return data->properties.auth_disabled;
        }))
    {
        subscribe_impl<Proto, Stream>(stream, settings, task);
    }
    else if (task->secret_sign.size())
    {
        check_secret_sign<Proto, Stream>(stream, settings, task);
    }
    else if (task->oauth_token.size())
    {
        if (task->topic.size())
        {
            send_bad_request(stream, "topics are not supported for OAuth authentication method");
            return;
        }
        if (task->uids_list.size() > 1)
        {
            send_bad_request(stream, "only one user id is allowed in OAuth authentication method");
            return;
        }
        check_oauth_token<Proto, Stream>(stream, settings, task);
    }
    else
    {
        if (task->cookie.empty())
        {
            send_unauthorized(stream, "");
            return;
        }
        check_cookies<Proto, Stream>(stream, settings, task);
    }
}

namespace api2 {

template <protocol_type Proto, typename Stream>
void subscribe(
    Stream stream,
    settings_ptr settings,
    const std::vector<string>& uids,
    const string& topic,
    const service_with_filter_list& services_with_filters,
    const string& client,
    const string& session,
    const string& secret_sign,
    const std::time_t sign_ts,
    const fetch_history& history,
    const fetch_history& history_strict,
    const string& filter)
{
    if (topic.size() && uids.size() > 1)
    {
        send_bad_request(stream, "request with topic should contain only one user id");
        return;
    }

    if (auto result = hacks::is_correct_uids(uids, services_with_filters); !result)
    {
        send_bad_request(stream, result.error_reason);
        return;
    }

    if (topic.size() && services_with_filters.size() > 1)
    {
        send_bad_request(stream, "request with topic should contain only one service");
        return;
    }

    if (history.size() && history_strict.size())
    {
        send_bad_request(stream, "request can't contain both fetch history arguments");
        return;
    }

    auto req = stream->request();
    auto task = make_task();
    task->uids_list = uids;
    task->services_with_filters = services_with_filters;
    task->services.reserve(services_with_filters.size());
    task->strong_order_services.reserve(services_with_filters.size());
    for (auto&& service : services_with_filters)
    {
        task->services.push_back(service.name);
        task->strong_order_services.push_back(service.name);
    }
    sort_unique(task->services_with_filters);
    sort_unique(task->services);
    task->topic = topic;
    task->client = client;
    task->session = session;
    task->secret_sign = secret_sign;
    task->sign_ts = sign_ts;
    task->cookie = req->headers["cookie"];
    task->history = history.size() ? history : history_strict;
    task->strict_history = history.size() ? false : true;
    task->watch_subscribers = req->url.param_value("watch_subscribers", "0") == "1";
    task->filter = filter;
    // TODO: load x-bb-clientid header if needed.

    auto& ioauth_token = req->headers["authorization"];
    if (ioauth_token.size())
    {
        static const string OAUTH_PREFIX = "OAuth ";
        if (ioauth_token.compare(0, OAUTH_PREFIX.size(), OAUTH_PREFIX) != 0)
        {
            send_unauthorized(stream, "unsupported authorization header");
            return;
        }
        task->oauth_token = ioauth_token.substr(OAUTH_PREFIX.size());
    }
    else
    {
        auto it = req->url.params.find("oauth_token");
        if (it != req->url.params.end())
        {
            task->oauth_token = it->second;
        }
    }

    check_auth<Proto, Stream>(stream, settings, task);
}

void subscribe_sse::operator()(
    http_stream_ptr stream,
    const std::vector<string>& uids,
    const string& topic,
    const service_with_filter_list& services,
    const string& client,
    const string& session,
    const string& secret_sign,
    const std::time_t sign_ts,
    const fetch_history& history,
    const fetch_history& history_strict,
    const string& filter)
{
    auto req = stream->request();
    // SSE requires HTTP v1.1
    if (req->proto_version.first < 1 || req->proto_version.second < 1)
    {
        send_bad_request(stream, "Protoсol version 1.1 is required");
        return;
    }

    using ymod_webserver::response_ptr;
    subscribe<protocol_type::sse, response_ptr>(
        stream,
        settings,
        uids,
        topic,
        services,
        client,
        session,
        secret_sign,
        sign_ts,
        history,
        history_strict,
        filter);
}

void subscribe_websocket::operator()(
    websocket_stream_ptr stream,
    const std::vector<string>& uids,
    const string& topic,
    const service_with_filter_list& services,
    const string& client,
    const string& session,
    const string& secret_sign,
    const std::time_t sign_ts,
    const fetch_history& history,
    const fetch_history& history_strict,
    const string& filter)
{
    using ymod_webserver::websocket::output_stream_ptr;
    subscribe<protocol_type::websocket, output_stream_ptr>(
        stream,
        settings,
        uids,
        topic,
        services,
        client,
        session,
        secret_sign,
        sign_ts,
        history,
        history_strict,
        filter);
}

void subscribe_websocketapi::operator()(
    std::shared_ptr<websocket_rpc::stream> stream,
    const string& uid,
    const string& topic,
    const service_with_filter& service,
    const string& client,
    const string& session,
    const string& secret_sign,
    const std::time_t sign_ts,
    const fetch_history& history,
    const fetch_history& history_strict,
    const string& filter)
{
    using stream_ptr = std::shared_ptr<websocket_rpc::stream>;
    subscribe<protocol_type::websocket_rpc, stream_ptr>(
        stream,
        settings,
        { uid },
        topic,
        { service },
        client,
        session,
        secret_sign,
        sign_ts,
        history,
        history_strict,
        filter);
}

} // api2

}}
