#include "websocket_stream.h"
#include "http_stream.h"
#include "version.h"
#include <yplatform/net/buffers/chunk.h>
#include <yplatform/net/streambuf.h>
#include <yplatform/encoding/base64.h>
#include <yplatform/context_repository.h>
#include <openssl/md5.h>
#include <openssl/sha.h>
#include <ymod_webserver/websocket_body.h>
#include <ymod_webserver/websocket_message.h>
#include "parser/date.h"
#include "net_server.h"

namespace ymod_webserver { namespace websocket {

namespace {

template <typename Allocator = std::allocator<char>>
class send_buffer : public yplatform::net::buffers::const_base_chunk
{
    typedef typename Allocator::template rebind<byte_t>::other allocator_type;

public:
    send_buffer(allocator_type const& allocator, const boost::asio::const_buffer& buff)
        : allocator_(allocator)
        , data_(boost::asio::buffer_cast<const byte_t*>(buff))
        , size_(boost::asio::buffer_size(buff))
        , prom_()
    {
    }

    ~send_buffer()
    {
        allocator_.deallocate(const_cast<byte_t*>(data_), size_);
        try
        {
            prom_.set();
        }
        catch (...)
        {
        }
    }

    void set_promise(promise_void_t prom)
    {
        prom_ = prom;
    }

    future_void_t fres() const
    {
        return prom_;
    }

    std::pair<const byte_t*, std::size_t> buff() override
    {
        return std::make_pair(data_, size_);
    }

private:
    allocator_type allocator_;
    const byte_t* data_;
    std::size_t size_;
    promise_void_t prom_;
};

static int32_t get_websocket_key_value(const request_ptr& req, const char* key_name)
{
    header_map_t::const_iterator i_key = req->headers.find(key_name);
    if (i_key == req->headers.end())
    {
        throw parse_error(
            "websocket_key_value", "incorrect websocket76 headers", "incorrect websocket76 headers")
            << BOOST_ERROR_INFO;
    }
    const char* s = i_key->second.c_str();
    int64_t v = 0;
    int32_t ac = 0;
    for (const char* pa = s; *pa; ++pa)
    {
        if (*pa == ' ') ac++;
        else if (*pa >= '0' && *pa <= '9')
            v = v * 10 + (*pa - '0');
    }
    if (ac == 0)
    {
        throw parse_error(
            "websocket_key_value", "incorrect websocket76 headers", "incorrect websocket76 headers")
            << BOOST_ERROR_INFO;
    }
    int32_t v32 = static_cast<int32_t>(v / ac);

    uint8_t* b = reinterpret_cast<uint8_t*>(&v32);
    std::swap(b[0], b[3]);
    std::swap(b[1], b[2]);
    return v32;
}

string process_accept_sign(const request_ptr& req)
{
    header_map_t::const_iterator i_key = req->headers.find("sec-websocket-key");
    if (i_key == req->headers.end())
    {
        throw parse_error(
            "websocket_key_value", "incorrect websocket07 headers", "incorrect websocket07 headers")
            << BOOST_ERROR_INFO;
    }
    static const string guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    unsigned char result[SHA_DIGEST_LENGTH];
    SHA_CTX ctx;
    ::SHA1_Init(&ctx);
    ::SHA1_Update(&ctx, i_key->second.c_str(), i_key->second.size());
    ::SHA1_Update(&ctx, guid.c_str(), guid.size());
    ::SHA1_Final(result, &ctx);
    string res;
    res += yplatform::base64_encode(result + 0, result + SHA_DIGEST_LENGTH);
    return res;
}

void process_websocket_secure_key(const request_ptr& req, string& out)
{
    int32_t v1 = get_websocket_key_value(req, "sec-websocket-key1");
    int32_t v2 = get_websocket_key_value(req, "sec-websocket-key2");

    MD5_CTX md5handler;
    MD5_Init(&md5handler);
    MD5_Update(&md5handler, &v1, sizeof(v1));
    MD5_Update(&md5handler, &v2, sizeof(v2));
    string body_data(req->raw_body.begin(), req->raw_body.end());
    MD5_Update(&md5handler, body_data.c_str(), 8);

    out.resize(16);
    MD5_Final(reinterpret_cast<unsigned char*>(&out[0]), &md5handler);
}

template <class Session>
class websocket_streamer : public yplatform::net::streamer_base
{
public:
    websocket_streamer(
        boost::shared_ptr<Session> const& s,
        std::size_t sz,
        upgrade_proto_header upgrade_to,
        uint8_t opcode,
        mask_t mask)
        : sess(s), is_binary_(sz > 0), upgrade_to_(upgrade_to)
    {
        switch (upgrade_to)
        {
        case upgrade_to_websocket75:
        case upgrade_to_websocket76:
            if (is_binary_)
            { // binary stream
                parser::make_websocket_binary_size(*this, sz);
            }
            else
            { // text stream
                this->operator<<('\x00');
            }
            break;
        case upgrade_to_websocket07:
        case upgrade_to_websocket08:
        case upgrade_to_websocket13:
            make_websocket_package_header(message::bit_fin, opcode, mask, sz, *this);
            break;
        default:
            throw std::runtime_error("not implemented");
        }
    }

