#pragma once

#include <yplatform/log.h>
#include <yplatform/net/types.h>
#include <yplatform/net/io_data.h>
#include <yplatform/net/keep_alive.h>
#include <yplatform/net/settings.h>
#include <yplatform/net/stream/optional_ssl_socket.h>
#include <yplatform/net/handlers/mem_alloc.h>
#include <yplatform/net/handlers/protect_wrapper.h>
#include <yplatform/time_traits.h>
#include <boost/asio.hpp>

#define PROTECTED_TYPE(Handler) detail::protect_wrapper<Handler>

namespace yplatform { namespace net {

namespace ph = std::placeholders;

struct operation_context
{
    operation_context(boost::asio::io_service& io) : timer(io), active(false)
    {
    }

    template <typename Socket>
    void prepare(Socket& socket, const time_traits::time_point& deadline)
    {
        assert(!active);
        active = true;
        timer.expires_at(deadline);
        timer.async_wait([&socket](const boost::system::error_code& ec) {
            if (!ec)
            {
                socket.cancel_operations();
            }
        });
    }

    template <typename Socket>
    void prepare(Socket& socket, const time_traits::duration& timeout)
    {
        if (timeout != yplatform::time_traits::duration::max())
        {
            prepare(socket, time_traits::clock::now() + timeout);
        }
        else
        {
            assert(!active);
            active = true;
        }
    }

    // TODO post
    void complete()
    {
        assert(active);
        timer.cancel();
        active = false;
    }

    time_traits::timer timer;
    bool active;
    handler_allocator<1024> allocator;
};

// Implements basic operations on raw/ssl socket with timeouts
// and custom memory allocators.
template <typename Socket = boost::asio::ip::tcp::socket>
class universal_socket
{
public:
    typedef settings settings_t;
    typedef Socket raw_socket_t;
    typedef typename raw_socket_t::endpoint_type endpoint_t;
    typedef stream::optional_ssl_socket<raw_socket_t> stream_t;

    typedef settings settings_type;
    typedef Socket raw_socket_type;
    typedef typename raw_socket_t::endpoint_type endpoint_type;
    typedef typename boost::asio::ip::address address_type;
    typedef stream::optional_ssl_socket<raw_socket_t> stream_type;
    typedef boost::asio::ssl::stream_base::handshake_type handshake_type;

    // TODO extern set options
    universal_socket(io_data& io_data, const settings_t& settings = settings_t())
        : io_data_(&io_data)
        , io_(*io_data.get_io())
        , settings_(settings)
        , raw_socket_(*io_data.get_io())
        , stream_(raw_socket_)
        , remote_port_(0)
        , local_port_(0)
        , read_context_(*io_data.get_io())
        , write_context_(*io_data.get_io())
    {
    }

    universal_socket(boost::asio::io_service& io, const settings_t& settings = settings_t())
        : io_data_(nullptr)
        , io_(io)
        , settings_(settings)
        , raw_socket_(io_)
        , stream_(raw_socket_)
        , remote_port_(0)
        , local_port_(0)
        , read_context_(io_)
        , write_context_(io_)
    {
    }

    universal_socket(const universal_socket&) = delete;
    universal_socket& operator=(const universal_socket&) = delete;

    // There is no need to move all internal data
    // because neither universal socket nor asio socket
    // are not designed to be moved when any asynchronous operations are active.
    // Notice that stream is not moved because it refers to local raw_socket_.
    universal_socket(universal_socket&& other)
        : io_data_(other.io_data_)
        , io_(other.io_)
        , settings_(std::move(other.settings_))
        , raw_socket_(std::move(other.raw_socket_))
        , stream_(raw_socket_)
        , remote_addr_(std::move(other.remote_addr_))
        , remote_port_(std::move(other.remote_port_))
        , local_addr_(other.local_addr_)
        , local_port_(other.local_port_)
        , read_context_(io_)
        , write_context_(io_)
    {
        // Ensure that stream has not been initialized yet.
        // We don't check for active operations here.
        assert(!stream_.get_ssl_stream());
    }

    universal_socket& operator=(universal_socket&& other)
    {
        if (this != &other)
        {
            assert(&io_ == &other.io_);
            assert(!other.stream_.get_ssl_stream());

            settings_ = std::move(other.settings_);
            raw_socket_ = std::move(other.raw_socket_);
            remote_addr_ = std::move(other.remote_addr_);
            local_addr_ = std::move(other.local_addr_);
            remote_port_ = other.remote_port_;
            local_port_ = other.local_port_;
        }
        return *this;
    }

