#include "send.h"

#include "encoding.h"
#include "types.h"
#include <yxiva/core/ec_crypto.h>
#include <yxiva/core/operation_result.h>
#include <yplatform/encoding/base64.h>
#include <yplatform/util/split.h>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/lexical_cast.hpp>

namespace yxiva { namespace web { namespace webpushapi {

namespace {
const string& get_header(const ymod_webserver::request_ptr& req, const string& name)
{
    static const string EMPTY;
    auto&& headers = req->headers;
    auto it = headers.find(name);
    return it != headers.end() ? it->second : EMPTY;
}

string get_header_parameter(
    const ymod_webserver::request_ptr& req,
    const string& header_name,
    const string& param_name)
{
    auto&& header = get_header(req, header_name);
    auto begin_pos = header.find(param_name);
    if (begin_pos == string::npos) return string();
    begin_pos += param_name.size() + 1;
    auto end_pos = header.find(";", begin_pos);
    auto len = (end_pos == string::npos ? end_pos : end_pos - begin_pos);
    return header.substr(begin_pos, len);
}

operation::result valid_encoding(const ymod_webserver::request_ptr& req)
{
    auto&& encoding = get_header(req, "content-encoding");
    if (encoding == "aesgcm")
    {
        return operation::success;
    }
    else
    {
        return "unexpected encoding \"" + encoding + "\"";
    }
}

std::vector<unsigned char> base64_decode_binary(const string& source)
{
    auto decode_range = yplatform::base64_urlsafe_decode(source);
    return { decode_range.begin(), decode_range.end() };
}

operation::result base64_decode(string& destination, const string& source)
{
    try
    {
        auto decode_range = yplatform::base64_urlsafe_decode(source);
        destination.assign(decode_range.begin(), decode_range.end());
        return operation::success;
    }
    catch (const std::exception& ex)
    {
        return ex.what();
    }
}

operation::result verify_signature(
    const void* data,
    size_t data_size,
    const string& signature_base64,
    const string& public_key_base64)
{
    try
    {
        auto signature = base64_decode_binary(signature_base64);
        auto public_key = base64_decode_binary(public_key_base64);
        auto ecgrp = std::shared_ptr<EC_GROUP>(
            EC_GROUP_new_by_curve_name(OBJ_txt2nid("prime256v1")), EC_GROUP_free);
        auto eckey = ec_crypto::evp_from_public_key(public_key, ecgrp.get());
        ec_crypto::verify_sign(data, data_size, signature, eckey);
        return operation::success;
    }
    catch (const std::exception& ex)
    {
        return string(ex.what());
    }
}

std::tuple<operation::result, jwt> check_vapid_draft_01(
    const string& value,
    const string& public_key_base64)
{
    auto splitted = yplatform::util::split(value, ".");
    if (splitted.size() != 3)
    {
        return { "bad jwt component count: " + std::to_string(splitted.size()), {} };
    }

    string json_header;
    if (auto res = base64_decode(json_header, splitted[0]); !res)
    {
        return { "malformed jwt header: \"" + res.error_reason + "\"", {} };
    }
    json_value jwt_header;
    if (auto error = jwt_header.parse(json_header))
    {
        return { "malformed jwt header: \"" + *error + "\"", {} };
    }
    if (jwt_header["typ"] != "JWT")
    {
        return { "jwt type is " + json_get<string>(jwt_header, "typ", "missing"), {} };
    }
    if (jwt_header["alg"] != "ES256")
    {
        return { "jwt algorithm is " + json_get<string>(jwt_header, "alg", "missing"), {} };
    }

    string json_payload;
    if (auto res = base64_decode(json_payload, splitted[1]); !res)
    {
        return { "malformed jwt payload: \"" + res.error_reason + "\"", {} };
    }
    json_value payload;
    if (auto error = payload.parse(json_payload))
    {
        return { "malformed jwt payload: \"" + *error + "\"", {} };
    }
    jwt jwt{ json_get(payload, "aud", string{}),
             json_get(payload, "sub", string{}),
             json_get<int64_t>(payload, "exp", -1) };
    if (jwt.exp < std::time(NULL))
    {
        return { "bad vapid expiration", jwt };
    }

    auto&& signature = splitted[2];
    size_t signed_data_size = value.size() - signature.size() - 1;
    // TODO standard recommends to cache verify results.
    auto res = verify_signature(value.data(), signed_data_size, signature, public_key_base64);
    return { res, jwt };
}

std::tuple<operation::result, jwt> check_vapid(
    const ymod_webserver::request_ptr& req,
    const string& public_key_base64)
{
    auto&& auth_header = get_header(req, "authorization");
    if (auth_header.empty()) return { "missing auth header", {} };
    // TODO check for senders using deprecated Vapid or Bearer.
    if (!boost::algorithm::istarts_with(auth_header, "WebPush "))
    {
        // Log actual value of scheme, trim by reasonable amount.
        auto scheme = auth_header.substr(0, std::min(auth_header.find(' '), 30UL));
        return { "unexpected auth scheme \"" + scheme + "\"", {} };
    }
    auto vapid = auth_header.substr(8);
    auto res = check_vapid_draft_01(vapid, public_key_base64);
    // Log the actual vapid (maybe trimmed) if check have failed.
    if (!std::get<operation::result>(res))
    {
        vapid.resize(std::min(vapid.size(), 300UL));
        YLOG_CTX_GLOBAL(req->ctx(), error) << "invalid vapid: \"" << vapid << "\"";
    }
    return res;
}

boost::optional<const sub_t> find_subscription(
    const std::vector<sub_t>& subs,
    const string& subscription_id)
{
    for (auto& sub : subs)
    {
        if (sub.id == subscription_id) return sub;
    }
    return {};
}
} // namespace

void send(settings_ptr settings, const ymod_webserver::http::stream_ptr& stream)
{
    auto&& log_data = stream->ctx()->custom_log_data;
    log_data["api_group"] = "webpushapi notifications";

    auto& url = stream->request()->url;
    if (url.path.size() != 3) return stream->result(http_codes::not_found);

    if (auto res = valid_encoding(stream->request()); !res)
    {
        send_bad_webpushapi_request(stream, "send", "bad request", res.error_reason);
        return;
    }

    // Hide part of endpoint before logging.
    if (int length_of_hidden = url.path[2].size() - settings->webpushapi.endpoint_characters_shown;
        length_of_hidden > 0)
    {
        string& raw_url = stream->request()->raw_url;
        raw_url.replace(raw_url.end() - length_of_hidden, raw_url.end(), "...");
    }

    string uidsetid, uid;
    auto decoded = decode_push_resource(url.path[2], uidsetid, uid);
    if (!decoded)
    {
        log_data["api_result"] = "send failed";
        log_data["error"] = decoded.error_reason;
        return stream->result(http_codes::not_found);
    }
    log_data["uid"] = uid;

    string transit_id = stream->ctx()->uniq_id();
    auto external_request_id = stream->request()->header_value("x-request-id", "");
    if (external_request_id.size() > settings->api.max_external_id_length)
    {
        return send_bad_request(stream, "external request id is too long");
    }
    if (external_request_id.size())
    {
        transit_id += ".";
        transit_id += external_request_id;
    }
    log_data["transit_id"] = transit_id;

    auto handler = handle_hub_list_json(
        stream, [stream, settings, uid, &log_data, transit_id](const std::vector<sub_t>& subs) {
            auto&& req = stream->request();
            auto sub = find_subscription(subs, settings->webpushapi.subscription_id);
            if (!sub || sub->session_key.empty())
            {
                // It seems subscription was deleted because of valid uid.
                log_data["api_result"] = "no subscription";
                return send_gone(stream, "", { { "TransitID", transit_id } });
            }
            auto& public_key = sub->session_key;
            auto [vapid_res, vapid_data] = check_vapid(req, public_key);
            log_data["jwt_aud"] = vapid_data.aud;
            log_data["jwt_sub"] = vapid_data.sub;
            log_data["jwt_exp"] = std::to_string(vapid_data.exp);
            if (!vapid_res)
            {
                log_data["api_result"] = "send failed";
                log_data["error"] = vapid_res.error_reason;
                return send_unauthorized(stream, "", { { "TransitID", transit_id } });
            }

            ttl_t ttl;
            using boost::conversion::try_lexical_convert;
            if (try_lexical_convert<ttl_t>(get_header(req, "ttl"), ttl))
            {
                ttl = std::min(settings->webpushapi.message_ttl, ttl);
            }
            else
            {
                ttl = settings->webpushapi.message_ttl;
            }

            message msg;
            msg.uid = uid;
            msg.service = settings->webpushapi.service;
            msg.operation = "push";
            msg.transit_id = transit_id;
            msg.ttl = ttl;
            auto& raw_body = req->raw_body;
            msg.raw_data.assign(raw_body.begin(), raw_body.end());
            msg.data["content-encoding"] = get_header(req, "content-encoding");
            msg.data["dh"] = get_header_parameter(req, "crypto-key", "dh");
            msg.data["salt"] = get_header_parameter(req, "encryption", "salt");

            find_hubrpc()->async_post_binary(
                stream->ctx(),
                uid,
                "/binary_notify",
                { { "uid", uid }, { "operation", msg.operation }, { "ratelimit", public_key } },
                pack(msg),
                [stream, settings, uid, &log_data](
                    const boost::system::error_code& ec, yhttp::response response) {
                    if (ec || response.status != 200)
                    {
                        log_data["api_result"] = "send failed";
                        log_data["error"] =
                            "binary_notify " + format_http_error(ec, response.status);
                        return handle_default_hub_codes(stream, ec, response);
                    }
                    stream->set_code(http_codes::ok);
                    auto it = response.headers.find("transitid");
                    auto transit_id = it == response.headers.end() ? string{} : it->second;
                    if (transit_id.size())
                    {
                        stream->add_header(
                            "Location",
                            settings->webpushapi.message_url + transit_id + "-" + generate_uuid());
                    }
                    log_data["api_result"] = "send success";
                    stream->add_header("TransitID", transit_id);
                    stream->result_body("");
                });
        });

    find_hubrpc()->async_get(
        stream->ctx(),
        uid,
        "/list_json",
        { { "uid", uid }, { "service", settings->webpushapi.service } },
        [handler = std::move(handler), &log_data](auto& ec, auto response) mutable {
            if (ec || response.status != 200)
            {
                log_data["api_result"] = "send failed";
                log_data["error"] = "list " + format_http_error(ec, response.status);
            }
            handler(ec, std::move(response));
        });
}

}}}
