#include "starter.h"

#include <boost/bind/protect.hpp>
#include "validator.h"
#include "parser/body.h"
#include "websocket_stream.h"
#include <yplatform/encoding/url_encode.h>
#include <yplatform/net/handlers/timer_handler.h>
#include <yplatform/context_repository.h>
#include "net_server.h"
#include "http_stream.h"

#include <array>

#define CATCH_BLOCK(func, parse_msg, http_msg, def_msg)                                            \
    catch (const ymod_webserver::parse_error& e)                                                   \
    {                                                                                              \
        make_response()->result(codes::bad_request);                                               \
        TRW_LOG(error) << func << ": failed to " << parse_msg                                      \
                       << ", connection broken. reason: " << e.private_message();                  \
        process_result_.state = process_result::finished;                                          \
        return;                                                                                    \
    }                                                                                              \
    catch (const ymod_webserver::http_error& e)                                                    \
    {                                                                                              \
        process_exception(e, make_response());                                                     \
        TRW_LOG(error) << func << ": failed to " << http_msg                                       \
                       << ", connection broken. reason: " << e.private_message();                  \
        process_result_.state = process_result::finished;                                          \
        return;                                                                                    \
    }                                                                                              \
    catch (const std::exception& e)                                                                \
    {                                                                                              \
        make_response()->result(codes::internal_server_error);                                     \
        TRW_LOG(error) << func << ": failed to " << def_msg                                        \
                       << ", connection broken. reason: " << e.what();                             \
        process_result_.state = process_result::finished;                                          \
        return;                                                                                    \
    }                                                                                              \
    catch (...)                                                                                    \
    {                                                                                              \
        make_response()->result(codes::internal_server_error);                                     \
        TRW_LOG(error) << func << ": failed to " << def_msg                                        \
                       << ", connection broken. reason: unknown";                                  \
        process_result_.state = process_result::finished;                                          \
        return;                                                                                    \
    }

#define TRW_LOG(severity)                                                                          \
    YLOG_CTX_LOCAL(session_->ctx(), severity)                                                      \
        << "starter-session-id=" << boost::lexical_cast<string>(session_.get()) << " "

