#pragma once

#include <yplatform/net/stream/detail/sslkeylog.h>
#include <yplatform/net/stream/detail/reduce_ssl_buffers.h>
#include <boost/asio/detail/throw_error.hpp>
#include <boost/function.hpp>
#include <boost/noncopyable.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/asio/ssl/stream.hpp>

#define SSLKEYLOG_ENV "SSLKEYLOG"

namespace yplatform { namespace net { namespace stream {

template <typename Socket>
class optional_ssl_socket : private boost::noncopyable
{
public:
    typedef Socket socket_type;
    typedef boost::asio::ssl::stream<socket_type&> ssl_stream_type;
    typedef boost::asio::socket_base::shutdown_type shutdown_type;

    optional_ssl_socket(socket_type& socket) : socket_(socket), ssl_stream_()
    {
    }

    boost::asio::io_service& io_service()
    {
        return socket_.get_io_service();
    }

    boost::asio::io_service& get_io_service()
    {
        return socket_.get_io_service();
    }

    void set_tls_server_name(boost::asio::ssl::context& context, const char* name)
    {
        if (!ssl_stream_) init_ssl_stream(context);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wold-style-cast"
        SSL_set_tlsext_host_name(ssl_stream_->native_handle(), name);
#pragma GCC diagnostic pop
    }

    template <typename Handler>
    void set_tls_verify_callback(boost::asio::ssl::context& context, Handler&& handler)
    {
        if (!ssl_stream_) init_ssl_stream(context);
        ssl_stream_->set_verify_callback(std::forward<Handler>(handler));
    }

    template <typename Handler>
    void async_start_tls(
        boost::asio::ssl::context& context,
        boost::asio::ssl::stream_base::handshake_type type,
        Handler handler)
    {
        if (!ssl_stream_) init_ssl_stream(context);

        static const bool SSLKEYLOG = std::getenv(SSLKEYLOG_ENV) != nullptr;
        if (SSLKEYLOG) init_tap_state();
        ssl_stream_->async_handshake(type, handler);
    }

    template <typename Handler, typename VerifyHandler>
    void async_start_tls(
        boost::asio::ssl::context& /*context*/,
        boost::asio::ssl::stream_base::handshake_type type,
        Handler handler,
        VerifyHandler&& verify_handler)
    {
        set_tls_verify_callback(std::forward<VerifyHandler>(verify_handler));

        static const bool SSLKEYLOG = std::getenv(SSLKEYLOG_ENV) != nullptr;
        if (SSLKEYLOG) init_tap_state();
        ssl_stream_->async_handshake(type, handler);
    }

    void cancel()
    {
        socket_.cancel();
    }

    void shutdown(shutdown_type what)
    {
        // do not ssl shutdown - no really need
        boost::system::error_code ec;
        shutdown(what, ec);
    }

    boost::system::error_code shutdown(shutdown_type what, boost::system::error_code& ec)
    {
        // do not ssl shutdown - no really need
        if (!socket_.is_open()) return ec;
        return socket_.shutdown(what, ec);
    }

    void close()
    {
        boost::system::error_code ec;
        close(ec);
    }

    boost::system::error_code close(boost::system::error_code& ec)
    {
        if (!socket_.is_open()) return ec;

        socket_.shutdown(boost::asio::socket_base::shutdown_both, ec);
        return socket_.close(ec);
    }

    template <typename Buffer>
    std::size_t read_some(Buffer buffer)
    {
        return ssl_stream_ ? ssl_stream_->read_some(buffer) : socket_.read_some(buffer);
    }

    template <typename Buffer>
    std::size_t read_some(Buffer buffer, boost::system::error_code& ec)
    {
        return ssl_stream_ ? ssl_stream_->read_some(buffer, ec) : socket_.read_some(buffer, ec);
    }

    template <typename Buffer, typename Handler>
    void async_read_some(Buffer buffer, Handler handler)
    {
        return ssl_stream_ ? ssl_stream_->async_read_some(buffer, handler) :
                             socket_.async_read_some(buffer, handler);
    }

    template <typename Buffer>
    std::size_t write_some(Buffer buffer)
    {
        return ssl_stream_ ? ssl_stream_->write_some(buffer) : socket_.write_some(buffer);
    }

    template <typename Buffer>
    std::size_t write_some(Buffer buffer, boost::system::error_code& ec)
    {
        return ssl_stream_ ? ssl_stream_->write_some(buffer, ec) : socket_.write_some(buffer, ec);
    }

    template <typename Buffer, typename Handler>
    void async_write_some(Buffer buffer, Handler handler)
    {
        return ssl_stream_ ? ssl_stream_->async_write_some(buffer, handler) :
                             socket_.async_write_some(buffer, handler);
    }

    bool is_open()
    {
        return socket_.is_open();
    }

    boost::shared_ptr<ssl_stream_type> get_ssl_stream()
    {
        return ssl_stream_;
    }

    string tap_ssl_key()
    {
        return ssl_stream_ ? detail::ssl_tap_key(ssl_stream_->native_handle(), tap_state_) :
                             string();
    }

    void reduce_ssl_buffers(size_t buffer_size)
    {
        if (!ssl_stream_) return;
        detail::reduce_ssl_buffers(*ssl_stream_, buffer_size);
    }

private:
    void init_ssl_stream(boost::asio::ssl::context& context)
    {
        ssl_stream_ = boost::make_shared<ssl_stream_type>(socket_, context);
    }

    void init_tap_state()
    {
        detail::ssl_tap_state_init(ssl_stream_->native_handle(), tap_state_);
    }

    socket_type& socket_;
    boost::shared_ptr<ssl_stream_type> ssl_stream_;
    detail::ssl_tap_state tap_state_;
};

}}}
