#pragma once

#include <map>
#include <yplatform/task_context.h>
#include <ymod_webserver/error.h>
#include <ymod_webserver/websocket.h>
#include <ymod_webserver/header.h>

using ymod_webserver::websocket::message;

namespace ymod_webserver { namespace parser {

#define WS_PARSER_LOG(severity) YLOG_CTX_LOCAL(ctx_, severity) << log_prefix_ << " "

template <typename Iterator>
class websocket
{
    typedef Iterator (websocket::*state_handler_t)(Iterator begin, Iterator& start, Iterator end);

public:
    typedef std::pair<Iterator, Iterator> message_t;

    websocket(
        yplatform::task_context_ptr ctx,
        string const& log_prefix,
        upgrade_proto_header upgrade_to,
        size_t max_length)
        : ctx_(ctx)
        , log_prefix_("websocket-session-id=" + log_prefix)
        , handler_(0)
        , binary_length_(0)
        , finished_(false)
        , max_message_length_(max_length)
    {
        if (upgrade_to == upgrade_to_websocket75 || upgrade_to == upgrade_to_websocket76)
        {
            parser_ = &websocket::parse_websocket_1;
        }
        else
        {
            parser_ = &websocket::parse_websocket_2;
        }
        msg_.mask.fill(0);
    }

    websocket(upgrade_proto_header upgrade_to = upgrade_to_websocket75)
        : handler_(0)
        , binary_length_(0)
        , finished_(false)
        , max_message_length_(std::numeric_limits<size_t>::max())
    {
        if (upgrade_to == upgrade_to_websocket75 || upgrade_to == upgrade_to_websocket76)
        {
            parser_ = &websocket::parse_websocket_1;
        }
        else
        {
            parser_ = &websocket::parse_websocket_2;
        }
        msg_.mask.fill(0);
    }

    Iterator operator()(Iterator begin, Iterator& start, Iterator end)
    {
        return (this->*(parser_))(begin, start, end);
    }

    void reset()
    {
        handler_ = 0;
        binary_length_ = 0;
        finished_ = false;
        msg_.mask.fill(0);
    }

    bool is_finished() const
    {
        return finished_;
    }

    const message_t& msg() const
    {
        return message_;
    }

    const ymod_webserver::websocket::message& headers() const
    {
        return msg_;
    }

protected:
    void detect_transfer_mode(uint8_t c)
    {
        if (c == 0) handler_ = &websocket::parse_text;
        else if (c == 0x80)
            handler_ = &websocket::parse_binary_size;
    }

    Iterator parse_websocket_1(Iterator begin, Iterator& start, Iterator end)
    {
        if (!handler_)
        {
            if (begin == end || start == end) return begin;
            detect_transfer_mode(static_cast<uint8_t>(*start));
            if (!handler_)
                throw parse_error("websocket_parser", "invalid message start") << BOOST_ERROR_INFO;
            begin = ++start;
        }
        return parse_websocket_2(begin, start, end);
    }

    Iterator parse_websocket_2(Iterator begin, Iterator& start, Iterator end)
    {
        if (handler_ == 0) handler_ = &websocket::parse_opcode;
        state_handler_t prev_handler;
        if (begin == end || start == end) return begin;
        auto parse_begin = begin;
        while (true)
        {
            prev_handler = handler_;
            parse_begin = (this->*(handler_))(parse_begin, start, end);
            if (prev_handler == handler_ || is_finished() || parse_begin == end) break;
        }
        return parse_begin;
    }

    Iterator parse_text(Iterator begin, Iterator& start, Iterator end)
    {
        while (start != end && *start != static_cast<char>(0xFF))
            ++start;
        if (start == end) return begin;
        message_ = std::make_pair(begin, start++);
        finished_ = true;
        return begin;
    }

    Iterator parse_binary_size(Iterator begin, Iterator& start, Iterator end)
    {
        while (start != end)
        {
            uint8_t v = static_cast<uint8_t>(*start);
            binary_length_ = 128 * binary_length_ + (v & 0x7F);
            if ((v & 0x80) == 0) break; // last size byte
            ++start;
        }
        if (start == end) return start;
        if (binary_length_ > max_message_length_)
            throw parse_error("websocket_parser", "message is too large") << BOOST_ERROR_INFO;

        handler_ = &websocket::parse_binary;
        begin = ++start;
        return parse_binary(begin, start, end);
    }