    bool is_open() const
    {
        return raw_socket().is_open();
    }

    std::size_t available(boost::system::error_code& ec) const
    {
        return raw_socket().available(ec);
    }

    void connect(const endpoint_t& ep)
    {
        raw_socket().connect(ep);
    }

    void close()
    {
        boost::system::error_code ec;
        if (stream().is_open())
        {
            stream().close(ec);
        }
    }

    void cancel_operations()
    {
        if (stream().is_open())
        {
            stream().cancel();
        }
    }

    void shutdown(bool graceful)
    {
        if (graceful)
        {
            stream().shutdown(raw_socket_t::shutdown_send);
        }
        else
        {
            stream().cancel();
            stream().close();
        }
    }

    void const* id() const
    {
        return static_cast<void const*>(this);
    }

    stream_t& stream()
    {
        return stream_;
    }

    const boost::asio::ip::address& remote_addr() const
    {
        return remote_addr_;
    }

    unsigned short remote_port() const
    {
        return remote_port_;
    }

    const boost::asio::ip::address& local_addr() const
    {
        return local_addr_;
    }

    unsigned short local_port() const
    {
        return local_port_;
    }

    boost::asio::io_service* get_io()
    {
        return &io_;
    }

    const boost::asio::io_service* get_io() const
    {
        return &io_;
    }

    io_data& get_io_data()
    {
        if (!io_data_)
        {
            throw std::runtime_error("socket io_data is not initialized");
        }
        return *io_data_;
    }

    raw_socket_t& raw_socket()
    {
        return raw_socket_;
    }

    const raw_socket_t& raw_socket() const
    {
        return raw_socket_;
    }

    template <typename Dummy = void>
    void open(
        error_code& ec,
        typename std::enable_if<
            std::is_same<raw_socket_t, boost::asio::ip::tcp::socket>::value,
            Dummy>::type* = 0)
    {
        ec = error_code();
        endpoint_t ep = raw_socket().remote_endpoint(ec);
        if (ec) return;
        remote_addr_ = ep.address();
        remote_port_ = ep.port();
        ep = raw_socket().local_endpoint(ec);
        if (ec) return;
        local_addr_ = ep.address();
        local_port_ = ep.port();
        setup_keep_alive(raw_socket(), get_settings());
        if (settings_.tcp_no_delay)
        {
            using boost::asio::detail::socket_option::boolean;
            raw_socket().set_option(boolean<IPPROTO_TCP, TCP_NODELAY>(true));
        }
        if (settings_.tcp_quickack)
        {
            using boost::asio::detail::socket_option::boolean;
            raw_socket().set_option(boolean<IPPROTO_TCP, TCP_QUICKACK>(true));
        }
    }

    template <typename Dummy = void>
    void open(
        error_code& /*ec*/,
        typename std::enable_if<
            !std::is_same<raw_socket_t, boost::asio::ip::tcp::socket>::value,
            Dummy>::type* = 0)
    {
    }

    template <typename Handler, typename TimeLimit>
    void async_connect(const endpoint_t& ep, const TimeLimit& limit, Handler&& handler)
    {
        write_context_.prepare(*this, limit);
        using handler_t = typename std::decay<Handler>::type;
        raw_socket().async_connect(
            ep,
            make_custom_alloc_handler(
                write_context_.allocator,
                [this, handler = handler_t(std::forward<Handler>(handler))](const error_code& ec) {
                    write_context_.complete();
                    if (!ec)
                    {
                        error_code open_ec;
                        this->open(open_ec);
                        io_.post(std::bind(std::move(handler), open_ec));
                    }
                    else
                    {
                        io_.post(std::bind(std::move(handler), ec));
                    }
                }));
    }

    void set_tls_server_name(const char* name)
    {
        stream().set_tls_server_name(*get_io_data().get_ssl_context(), name);
    }

    template <typename Handler>
    void set_tls_verify_callback(Handler&& handler)
    {
        stream().set_tls_verify_callback(
            *get_io_data().get_ssl_context(), std::forward<Handler>(handler));
    }

    template <typename Handler, typename TimeLimit>
    void async_tls_handshake(
        handshake_type tls_handshake_type,
        const TimeLimit& limit,
        Handler&& handler)
    {
        write_context_.prepare(*this, limit);
        this->raw_socket().cancel();
        stream().async_start_tls(
            *get_io_data().get_ssl_context(),
            tls_handshake_type,
            make_custom_alloc_handler(
                write_context_.allocator,
                std::bind(
                    &universal_socket::handle_tls_enabled<PROTECTED_TYPE(Handler)>,
                    this,
                    protect_handler(std::forward<Handler>(handler)),
                    ph::_1)));
    }

