#include "subscribe.h"

#include "common.h"
#include "parse_helpers.h"
#include "web/methods/http_handlers.h"
#include "web/auth/user.h"
#include "web/extract_requests.h"
#include "web/sse_subscriber.h"
#include "web/websocket_subscriber.h"
#include "web/control_session.h"
#include "web/auth/oauth.h"
#include <yxiva/core/iabstract.h>
#include <yxiva/core/methods/request_adaptor.h>
#include <yxiva/core/user_info.h>
#include <yxiva/core/authorizer.h>
#include "web/formatters/interface.h"
#include <yplatform/find.h>
#include <ymod_blackbox/client.h>

namespace yxiva { namespace web { namespace api {

namespace ph = std::placeholders;

enum class protocol_type
{
    sse,
    websocket
};

namespace {

struct args
{
    std::vector<string> uids_list;
    string client;
    string session;
    service_with_filter_list services;
    string oauth_token;
    string cookie;
    string bb_connection_id;
};

}

typedef boost::shared_ptr<args> args_ptr;

inline args_ptr make_args()
{
    return boost::make_shared<args>();
}

typedef ymod_webserver::websocket::output_stream_ptr websocket_stream_ptr;

template <protocol_type, typename Stream>
timer_ptr create_timer(Stream stream)
{
    return stream->make_timer();
}

template <typename Stream = http_stream_ptr>
void send_not_allowed(Stream stream, ymod_webserver::context_ptr ctx)
{
    YLOG_CTX_GLOBAL(ctx, info) << "method not allowed";
    stream->result(http_codes::method_not_allowed, "");
}

template <>
void send_not_allowed<websocket_stream_ptr>(websocket_stream_ptr, ymod_webserver::context_ptr)
{
}

template <protocol_type, typename Stream>
void send_internal_error(Stream stream, const string& info, ymod_webserver::context_ptr ctx)
{
    YLOG_CTX_GLOBAL(ctx, info) << "internal server error: " << info;
    stream->result(http_codes::internal_server_error, info);
}

template <>
void send_internal_error<protocol_type::websocket, websocket_stream_ptr>(
    websocket_stream_ptr stream,
    const string& info,
    ymod_webserver::context_ptr ctx)
{
    YLOG_CTX_GLOBAL(ctx, info) << "internal server error: " << info;
    json_value message;
    message["error"] = info;
    stream->close_connection(websocket_codes::internal_server_error, message.stringify());
}

template <protocol_type, typename Stream>
web_subscriber_ptr create_subscriber(
    Stream stream,
    settings_ptr settings,
    formatters::formatter_ptr formatter,
    const string& client);

template <>
web_subscriber_ptr create_subscriber<protocol_type::sse, http_stream_ptr>(
    http_stream_ptr stream,
    settings_ptr /*settings*/,
    formatters::formatter_ptr formatter,
    const string& client)
{
    auto req = stream->request();
    stream->set_code(http_codes::ok);
    stream->add_header("Access-Control-Allow-Origin", "*");
    stream->add_header(
        "Cache-Control",
        "max-age=0, must-revalidate, proxy-revalidate, no-cache, no-store, private");
    stream->add_header("Expires", "Thu, 01 Jan 1970 00:00:01 GMT");
    stream->add_header("Pragma", "no-cache");
    stream->set_content_type("text", "event-stream");
    stream->set_connection(false);

    yplatform::net::streamable_ptr chunked_stream = stream->result_chunked();
    shared_ptr<sse_subscriber> subscriber =
        ::yxiva::make_shared<sse_subscriber>(req->context, chunked_stream, formatter, 0, client);
    return subscriber;
}

template <>
web_subscriber_ptr create_subscriber<protocol_type::websocket, websocket_stream_ptr>(
    websocket_stream_ptr stream,
    settings_ptr settings,
    formatters::formatter_ptr formatter,
    const string& client)
{
    auto req = stream->request();
    shared_ptr<websocket_subscriber> subscriber = ::yxiva::make_shared<websocket_subscriber>(
        req->context, stream, formatter, 0, settings->rand_websocket_inactive_timeout(), client);
    subscriber->init();
    return subscriber;
}

template <protocol_type>
time_duration get_ping_interval(settings_ptr settings);
template <>
time_duration get_ping_interval<protocol_type::sse>(settings_ptr settings)
{
    return settings->ss_event_ping_interval;
}
template <>
time_duration get_ping_interval<protocol_type::websocket>(settings_ptr settings)
{
    return settings->websocket_ping_interval;
}

inline std::vector<user_info>::const_iterator find_ui(
    const std::vector<user_info>& vinfo,
    const string& uid_to_search)
{
    for (auto it = vinfo.begin(); it != vinfo.end(); ++it)
    {
        string item_uid = static_cast<string>(it->uid);
        if (item_uid == uid_to_search) return it;
    }
    return vinfo.end();
}

template <protocol_type Proto, typename Stream>
void subscribe_impl(const auth_info& auth, Stream stream, settings_ptr settings, args_ptr args)
{
    try
    {
        auto req = stream->request();
        auth_info ui_to_subscribe;
        ui_to_subscribe.users.reserve(auth.users.size());

        for (const auto& uid_to_check : args->uids_list)
        {
            auto found_info = find_ui(auth.users, uid_to_check);
            if (found_info != auth.users.end())
            {
                ui_to_subscribe.users.push_back(*found_info);
            }
            else
            {
                send_unauthorized(
                    stream, string("uid ") + uid_to_check + " is not correct for this session");
                return;
            }
        }

        multi_channels_set expanded_requests;
        expanded_requests.reserve(ui_to_subscribe.users.size());
        for (auto& uid : ui_to_subscribe.users)
        {
            channels_set channels_by_user;
            for (auto& service : args->services)
            {
                channels_by_user.push_back({ service_name(service.name),
                                             service.filter,
                                             args->client,
                                             args->session,
                                             args->bb_connection_id });
            }
            expanded_requests.emplace_back(uid, std::move(channels_by_user));
        }
        if (expanded_requests.empty())
        {
            send_bad_request(stream, "no such channels were found");
            return;
        }

        auto future_close = make_future_close(stream);

        auto formatters_kit = find_processor()->formatters();
        if (!formatters_kit->has("json"))
        {
            send_internal_error<Proto>(stream, "formatter was not found", req->ctx());
            return;
        }
        auto formatter = formatters_kit->get("json");

        auto subscriber = create_subscriber<Proto>(stream, settings, formatter, args->client);

        control_session(
            expanded_requests,
            subscriber,
            create_timer<Proto>(stream),
            get_ping_interval<Proto>(settings),
            future_close);
    }
    catch (const std::exception& e)
    {
        YLOG_CTX_GLOBAL(stream->request()->context, error)
            << "subscribe_impl exception=\"" << e.what() << "\"";
    }
    catch (...)
    {
        YLOG_CTX_GLOBAL(stream->request()->context, error)
            << "subscribe_impl exception=\"unknown\"";
    }
}

template <protocol_type Proto, typename Stream>
void handle_cookie_check(
    const auth_error::error_code& err,
    const auth_info& auth,
    Stream stream,
    settings_ptr settings,
    args_ptr args) noexcept
{
    try
    {
        auto req = stream->request();
        req->context->profilers.pop("passport::cookie");

        if (err)
        {
            send_unauthorized(stream, "");
            return;
        }

        args->bb_connection_id = auth.bb_connection_id;
        subscribe_impl<Proto, Stream>(auth, stream, settings, args);
    }
    catch (const std::exception& e)
    {
        YLOG_CTX_GLOBAL(stream->request()->context, error)
            << "handle_auth_call exception error=\"" << e.what() << "\"";
    }
    catch (...)
    {
        YLOG_CTX_GLOBAL(stream->request()->context, error)
            << "handle_auth_call exception error=\"unknown\"";
    }
}

template <protocol_type Proto, typename Stream>
void handle_oauth_check(
    operation::result authorized,
    const user_authorization& user_auth,
    Stream stream,
    settings_ptr settings,
    args_ptr args) noexcept
{
    auto ctx = stream->request()->context;
    ctx->profilers.pop("passport::oauth");

    if (!authorized)
    {
        YLOG_CTX_GLOBAL(ctx, info) << "auth failed: error=" << authorized.error_reason;
        send_unauthorized(stream, "");
        return;
    }

    auth_info auth;
    auth.users.push_back(user_info{ user_id(user_auth.uid) });
    args->bb_connection_id = user_auth.bb_connection_id;
    subscribe_impl<Proto, Stream>(auth, stream, settings, args);
}

template <protocol_type Proto, typename Stream>
void subscribe(Stream stream, settings_ptr settings)
{
    auto req = stream->request();
    auto args = make_args();

    if (!params::parser<
            params::uid_as_list,
            params::client,
            params::session,
            params::service_list_with_tags,
            params::oauth_token_or_cookie>::fill(stream, args))
        return;

    if (settings->disable_authentication)
    {
        YLOG_CTX_GLOBAL(req->ctx(), warning) << "SKIP AUTH";

        auth_info auth;
        for (auto& uid : args->uids_list)
        {
            auth.users.push_back(user_info(user_id(uid)));
        }
        subscribe_impl<Proto, Stream>(auth, stream, settings, args);
        return;
    }

    // check auth before composing channels
    auto user_ip = req->context->remote_address;
    auto user_port = req->context->remote_port;
    if (!args->oauth_token.empty())
    {
        req->context->profilers.push("passport::oauth");
        auto mod_bb = yplatform::find<ymod_blackbox::client>("bbclient");
        namespace ph = std::placeholders;
        resolve_config_with_oauth(
            mod_bb,
            settings,
            args->oauth_token,
            args->services,
            { user_ip, user_port },
            std::bind(handle_oauth_check<Proto, Stream>, ph::_1, ph::_2, stream, settings, args));
    }
    else
    {
        // Check requested services to allow passport auth.
        auto auth_service_manager = find_service_manager(settings->api.auth_service_manager);
        for (auto& service_with_tags : args->services)
        {
            auto service = auth_service_manager->find_service_by_name(service_with_tags.name);
            if (!service)
            {
                send_unauthorized(stream, "no service " + service_with_tags.name);
                return;
            }
            else if (!service->properties.is_passport)
            {
                send_unauthorized(stream, service_with_tags.name + " not a passport service");
                return;
            }
        }
        if (!req->origin || !settings->origin_domains.check_allowed(req->origin->domain).first)
        {
            send_unauthorized(stream, "this origin domain is not allowed");
            return;
        }
        string auth_domain = req->headers["host"];
        req->context->profilers.push("passport::cookie");
        auth_user_cookies(
            req->context,
            args->cookie,
            auth_domain,
            { user_ip, user_port },
            std::bind(handle_cookie_check<Proto, Stream>, ph::_1, ph::_2, stream, settings, args));
    }
}

void subscribe_sse(http_stream_ptr stream, settings_ptr settings)
{
    auto req = stream->request();
    // SSE requires HTTP v1.1
    if (req->proto_version.first < 1 || req->proto_version.second < 1)
    {
        send_bad_request(stream, "Protoсol version 1.1 is required");
        return;
    }
    subscribe<protocol_type::sse, ymod_webserver::response_ptr>(stream, settings);
}

void subscribe_websocket(websocket_stream_ptr stream, settings_ptr settings)
{
    subscribe<protocol_type::websocket, ymod_webserver::websocket::output_stream_ptr>(
        stream, settings);
}

}}}
