#pragma once

#include "payload_encryption.h"
#include "vapid.h"
#include <yxiva_mobile/push_task_context.h>
#include <yxiva_mobile/error.h>
#include <yxiva_mobile/ipusher.h>
#include <yxiva_mobile/reports.h>
#include <yxiva/core/callbacks.h>
#include <yplatform/module.h>
#include <yplatform/spinlock.h>
#include <yplatform/find.h>
#include <yplatform/time_traits.h>
#include <ymod_httpclient/call.h>
#include <memory>
#include <string>
#include <functional>

namespace yxiva { namespace mobile {

const size_t MAX_PAYLOAD_SIZE = 4077;
const string VAPID_CONTACT = "mailto:xiva-dev@yandex-team.ru";
const string FCM_REASON_UNAUTHORIZED = "UnauthorizedRegistration";

struct webpush_settings
{
    string vapid_key_file;
    string vapid_sub;
    yplatform::time_traits::duration key_generation_interval = yplatform::time_traits::hours(1);
    yplatform::time_traits::duration vapid_validity = yplatform::time_traits::hours(24);
    yplatform::time_traits::duration vapid_validity_delta = yplatform::time_traits::hours(1);

    void load(const yplatform::ptree& config)
    {
        vapid_key_file = config.get<string>("vapid_key_file");
        vapid_sub = config.get<string>("vapid_contact");
        key_generation_interval = config.get<yplatform::time_traits::duration>(
            "key_generation_interval", key_generation_interval);
        vapid_validity =
            config.get<yplatform::time_traits::duration>("vapid_validity", vapid_validity);
        vapid_validity_delta = config.get<yplatform::time_traits::duration>(
            "vapid_validity_delta", vapid_validity_delta);
    }
};

class webpush : public yplatform::module
{
    using synchronization = yplatform::spinlock;

public:
    void init(const yplatform::ptree& config)
    {
        http_client_ = yplatform::find<yhttp::call, std::shared_ptr>("http_client");
        http_opts_.reuse_connection = true;

        settings_ = std::make_shared<webpush_settings>();
        settings_->load(config);
        auto global_reactor = yplatform::global_reactor_set->get_global();
        keys_ = std::make_shared<keys_store<synchronization>>(
            *global_reactor->io(), settings_->key_generation_interval);
        keys_->logger(logger());
        vapids_ = std::make_shared<vapid_store<synchronization>>(
            settings_->vapid_validity, settings_->vapid_validity_delta);
        vapids_->logger(logger());
        if (!vapids_->reset(settings_->vapid_key_file, settings_->vapid_sub))
            throw std::runtime_error("can't start without correct vapid key");
    }

    void reload(const yplatform::ptree& config)
    {
        auto new_settings = std::make_shared<webpush_settings>();
        new_settings->load(config);
        auto settings = get_settings();
        if (settings->vapid_key_file != new_settings->vapid_key_file ||
            settings->vapid_sub != new_settings->vapid_sub)
            vapids_->reset(new_settings->vapid_key_file, new_settings->vapid_sub);
        {
            std::lock_guard<synchronization> lock(settings_sync_);
            settings_ = new_settings;
        }
    }

    void start()
    {
        keys_->run();
    }