    template <typename Handler, typename VerifyHandler, typename TimeLimit>
    void async_tls_handshake(
        handshake_type tls_handshake_type,
        const TimeLimit& limit,
        Handler&& handler,
        VerifyHandler&& verify_handler)
    {
        write_context_.prepare(*this, limit);
        this->raw_socket().cancel();
        stream().async_start_tls(
            *get_io_data().get_ssl_context(),
            tls_handshake_type,
            make_custom_alloc_handler(
                write_context_.allocator,
                std::bind(
                    &universal_socket::handle_tls_enabled<PROTECTED_TYPE(Handler)>,
                    this,
                    protect_handler(std::forward<Handler>(handler)),
                    ph::_1)),
            std::forward<VerifyHandler>(verify_handler));
    }

    template <typename Buffer, typename Handler, typename TimeLimit>
    void async_read(Buffer&& buff, const TimeLimit& limit, Handler&& handler, std::size_t min = 1U)
    {
        read_context_.prepare(*this, limit);
        boost::asio::async_read(
            stream(),
            buff,
            boost::asio::transfer_at_least(min),
            make_custom_alloc_handler(
                read_context_.allocator,
                std::bind(
                    &universal_socket::handle_read<PROTECTED_TYPE(Handler)>,
                    this,
                    protect_handler(std::forward<Handler>(handler)),
                    ph::_1,
                    ph::_2)));
    }

    template <typename Buffer, typename Handler, typename Filter, typename TimeLimit>
    void async_read_until(
        Buffer&& buff,
        const Filter& filter,
        const TimeLimit& limit,
        Handler&& handler)
    {
        read_context_.prepare(*this, limit);
        boost::asio::async_read_until(
            stream(),
            buff,
            filter,
            make_custom_alloc_handler(
                read_context_.allocator,
                std::bind(
                    &universal_socket::handle_read<PROTECTED_TYPE(Handler)>,
                    this,
                    protect_handler(std::forward<Handler>(handler)),
                    ph::_1,
                    ph::_2)));
    }

    template <typename Buffer, typename Handler, typename TimeLimit>
    void async_write(Buffer&& buff, const TimeLimit& limit, Handler&& handler)
    {
        write_context_.prepare(*this, limit);
        boost::asio::async_write(
            stream(),
            buff,
            boost::asio::transfer_all(),
            make_custom_alloc_handler(
                write_context_.allocator,
                std::bind(
                    &universal_socket::handle_write<PROTECTED_TYPE(Handler)>,
                    this,
                    protect_handler(std::forward<Handler>(handler)),
                    ph::_1,
                    ph::_2)));
    }

    const settings_t& get_settings() const
    {
        return settings_;
    }

protected:
    template <typename Handler>
    void handle_tls_enabled(Handler& handler, const error_code& e)
    {
        write_context_.complete();
        io_.post(std::bind(std::forward<Handler>(handler), e));

        static const bool SSLKEYLOG = std::getenv(SSLKEYLOG_ENV) != nullptr;
        if (SSLKEYLOG)
        {
            auto sslkey = stream_.tap_ssl_key();
            if (sslkey.size()) YLOG_G(info) << sslkey;
        }

        auto& ssl_settings = get_io_data().get_ssl_settings();
        if (ssl_settings.reduce_excess_buffers)
        {
            stream().reduce_ssl_buffers(ssl_settings.reduced_buffer_size);
        }
    }

    template <typename Handler>
    void handle_read(Handler& handler, const error_code& e, std::size_t bytes)
    {
        read_context_.complete();
        io_.post(std::bind(std::forward<Handler>(handler), e, bytes));
    }

    template <typename Handler>
    void handle_write(Handler& handler, const error_code& e, std::size_t bytes)
    {
        write_context_.complete();
        io_.post(std::bind(std::forward<Handler>(handler), e, bytes));
    }

private:
    io_data* io_data_;
    boost::asio::io_service& io_;
    settings_t settings_;
    raw_socket_t raw_socket_;
    stream_t stream_;

    boost::asio::ip::address remote_addr_;
    unsigned short remote_port_;
    boost::asio::ip::address local_addr_;
    unsigned short local_port_;

    operation_context read_context_;
    operation_context write_context_; // it is also used for connect and tls handshake
};

}}

#undef PROTECTED_TYPE