    ~websocket_streamer()
    {
        if (upgrade_to_ == upgrade_to_websocket75 || upgrade_to_ == upgrade_to_websocket76)
        {
            if (!is_binary_) (*this) << '\xFF';
        }
        ostr->flush();
        sess->send_client_stream(buf->release_buffer());
    }

private:
    boost::shared_ptr<Session> sess;
    bool is_binary_;
    upgrade_proto_header upgrade_to_;
};

namespace {
inline void do_nothing(boost::shared_ptr<websocket_stream> /*session*/)
{
}
}

}

websocket_stream::websocket_stream(
    boost::weak_ptr<net_server> owner,
    context_ptr context,
    session_ptr session,
    const request_ptr& req,
    read_buffer_ptr readq,
    const yplatform::log::source& logger,
    size_t max_length,
    size_t max_fragmentation)
    : yplatform::log::contains_logger(logger)
    , owner_(owner)
    , context_(context)
    , session_(session)
    , request_(req)
    , readq_(readq)
    , i_saved_(readq_->begin())
    , parser_(
          req->context,
          boost::lexical_cast<string>(session_->id()),
          req->upgrade_to,
          max_length)
    , is_reading_(false)
    , close_frame_sent_(false)
    , max_fragmentation_(max_fragmentation)
{
    ctx()->profilers.push("exec");
    yplatform::context_repository::instance().add_context(ctx());
}

websocket_stream::~websocket_stream()
{
    ctx()->profilers.pop("exec");
    yplatform::context_repository::instance().rem_context(ctx());
    if (auto powner = owner_.lock())
    {
        try
        {
            powner->on_destroy(this);
        }
        catch (...)
        {
        }
    }
}

string websocket_stream::get_websocket_location()
{
    string result = session_->session_settings().websocket_location;
    string::size_type url_begin = request_->raw_request_line.find(' ') + 1;
    string::size_type url_end = request_->raw_request_line.find_last_of(' ');
    result += request_->raw_request_line.substr(url_begin, url_end - url_begin);
    return result;
}

void websocket_stream::init()
{
    session_->set_write_error_hook(
        [this, weak_self = weak_from_this()](const boost::system::error_code& ec) {
            if (auto self = weak_self.lock())
            {
                handle_io_error(ec, "write");
            }
        });
    out_handshake();
}

void websocket_stream::out_handshake()
{
    yplatform::net::streamer_wrapper stream = session_->client_stream();
    switch (request_->upgrade_to)
    {
    case upgrade_to_websocket75:
    {
        stream << "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
                  "Upgrade: WebSocket\r\n"
                  "Connection: Upgrade\r\n"
                  "WebSocket-Origin: "
               << request_->headers["origin"]
               << "\r\n"
                  "WebSocket-Location: "
               << get_websocket_location() << "\r\n";
        if (session_->session_settings().websocket_protocol.size() > 0)
            stream << "WebSocket-Protocol: " << session_->session_settings().websocket_protocol
                   << "\r\n";
        stream << "Y-Context: " << request_->context->uniq_id() << "\r\n";
        stream << "Date: ";
        parser::print_date(stream, std::time(0));
        stream << "\r\nServer: YaWEBServer/" << VERSION
               << "\r\n"
                  "Content-Type: text/plain\r\n"
                  "\r\n";
        break;
    }
    case upgrade_to_websocket76:
    {
        string response_body;
        process_websocket_secure_key(request_, response_body);
        stream << "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
                  "Upgrade: WebSocket\r\n"
                  "Connection: Upgrade\r\n"
                  "Sec-WebSocket-Origin: "
               << request_->headers["origin"]
               << "\r\n"
                  "Sec-WebSocket-Location: "
               << get_websocket_location() << "\r\n";
        if (session_->session_settings().websocket_protocol.size() > 0)
            stream << "Sec-WebSocket-Protocol: " << session_->session_settings().websocket_protocol
                   << "\r\n";
        stream << "Y-Context: " << request_->context->uniq_id() << "\r\n";
        stream << "Date: ";
        parser::print_date(stream, std::time(0));
        stream << "\r\nServer: YaWEBServer/" << VERSION
               << "\r\n"
                  "Content-Type: text/plain\r\n"
                  "\r\n"
               << response_body;
        break;
    }
    case upgrade_to_websocket07:
    case upgrade_to_websocket08:
    case upgrade_to_websocket13:
    {
        stream << "HTTP/1.1 101 Switching Protocols\r\n"
                  "Upgrade: WebSocket\r\n"
                  "Connection: Upgrade\r\n"
                  "Sec-WebSocket-Accept: "
               << process_accept_sign(request_) << "\r\n";
        if (session_->session_settings().websocket_protocol.size() > 0)
            stream << "Sec-WebSocket-Protocol: " << session_->session_settings().websocket_protocol
                   << "\r\n";
        stream << "Y-Context: " << request_->context->uniq_id() << "\r\n";
        stream << "\r\n";
        break;
    }
    default:
        break;
    }
}

bool websocket_stream::is_open() const
{
    return session_->is_open() && !close_frame_sent_;
}

bool websocket_stream::is_secure() const
{
    return session_->is_secure();
}

void websocket_stream::start_read_message_strand(
    const yplatform::time_traits::duration& inactive_timeout)
{
    if (!is_open()) return;

    session_->update_read_timeout(inactive_timeout);

    if (readq_->size())
    {
        process_read_buffer();
    }
    else
    {
        begin_read_message();
    }
}

yplatform::net::streamer_wrapper websocket_stream::text_stream(std::size_t sz, uint8_t opcode)
{
    return yplatform::net::streamer_wrapper(new websocket_streamer<websocket_stream>(
        shared_from_this(), sz, request_->upgrade_to, opcode, generate_mask()));
}

yplatform::net::streamer_wrapper websocket_stream::bin_stream(std::size_t sz, uint8_t opcode)
{
    return yplatform::net::streamer_wrapper(new websocket_streamer<websocket_stream>(
        shared_from_this(), sz, request_->upgrade_to, opcode, generate_mask()));
}

void websocket_stream::send_client_stream_strand(
    yplatform::net::buffers::const_chunk_buffer const& s)
{
    if (!is_open()) return;
    session_->send_client_stream(s);
}

void websocket_stream::add_message_callback_strand(message_hook_t const& hook)
{
    if (!is_open()) return;
    message_hooks_.push_back(hook);
}

void websocket_stream::set_close_callback_strand(close_hook_t const& hook)
{
    if (!is_open())
    {
        hook(0, "already closed");
        return;
    }
    close_hook_ = session_->current_io_service().wrap(hook);
}

void websocket_stream::send_text_strand(const string& msg, uint8_t opcode, promise_void_t prom)
{
    if (!is_open()) return;
    if (!session_->is_open())
    {
        prom.set_exception(connection_broken() << BOOST_ERROR_INFO);
        return;
    }
    switch (request_->upgrade_to)
    {
    case upgrade_to_websocket75:
    case upgrade_to_websocket76:
    {
        yplatform::net::basic_streambuf<> buffer(msg.size() + 20, msg.size() + 20);
        std::ostream stream(&buffer);
        stream << '\x00' << msg << '\xFF';
        stream.flush();
        typedef yplatform::net::basic_streambuf<>::allocator_type allocator_type;
        boost::shared_ptr<send_buffer<allocator_type>> buff_ptr =
            buffer.release_buffer<send_buffer<allocator_type>>();
        buff_ptr->set_promise(prom);
        session_->send_client_stream(yplatform::net::buffers::make_chunk_buffer(buff_ptr));
        break;
    }
    case upgrade_to_websocket07:
    case upgrade_to_websocket08:
    case upgrade_to_websocket13:
    {
        send_binary(msg, opcode);
        break;
    }
    default:
        prom.set_exception(std::runtime_error("not implemented"));
    }
}

void websocket_stream::send_binary_strand(const string& msg, uint8_t opcode, promise_void_t prom)
{
    if (!is_open()) return;
    if (!session_->is_open())
    {
        prom.set_exception(connection_broken() << BOOST_ERROR_INFO);
        return;
    }
    switch (request_->upgrade_to)
    {
    case upgrade_to_websocket75:
    case upgrade_to_websocket76:
    {
        yplatform::net::basic_streambuf<> buffer(msg.size() + 2, msg.size() + 16);
        std::ostream stream(&buffer);
        parser::make_websocket_binary_size(stream, msg.size());
        stream << msg;
        stream.flush();
        typedef yplatform::net::basic_streambuf<>::allocator_type allocator_type;
        boost::shared_ptr<send_buffer<allocator_type>> buff_ptr =
            buffer.release_buffer<send_buffer<allocator_type>>();
        buff_ptr->set_promise(prom);
        session_->send_client_stream(yplatform::net::buffers::make_chunk_buffer(buff_ptr));
        break;
    }
    case upgrade_to_websocket07:
    case upgrade_to_websocket08:
    case upgrade_to_websocket13:
        send_message_strand(
            msg,
            opcode,
            message::bit_fin,
            generate_mask(),
            session_->session_settings().enable_websocket_masking,
            prom);
        break;
    default:
        prom.set_exception(std::runtime_error("not implemented"));
    }
}

yplatform::time_traits::timer_ptr websocket_stream::make_timer() const
{
    return std::make_shared<yplatform::time_traits::timer>(session_->current_io_service());
}

boost::asio::io_service& websocket_stream::get_io_service()
{
    return session_->current_io_service();
}

void websocket_stream::send_message_strand(
    const string& msg,
    uint8_t opcode,
    uint8_t bits,
    const mask_t& mask,
    bool apply_mask,
    promise_void_t prom)
{
    yplatform::net::basic_streambuf<> buffer(msg.size() + 2, msg.size() + 16);
    std::ostream stream(&buffer);
    std::size_t message_length = prepare_ws_message(msg, opcode, bits, mask, apply_mask, stream);
    stream.flush();

    YLOG_CTX_LOCAL(session_->get_context(), info)
        << "sending ws message: opcode=" << static_cast<int>(opcode)
        << ", bits=" << static_cast<int>(bits) << ", length=" << message_length;

    typedef yplatform::net::basic_streambuf<>::allocator_type allocator_type;
    boost::shared_ptr<send_buffer<allocator_type>> buff_ptr =
        buffer.release_buffer<send_buffer<allocator_type>>();
    buff_ptr->set_promise(prom);
    session_->send_client_stream(yplatform::net::buffers::make_chunk_buffer(buff_ptr));
}

mask_t websocket_stream::generate_mask() const
{
    if (session_->session_settings().enable_websocket_masking ||
        request_->upgrade_to >= upgrade_to_websocket08)
    {
        mask_t mask;
        for (std::size_t i = 0; i < mask.size(); ++i)
            mask[i] = static_cast<uint8_t>(0xFF * rand() / RAND_MAX);
        return mask;
    }
    return empty_mask;
}

void websocket_stream::deliver_message(const message& msg)
{
    for (auto& hook : message_hooks_)
    {
        session_->current_io_service().post(std::bind(hook, msg));
    }
}

void websocket_stream::maybe_notify_closed(uint16_t code, const string& reason)
{
    result_code_ = static_cast<codes::code>(code);
    if (close_hook_)
    {
        // TODO safe call
        try
        {
            close_hook_(code, reason);
        }
        catch (...)
        {
        }
        close_hook_ = close_hook_t();
    }
}

void websocket_stream::begin_read_message(std::size_t min)
{
    if (!is_open()) return;
    if (is_reading_)
    {
        YLOG_CTX_LOCAL(session_->get_context(), warning)
            << "begin reading attempt refused because we're already reading";
        return;
    }
    session_->begin_read(
        boost::bind(&websocket_stream::handle_read_message, shared_from_this(), _1, _2),
        readq_->prepare((min > 0 ? min : 1)),
        (min > 0 ? min : 1));
    is_reading_ = true;
}

void websocket_stream::handle_read_message(boost::system::error_code const& e, std::size_t bytes)
{
    is_reading_ = false;
    if (!is_open()) return;
    if (e)
    {
        handle_io_error(e, "read");
        return;
    }

    if (bytes == 0)
    {
        begin_read_message();
        return;
    }

    readq_->commit(bytes);
    process_read_buffer();
}

void websocket_stream::process_read_buffer()
{
    try
    {
        optional<message> msg_opt = parse_message();
        while (msg_opt && is_open())
        {
            on_receive(msg_opt.get());
            msg_opt = parse_message();
        }
    }
    catch (const yplatform::exception& e)
    {
        YLOG_CTX_LOCAL(session_->get_context(), error)
            << "websocket_stream process_stream exception message=\"" << e.private_message()
            << "\"";
        handle_runtime_error(e.public_message());
    }
    catch (std::exception const& e)
    {
        YLOG_CTX_LOCAL(session_->get_context(), error)
            << "websocket_stream process_stream exception message=\"" << e.what() << "\"";
        handle_runtime_error(e.what());
    }
    catch (...)
    {
        YLOG_CTX_LOCAL(session_->get_context(), error)
            << "websocket_stream process_stream exception message=none";
        handle_runtime_error("unknown exception");
    }

    if (is_open()) begin_read_message();
}

optional<message> websocket_stream::parse_message()
{
    read_iterator start = readq_->begin();
    read_iterator end = readq_->end();

    if (start == end) return optional<message>();

    readq_->detach(parser_(start, i_saved_, end));
    if (!parser_.is_finished()) return optional<message>();

    message msg = parser_.headers();
    websocket_parser_t::message_t msg_body = parser_.msg();
    msg.data = websocket::make_websocket_payload(readq_->detach(msg_body.second), msg.mask);
    readq_->detach(i_saved_);
    parser_.reset();
    return optional<message>(msg);
}

void websocket_stream::validate_message(message const& msg)
{
    switch (msg.opcode)
    {
    case message::opcode_close:
        if (msg.length == 1) // must be == 0 or >= 2: (\d{2})? | (\d{2}.{123})
            throw websocket_error("websocket_stream", "close frame invalib length");
    case message::opcode_ping:
    case message::opcode_pong:
        if (!msg.is_finished())
            throw websocket_error("websocket_stream", "control frames must not be fragmentated");
        if (msg.length > 125)
            throw websocket_error(
                "websocket_stream", "control frames payload length must be <= 125");
        break;
    case message::opcode_continuation:
        if (!fragmented_message_ || fragmented_message_.get().is_finished)
            throw websocket_error("websocket_stream", "unexpected continuation frame");
        if (fragmented_message_.get().fragments_count >= max_fragmentation_)
            throw websocket_error("websocket_stream", "max fragmentation limit acceded");
        break;
    case message::opcode_text:
    case message::opcode_binary:
        if (fragmented_message_)
            throw websocket_error("websocket_stream", "continuation frame was expected");
        break;
    default:
        break;
    }
}

void websocket_stream::perform_message(message const& msg)
{
    validate_message(msg);
    switch (msg.opcode)
    {
    case message::opcode_ping:
        send_message_strand(
            string(msg.data.begin(), msg.data.end()),
            message::opcode_pong,
            message::bit_fin_val,
            empty_mask);
        break;
    case message::opcode_pong:
        break;
    case message::opcode_close:
        handle_close_by_remote_peer();
        break;
    case message::opcode_text:
    case message::opcode_binary:
        if (!msg.is_finished()) fragmented_message_ = msg;
        deliver_message(msg);
        break;
    case message::opcode_continuation:
        if (fragmented_message_) fragmented_message_.get() << msg;
        deliver_message(msg);
        break;
    default:
        throw websocket_error("websocket_stream", "unexpected opcode");
    }
    if (fragmented_message_ && fragmented_message_.get().is_finished)
    {
        fragmented_message_.reset();
    }
}

void websocket_stream::on_receive(message const& msg)
{
    YLOG_CTX_LOCAL(session_->get_context(), info)
        << "new ws message: opcode=" << static_cast<int>(msg.opcode)
        << ", bits=" << static_cast<int>(msg.bits) << ", length=" << msg.length
        << ", is_fragment=" << (!msg.is_finished() || fragmented_message_ ? "true" : "false");

    perform_message(msg);
}

void websocket_stream::close_connection_impl(uint16_t code, const string& reason)
{
    if (!is_open()) return;
    close_frame_sent_ = true;

    if (request_->upgrade_to >= upgrade_to_websocket07 &&
        request_->upgrade_to <= upgrade_to_websocket13)
    {
        std::ostringstream stream;
        stream << reinterpret_cast<char*>(&code)[1] << reinterpret_cast<char*>(&code)[0] << reason;
        promise_void_t prom;
        future_void_t future = prom;
        future.add_callback([this, self = shared_from_this()]() { session_->close(); });
        send_message_strand(
            stream.str(),
            message::opcode_close,
            message::bit_fin,
            generate_mask(),
            session_->session_settings().enable_websocket_masking,
            prom);
    }
    else
    {
        session_->do_shutdown(boost::bind(do_nothing, shared_from_this()), true);
    }
}

void websocket_stream::handle_runtime_error(const string& reason)
{
    YLOG_CTX_LOCAL(session_->get_context(), info) << "runtime error=\"" << reason << "\"";
    session_->close();
    maybe_notify_closed(codes::close_opcode_proto_error, "protocol runtime error");
}

void websocket_stream::handle_close_by_remote_peer()
{
    if (session_->is_open() && !close_frame_sent_)
    {
        close_connection_impl(codes::close_opcode_normal, "");
    }
    maybe_notify_closed(codes::close_opcode_go_away, "closed by peer");
}

void websocket_stream::handle_close_by_application(uint16_t code, const string& reason)
{
    result_code_ = static_cast<codes::code>(code);
    if (session_->is_open() && !close_frame_sent_)
    {
        close_connection_impl(code, reason);
    }
    // Don't notify if closed by application.
    close_hook_ = close_hook_t();
}

void websocket_stream::handle_io_error(boost::system::error_code const& e, const char* operation)
{
    session_->close();
    if (connection_is_closed(e))
    {
        maybe_notify_closed(codes::close_opcode_go_away, "connection closed");
    }
    else
    {
        maybe_notify_closed(codes::close_opcode_proto_error, "protocol error");
    }

    if (!connection_is_closed(e))
    {
        YLOG_CTX_LOCAL(session_->get_context(), info)
            << "websocket_stream " << operation << " error=\"" << e.message() << "\"";
    }
}

}}