namespace ymod_webserver {

namespace {
inline void do_nothing(boost::shared_ptr<starter> /*starter*/)
{
}

bool has_body(methods::http_method method)
{
    using namespace methods;
    static const std::array<bool, 18> has_body{ {
        false, // mth_invalid
        false, // mth_get
        false, // mth_head
        false, // mth_bmove
        true,  // mth_delete
        false, // mth_connect
        false, // mth_checkout
        true,  // mth_post
        false, // mth_propfind
        false, // mth_proppath
        true,  // mth_put
        false, // mth_trace
        false, // mth_options
        false, // mth_report
        false, // mth_move
        false, // mth_mkactivity
        false, // mth_mkcol
        false  // mth_merge
    } };
    return has_body[std::size_t(method)];
}

}

starter::starter(
    boost::asio::io_service& io,
    boost::weak_ptr<net_server> owner,
    net_session_ptr session,
    const settings& settings)
    : owner_(owner)
    , session_(session)
    , settings_(settings)
    , io_(&io)
    , readq_(boost::make_shared<buffer_t>())
    , ctx_(boost::make_shared<context>())
    , post_mode_(request_body::read_mode_parsed_buffer)
{
    init_context();
    session->ctx()->set_nested(ctx());
    ctx()->profilers.push("preprocess");
}

starter::~starter()
{
    ctx()->profilers.pop("preprocess");
    process_result_.read_buffer = release_read_buffer();
    if (net_server_ptr powner = owner_.lock()) try
        {
            powner->on_destroy(this, process_result_);
        }
        catch (...)
        {
        }
}

void starter::run()
{
    begin_new_request(boost::system::error_code());
}

void starter::abort()
{
    session_->cancel_operations(boost::bind(do_nothing, shared_from_this()));
}

void starter::init_context()
{
    ctx_->remote_address = session_->ctx()->remote_address;
    ctx_->remote_port = session_->ctx()->remote_port;
    ctx_->local_address = session_->ctx()->local_address;
    ctx_->local_port = session_->ctx()->local_port;
}

void starter::begin_new_request(const boost::system::error_code& e)
{
    if (e) return;
    ctx()->state = context_state_t::read_headers;
    clear_data();
    if (readq_->size()) parse_headers();
    else
        begin_read_headers();
}

void starter::clear_data()
{
    parser_.reset(ctx());
    transfer_parser_.reset();
    i_saved_ = readq_->begin();
}

bool starter::check_request_line_or_headers_size()
{
    size_t request_line_size = parser_.req()->raw_request_line.size();
    if (!parser_.request_line_finished())
    {
        if (request_line_size + readq_->size() > settings_.max_request_line_size)
        {
            TRW_LOG(info) << "too large request line";
            make_response()->result(codes::request_uri_too_long);
            process_result_.state = process_result::finished;
            return false;
        }
    }
    else
    {
        size_t total_size = parser_.req()->ctx()->bytes_received - request_line_size;
        if (total_size > settings_.max_headers_size)
        {
            TRW_LOG(info) << "too large headers";
            make_response()->result(codes::request_entity_too_large);
            process_result_.state = process_result::finished;
            return false;
        }
    }
    return true;
}

void starter::begin_read_headers()
{
    session_->begin_read(
        boost::bind(&starter::handle_read_headers, shared_from_this(), _1, _2),
        readq_->prepare(settings_.read_chunk_size),
        1);
}

void starter::handle_read_headers(const boost::system::error_code& e, std::size_t bytes)
{
    ctx_->bytes_received += bytes;
    if (e)
    {
        if (connection_is_closed(e))
        {
            if (process_result_.state == process_result::inited)
            {
                process_result_.state = process_result::finished;
                return;
            }
        }
        else if (session_->is_open())
        {
            session_->async_close(boost::bind(do_nothing, shared_from_this()));
        }
        TRW_LOG(info) << "handle_read_headers error message=\"" << e.message() << "\"";
        process_result_.state = process_result::failed;
        return;
    }
    if (bytes == 0)
    {
        begin_read_headers();
        return;
    }
    if (process_result_.state == process_result::inited)
        process_result_.state = process_result::processing;
    readq_->commit(bytes);
    parse_headers();
}

void starter::parse_headers()
{
    try
    {
        read_iterator start = readq_->begin();
        read_iterator end = readq_->end();
        readq_->detach(parser_(start, i_saved_, end));
        if (!parser_.is_finished())
        {
            if (!check_request_line_or_headers_size()) return;
            begin_read_headers();
            return;
        }
        if (parser_.is_policy_request())
        {
            begin_write_policy();
            process_result_.state = process_result::finished;
            return;
        }
        process_result_.request = parser_.req();
        validator headers_validator;
        bool expect = headers_validator(session_->endpoint(), process_result_.request);
        if (process_result_.request->vhost.port == 0)
            process_result_.request->vhost.port = session_->local_port();

        ctx()->state = context_state_t::execute;

        return execute(expect);
    }
    CATCH_BLOCK("parse_headers", "parse request", "execute request", "load headers")
}

void starter::begin_write_policy()
{
    session_->send_client_stream(
        yplatform::net::buffers::make_const_cstdcont_chunk(settings_.policy_file));
}

void starter::execute(bool expect_100_continue)
{
    request_ptr req = process_result_.request;
    req->context = ctx(); // TODO XXX redundant?
    ctx()->set_request(req);
    response_ptr resp;
    try
    {
        if (upgrade()) return;
        if (has_body(req->method))
        {
            post_mode_ = session_->handler()->parse_post_mode(req);
            if (expect_100_continue)
            {
                session_->client_stream() << "HTTP/1.1 100 Continue\r\n\r\n";
            }
            if (post_mode_ == request_body::read_mode_stream)
            {
                // TODO implement the receiver
                //        req->post_receiver = shared_from_this();
                process_result_.state = process_result::continue_http;
                return;
            }
            else
            {
                // [rfc] if transfer encoding is defined and != identity
                // message length is retrieved via chunked transfer-coding,
                // content-length header isn't used
                if (req->content.length == 0 &&
                    req->transfer_encoding == transfer_encoding_identity)
                {
                    process_result_.state = process_result::continue_http;
                    return;
                }

                if (req->content.length > settings_.max_post_size)
                {
                    resp = make_response();
                    resp->result(codes::request_entity_too_large);
                    process_result_.state = process_result::finished;
                    return;
                }

                transfer_parser_.reset(new transfer_parser_t(req, readq_));
                start_read_body(&starter::parse_body_mpost, req->content.length);
            }
        }
        else if (req->method == methods::mth_get)
        {
            process_result_.state = process_result::continue_http;
            return;
        }
        else if (
            req->method == methods::mth_options && settings_.access_control_allow_origin.size())
        {
            resp = make_response();
            resp->set_code(codes::ok);
            resp->add_header("Access-Control-Allow-Origin", settings_.access_control_allow_origin);
            resp->add_header(
                "Access-Control-Allow-Headers", settings_.access_control_allow_headers);
            resp->result_body("");
        }
        else
        {
            resp = make_response();
            resp->result(codes::not_implemented);
        }

        // TODO remove it
    }
    catch (const http_error& err)
    {
        TRW_LOG(error) << "execute http exception message=\"" << err.private_message() << "\"";
        if (!resp) resp = make_response();
        process_exception(err, resp);
    }
    catch (const std::exception& e)
    {
        TRW_LOG(error) << "execute exception message=\"" << e.what() << "\"";
        if (!resp) make_response();
    }
    catch (...)
    {
        TRW_LOG(error) << "execute exception message=none";
        if (!resp) make_response();
    }

    if (resp)
    {
        process_result_.state = process_result::finished;
    }
}

bool starter::upgrade()
{
    request_ptr req = process_result_.request;
    if (req->upgrade_to == upgrade_none) return false;
    if (req->upgrade_to == upgrade_to_websocket76)
    {
        start_read_body(&starter::parse_body_ws76_upgrade, 8);
    }
    else if (
        req->upgrade_to == upgrade_to_websocket75 || req->upgrade_to == upgrade_to_websocket07 ||
        req->upgrade_to == upgrade_to_websocket08 || req->upgrade_to == upgrade_to_websocket13)
    {
        process_result_.state = process_result::continue_websocket;
    }
    else
    {
        response_ptr resp = make_response();
        resp->set_code(ymod_webserver::codes::bad_request);
        resp->add_header("Sec-WebSocket-Version", "13, 8, 7");
        process_result_.state = process_result::finished;
    }
    return true;
}

response_ptr starter::make_response()
{
    // find an existing request to prevent a http stream crash
    request_ptr req = process_result_.request ? process_result_.request : parser_.req();

    // assume default version of HTTP if version is not known by now
    // (for example, if the first line of the request
    // exceeds request size limit)
    if (req->proto_version.first == 0)
    {
        req->proto_version.first = settings_.http_default_version.major;
        req->proto_version.second = settings_.http_default_version.minor;
    }

    boost::shared_ptr<http_stream> stream =
        boost::make_shared<http_stream>(*io_, owner_, ctx(), session_, req, release_read_buffer());
    stream->init();
    return stream;
}

void starter::handle_read_body(
    const boost::system::error_code& e,
    parse_body_function parse_body,
    std::size_t bytes)
{
    ctx_->bytes_received += bytes;
    if (e)
    {
        if (!connection_is_closed(e) && session_->is_open())
        {
            session_->async_close(boost::bind(do_nothing, shared_from_this()));
        }
        TRW_LOG(info) << "handle_read_body error message=\"" << e.message() << "\"";
        process_result_.state = process_result::failed;
        return;
    }
    if (bytes == 0)
    {
        begin_read_body(parse_body);
        return;
    }
    readq_->commit(bytes);
    (this->*parse_body)(bytes);
}

void starter::parse_body_mpost(std::size_t bytes)
{
    try
    {
        read_iterator start = readq_->begin();
        read_iterator end = readq_->end();
        request_ptr req = process_result_.request;
        readq_->detach((*transfer_parser_)(start, i_saved_, end, bytes));
        if (!transfer_parser_->is_finished())
        {
            begin_read_body(&starter::parse_body_mpost);
            return;
        }
        req->raw_body = transfer_parser_->result(i_saved_);
        if (post_mode_ == request_body::read_mode_parsed_buffer) parser::parse_body(req);
        process_result_.state = process_result::continue_http;
        return;
    }
    CATCH_BLOCK("parse_body", "parse request body", "execute request", "load body")
}

void starter::parse_body_ws76_upgrade(std::size_t /*bytes*/)
{
    try
    {
        request_ptr req = process_result_.request;
        if (req->upgrade_to == upgrade_to_websocket76)
        {
            i_saved_ = readq_->end();
            req->raw_body = readq_->detach(i_saved_);
            process_result_.state = process_result::continue_websocket;
            return;
        }
        else
        {
            response_ptr resp = make_response();
            resp->set_code(ymod_webserver::codes::bad_request);
            resp->add_header("Sec-WebSocket-Version", "13, 8, 7");
            process_result_.state = process_result::finished;
        }
    }
    CATCH_BLOCK(
        "parse_body_ws76_upgrade", "parse ws76 upgrade body", "execute request", "load body")
}

void starter::start_read_body(parse_body_function parse_body, std::size_t min)
{
    if (readq_->size() == 0) return begin_read_body(parse_body, (min > 0 ? min : 1));
    if (min > 0)
    {
        if (readq_->size() < min) (this->*parse_body)(readq_->size());
        else
            (this->*parse_body)(min);
    }
    else
    {
        (this->*parse_body)(readq_->size());
    }
}

void starter::begin_read_body(parse_body_function parse_body, std::size_t min)
{
    if (readq_->size() > settings_.max_post_size)
        throw http_error() << http_result_code(codes::request_entity_too_large);

    session_->begin_read(
        boost::protect(boost::bind(
            &starter::handle_read_body,
            this->shared_from_this(),
            boost::asio::placeholders::error,
            parse_body,
            boost::asio::placeholders::bytes_transferred)),
        readq_->prepare(settings_.read_chunk_size),
        min);
}

}