    void push(webpush_task_context_ptr ctx, callback_t&& cb)
    {
        auto& log_data = ctx->request->ctx()->custom_log_data;

        if (ctx->payload.size() > MAX_PAYLOAD_SIZE)
        {
            cb(make_error(error::invalid_payload_length));
            return;
        }
        try
        {
            json_value sub;
            if (auto error = sub.parse(ctx->subscription))
            {
                report_invalid_webpush_subscription(ctx, *error);
                cb(make_error(error::invalid_subscription)); // 205
                return;
            }
            auto endpoint = sub.cref["endpoint"].to_string("");
            auto client_auth = sub.cref["keys"]["auth"].to_string("");
            auto client_key = sub.cref["keys"]["p256dh"].to_string("");
            if (endpoint.empty() || client_auth.empty() || client_key.empty())
            {
                cb(make_error(error::invalid_subscription)); // 205
                return;
            }

            string origin;
            if (auto parsed = callback_uri::parse_webpush_uri_origin(endpoint, origin); !parsed)
            {
                report_invalid_webpush_endpoint(ctx, parsed.error_reason);
            }
            auto vapid = vapids_->get(origin);
            log_data["push_service"] = origin;
            log_data["vapid_exp"] = std::to_string(vapid->exp);

            string request_headers;
            request_headers.reserve(640);
            request_headers += "ttl: ";
            request_headers += std::to_string(ctx->ttl);
            request_headers += "\r\nX-Request-ID: ";
            request_headers += ctx->transit_id;
            request_headers += "\r\nAuthorization: WebPush ";
            request_headers += vapid->auth;
            request_headers += "\r\nCrypto-Key: p256ecdsa=";
            request_headers += vapid->key;

            string body;
            if (ctx->payload.size())
            {
                auto pl_key = keys_->get();
                if (!pl_key)
                {
                    report_no_webpush_encryption_key(ctx);
                    cb(make_error(error::internal_error));
                }
                else
                {
                    ctx->request->ctx()->profilers.push("webpush::encrypt");
                    auto pl_enc = encrypt_payload(ctx->payload, *pl_key, client_auth, client_key);
                    ctx->request->ctx()->profilers.pop();
                    body = std::move(pl_enc.data);
                    request_headers += ";dh=";
                    request_headers += pl_key->public_key_base64;
                    request_headers += "\r\nContent-Encoding: aesgcm\r\nEncryption: salt=";
                    request_headers += pl_enc.salt;
                }
            }
            request_headers += "\r\n";

            namespace p = std::placeholders;
            http_client_->async_run(
                ctx,
                yhttp::request::POST(string(endpoint), request_headers, std::move(body)),
                http_opts_,
                std::bind(
                    &webpush::handle_response,
                    shared_from(this),
                    p::_1,
                    p::_2,
                    ctx,
                    std::move(cb)));
        }
        catch (const std::exception& ex)
        {
            report_webpush_request_error(ctx, {}, ex.what());
            cb(make_error(error::internal_error));
        }
    }

private:
    std::shared_ptr<webpush_settings> get_settings()
    {
        std::lock_guard<synchronization> lock(settings_sync_);
        return settings_;
    }

    void handle_response(
        const boost::system::error_code& err,
        yhttp::response response,
        webpush_task_context_ptr ctx,
        callback_t& cb)
    {
        if (err)
        {
            report_webpush_request_error(ctx, err);
            cb(make_error(
                (err == yhttp::errc::request_timeout || err == yhttp::errc::task_canceled) ?
                    error::task_cancelled :
                    error::internal_error));
            return;
        }

        if (response.status / 100 == 2)
        {
            cb(make_error(error::success)); // 200
            return;
        }

        report_webpush_request_error(
            ctx,
            err,
            std::to_string(response.status) + " " + response.reason + " " + response.body);

        if (response.status == 429 || response.status == 503)
        {
            cb(make_error(error::cloud_error));
        }
        else if (
            response.status == 410 ||
            (response.status == 400 && response.reason == FCM_REASON_UNAUTHORIZED))
        {
            cb(make_error(error::subscription_expired)); // 205
        }
        else if (response.status == 401 || response.status == 403 || response.status == 404)
        {
            cb(make_error(error::invalid_subscription)); // 205
        }
        else if (response.status == 400 || response.status == 413)
        {
            cb(make_error(error::data_compose_error)); // 400
        }
        else if (response.status / 100 == 5)
        {
            cb(make_error(error::cloud_error)); // 502
        }
        else
        {
            cb(make_error(error::cloud_error)); // 502
        }
    }

    synchronization settings_sync_;
    std::shared_ptr<yhttp::call> http_client_;
    yhttp::options http_opts_;
    std::shared_ptr<webpush_settings> settings_;
    std::shared_ptr<keys_store<synchronization>> keys_;
    std::shared_ptr<vapid_store<synchronization>> vapids_;
};
}}