    Iterator parse_binary(Iterator begin, Iterator& start, Iterator end)
    {
        for (; start != end && binary_length_ > 0; --binary_length_, ++start)
            ;
        if (start == end && binary_length_ > 0) return begin;
        message_ = std::make_pair(begin, start);
        finished_ = true;
        return begin;
    }

    Iterator parse_opcode(
        Iterator begin __attribute__((unused)),
        Iterator& start,
        Iterator end __attribute__((unused)))
    {
        msg_.bits = 0;
        if (*start & message::bit_fin_val) msg_.bits |= message::bit_fin_val;
        if (*start & message::bit_rsv1_val) msg_.bits |= message::bit_rsv1_val;
        if (*start & message::bit_rsv2_val) msg_.bits |= message::bit_rsv2_val;
        if (*start & message::bit_rsv3_val) msg_.bits |= message::bit_rsv3_val;
        msg_.opcode = static_cast<uint8_t>(*start & 0x0F);
        handler_ = &websocket::parse_length;
        return ++start;
    }

    Iterator parse_length(
        Iterator begin __attribute__((unused)),
        Iterator& start,
        Iterator end __attribute__((unused)))
    {
        if (*start & message::bit_mask_val) msg_.bits |= message::bit_mask;
        msg_.length = static_cast<uint8_t>(*start & 0x7F);
        handler_ = &websocket::parse_length_ext;
        if (msg_.length == 126)
        {
            binary_length_ = 2;
            msg_.length = 0;
        }
        else if (msg_.length == 127)
        {
            binary_length_ = 8;
            msg_.length = 0;
        }
        else if (msg_.bits & ymod_webserver::websocket::message::bit_mask)
        {
            handler_ = &websocket::parse_mask;
            binary_length_ = 4;
        }
        else
        {
            handler_ = &websocket::parse_binary;
            binary_length_ = msg_.length;
        }
        return ++start;
    }

    Iterator parse_length_ext(Iterator begin __attribute__((unused)), Iterator& start, Iterator end)
    {
        for (; start != end && binary_length_ > 0; --binary_length_, ++start)
        {
            msg_.length = (msg_.length << 8) + static_cast<uint8_t>(0xFF & *start);
        }
        if (msg_.length > max_message_length_)
            throw parse_error("websocket_parser", "message is too large") << BOOST_ERROR_INFO;

        if (binary_length_ == 0)
        {
            if (msg_.bits & ymod_webserver::websocket::message::bit_mask)
            {
                handler_ = &websocket::parse_mask;
                binary_length_ = 4;
            }
            else
            {
                handler_ = &websocket::parse_binary;
                binary_length_ = msg_.length;
            }
        }
        return start;
    }

    Iterator parse_mask(Iterator begin __attribute__((unused)), Iterator& start, Iterator end)
    {
        assert(binary_length_ <= sizeof(msg_.mask));
        for (; start != end && binary_length_ > 0; --binary_length_, ++start)
        {
            msg_.mask[sizeof(msg_.mask) - binary_length_] = static_cast<uint8_t>(*start);
        }
        if (binary_length_ == 0)
        {
            handler_ = &websocket::parse_binary;
            binary_length_ = msg_.length;
        }
        return start;
    }

private:
    yplatform::task_context_ptr ctx_;
    string log_prefix_;
    state_handler_t parser_;
    state_handler_t handler_;
    std::size_t binary_length_;
    message_t message_;
    ymod_webserver::websocket::message msg_;
    bool finished_;
    size_t max_message_length_;
};

template <typename Stream>
void make_websocket_binary_size(Stream& stream, std::size_t size)
{
    stream << '\x80';
    std::size_t i = size / 128;
    for (; i; i = (size %= 128) / 128)
    {
        i |= 0x80;
        stream << static_cast<char>(i);
    }
    stream << static_cast<char>(size);
}

}}
