#include "smtp_session_impl.h"
#include "reply_codes.h"

#include <boost/algorithm/string.hpp>

namespace ymod_smtpclient {

namespace {

void appendTerminationOctet(std::string& msg) {
    if (boost::algorithm::ends_with(msg, "\r\n")) {
        msg.append(".\r\n");
    } else {
        msg.append("\r\n.\r\n");
    }
}

std::string applyDotStuff(const std::string& msg) {
    DotStuffer dotStuffer;
    std::string newMsg;
    dotStuffer.apply(msg.begin(), msg.end(), newMsg);
    return newMsg;
}

}  // namespace anonymous

SmtpSessionImpl::SmtpSessionImpl(
    yplatform::net::io_data& ioData,
    const Settings& settings,
    const yplatform::log::source& logger
)
    : yplatform::log::contains_logger(logger)
    , socket(ioData, settings.socket)
    , connectOp(socket, settings.socket.resolve_order)
    , authEngine(settings.maxAuthLineLength)
    , timeouts(settings.connectTimeout, settings.connectAttemptTimeout,
               settings.commandTimeout, settings.dataTimeout)
    , smtpSslDefaultPort(settings.smtpSslDefaultPort)
{
    connectOp.logger(logger);
}

SmtpSessionImpl::SmtpSessionImpl(
    yplatform::net::io_data& ioData,
    const Settings& settings
) : SmtpSessionImpl(ioData, settings, {}) {}

SmtpSessionImpl::~SmtpSessionImpl() {
    try {
        close();
    } catch (...) {
    }
}

void SmtpSessionImpl::setContext(const ContextPtr& newCtx) {
    ctx = newCtx;
}

void SmtpSessionImpl::setTimeouts(const Timeouts& newTimeouts) {
    timeouts = newTimeouts;
}

void SmtpSessionImpl::cancel() {
    socket.cancel_operations();
}

void SmtpSessionImpl::close() {
    socket.close();
}

void SmtpSessionImpl::shutdown(bool gracefull) {
    socket.shutdown(gracefull);
}

bool SmtpSessionImpl::isOpen() const {
    return socket.is_open();
}

bool SmtpSessionImpl::isEncrypted() const {
    return encrypted;
}

boost::asio::ip::address SmtpSessionImpl::getLocalAddr() const {
    return socket.local_addr();
}

boost::asio::ip::address SmtpSessionImpl::getRemoteAddr() const {
    return socket.remote_addr();
}

uint16_t SmtpSessionImpl::getLocalPort() const {
    return socket.local_port();
}

uint16_t SmtpSessionImpl::getRemotePort() const {
    return socket.remote_port();
}

void const* SmtpSessionImpl::getId() const {
    return socket.id();
}

ServerExtensions SmtpSessionImpl::getServerExtensions() const {
    return serverExtensions;
}

bool SmtpSessionImpl::updateServerExtensions(const Response& resp) {
    try {
        serverExtensions = parse_smtp_server_extensions(resp);
    } catch (const std::exception& e) {
        YLOG_CTX_LOCAL(ctx, error)
            << "Unable to parse server extensions"
            << ", reason: " << e.what();
        return false;
    }
    return true;
}

void SmtpSessionImpl::clearBuffers() {
    readBuf.consume(readBuf.size());
    writeBuf.consume(writeBuf.size());
}

bool SmtpSessionImpl::isSslRequired(uint16_t port, Optional<bool> useSsl) const {
    if (useSsl) {
        return *useSsl;
    }
    return port == smtpSslDefaultPort;
}

void SmtpSessionImpl::asyncConnect(std::string host, uint16_t port, Optional<bool> useSsl, Handler handler) {
    auto self = shared_from_this();
    auto needSsl = isSslRequired(port, useSsl);
    connectOp.perform(host, port, timeouts.connectAttempt, timeouts.connect,
        [self, needSsl, handler](auto errc) {
            if (errc) {
                auto status = (errc == boost::asio::error::operation_aborted)
                    ? error::ConnectionTimedOut
                    : error::ConnectError;
                return handler(status);
            } else if (needSsl) {
                return self->asyncTlsHandshake(handler);
            }
            handler(error::Success);
        });
}

void SmtpSessionImpl::asyncTlsHandshake(Handler handler) {
    auto self = shared_from_this();
    auto now = yplatform::time_traits::clock::now();
    socket.async_tls_handshake(Socket::handshake_type::client, now + timeouts.connect,
        [self, this, handler](auto errc) {
            auto status = error::Success;
            if (errc) {
                status = (errc == boost::asio::error::operation_aborted)
                    ? error::ConnectionTimedOut
                    : error::SslError;
            } else {
                encrypted = true;
            }
            handler(status);
        });
}

void SmtpSessionImpl::asyncGreeting(SmtpHandler handler) {
    asyncHandleCommandResponse(EXPECTED_GREETING_REPLY_CODE, handler);
}

void SmtpSessionImpl::asyncHelo(SmtpPoint::Proto proto, const std::string& hostname, SmtpHandler handler) {
    auto self = shared_from_this();
    asyncExecCommand(Composer::HELLO(proto, hostname), EXPECTED_HELO_REPLY_CODE,
        [self, this, handler](error::Code errc, Response resp) {
            if (!errc && resp.replyCode == 250) {
                if (!updateServerExtensions(resp)) {
                    errc = error::ProtocolError;
                }
            }
            handler(errc, resp);
        });
}

void SmtpSessionImpl::asyncHelo(SmtpPoint::Proto proto, SmtpHandler handler) {
    return asyncHelo(proto, std::string(), std::move(handler));
}

void SmtpSessionImpl::asyncAuth(const AuthData& authData, SmtpHandler handler) {
    authEngine.reset(authData);
    Command cmd;
    authEngine.next(Response{}, cmd);
    asyncExecCommand(cmd, [this, handler](error::Code errc, Response resp) {
        if (errc) {
            return handler(errc, resp);
        }
        asyncAuthProceed(resp, handler);
    });
}

void SmtpSessionImpl::asyncAuthProceed(Response authResp, SmtpHandler handler) {
    Command cmd;
    auto status = authEngine.next(authResp, cmd);
    if (status == sasl::ClientEngine::Status::More) {
        auto self = shared_from_this();
        return asyncExecCommand(cmd,
            [self, this, handler](error::Code errc, Response resp) {
                if (errc) {
                    return handler(errc, resp);
                }
                asyncAuthProceed(resp, handler);
        });
    }
    if (status == sasl::ClientEngine::Status::Done && authResp.replyCode == 235) {
        return handler(error::Success, authResp);
    }
    handler(error::AuthError, authResp);
}

void SmtpSessionImpl::asyncStartTls(SmtpHandler handler) {
    auto self = shared_from_this();
    asyncExecCommand(Composer::STARTTLS(), EXPECTED_STARTTLS_REPLY_CODE,
        [self, this, handler](error::Code errc, Response resp) {
            if (errc) {
                return handler(errc, resp);
            }
            asyncTlsHandshake([self, resp, handler](error::Code errc) {
                if (errc) {
                    return handler(errc, resp);
                }
                handler(errc, resp);
            });
        });
}

void SmtpSessionImpl::asyncMailFrom(const MailFrom& mailFrom, SmtpHandler handler) {
    auto cmd = Composer::MAIL_FROM(mailFrom);
    auto self = shared_from_this();
    asyncWriteCommand(cmd.str(),
        [self, this, handler](error::Code errc) {
            if (errc) {
                return handler(errc, Response());
            }
            asyncHandleMailFrom(handler);
        });
}

void SmtpSessionImpl::asyncRcptTo(const RcptTo& rcptTo, bool enableDsn, SmtpHandler handler) {
    auto cmd = Composer::RCPT_TO(rcptTo, enableDsn);
    auto self = shared_from_this();
    asyncWriteCommand(cmd.str(),
        [self, this, handler](error::Code errc) {
            if (errc) {
                return handler(errc, Response());
            }
            asyncHandleRcptTo(handler);
        });
}

void SmtpSessionImpl::asyncDataStart(SmtpHandler handler) {
    auto cmd = Composer::DATA_START();
    auto self = shared_from_this();
    asyncWriteCommand(cmd.str(),
        [self, this, handler](error::Code errc) {
            if (errc) {
                return handler(errc, Response());
            }
            asyncHandleDataStart(handler);
        });
}

void SmtpSessionImpl::asyncRset(SmtpHandler handler) {
    clearBuffers();
    asyncExecCommand(Composer::RSET(), EXPECTED_RSET_REPLY_CODE, handler);
}

void SmtpSessionImpl::asyncQuit(SmtpHandler handler) {
    asyncExecCommand(Composer::QUIT(), EXPECTED_QUIT_REPLY_CODE, handler);
}

void SmtpSessionImpl::asyncHandleMailFrom(SmtpHandler handler) {
    asyncHandleCommandResponse(EXPECTED_MAIL_FROM_REPLY_CODE, handler);
}

void SmtpSessionImpl::asyncHandleRcptTo(SmtpHandler handler) {
    asyncHandleCommandResponse(EXPECTED_RCPT_TO_REPLY_CODE, handler);
}

void SmtpSessionImpl::asyncHandleDataStart(SmtpHandler handler) {
    asyncHandleCommandResponse(EXPECTED_DATA_START_REPLY_CODE, handler);
}

void SmtpSessionImpl::asyncExecCommand(const Command& cmd, SmtpHandler handler) {
    auto self = shared_from_this();
    asyncWriteCommand(cmd.str(),
        [self, this, handler](error::Code errc) {
            if (errc) {
                return handler(errc, Response());
            }
            asyncReadCommandResponse(handler);
        });
}

void SmtpSessionImpl::asyncExecCommand(const Command& cmd, uint16_t expectedReplyCode, SmtpHandler handler) {
    auto self = shared_from_this();
    asyncWriteCommand(cmd.str(),
        [self, this, expectedReplyCode, handler](error::Code errc) {
            if (errc) {
                return handler(errc, Response());
            }
            asyncHandleCommandResponse(expectedReplyCode, handler);
        });
}

void SmtpSessionImpl::asyncWriteMessage(const std::string& msg, bool enableDotStuffing, SmtpHandler handler) {
    std::string newMsg = msg;
    if (enableDotStuffing) {
        newMsg = applyDotStuff(msg);
    }
    appendTerminationOctet(newMsg);
    auto self = shared_from_this();
    auto now = yplatform::time_traits::clock::now();
    asyncWrite(newMsg, now + timeouts.data,
        [self, this, handler](error::Code errc) {
            if (errc) {
                return handler(errc, Response());
            }
            asyncReadMessageResponse(handler);
        });
}

void SmtpSessionImpl::asyncWriteCommand(const std::string& cmd, Handler handler) {
    auto now = yplatform::time_traits::clock::now();
    asyncWrite(cmd, now + timeouts.command, handler);
}

void SmtpSessionImpl::asyncWrite(const std::string& buffer, TimePoint deadline, Handler handler) {
    writeBuf.consume(writeBuf.size());
    std::ostream stream(&writeBuf);
    stream << buffer;

    auto self = shared_from_this();
    socket.async_write(writeBuf, deadline,
        [self, handler](boost::system::error_code errc, std::size_t) {
            auto status = error::Success;
            if (errc) {
                status = (errc == boost::asio::error::operation_aborted)
                    ? error::RequestTimedOut
                    : error::WriteError;
            }
            handler(status);
        });
}

void SmtpSessionImpl::asyncReadResponse(Duration timeout, SmtpHandler handler) {
    auto self = shared_from_this();
    auto reader = std::make_shared<ResponseReader>(socket, readBuf, timeout, ctx, logger(),
        [self, handler](error::Code errc, Response resp) {
            handler(errc, resp);
        });
    yplatform::spawn(reader);
}

void SmtpSessionImpl::asyncReadCommandResponse(SmtpHandler handler) {
    asyncReadResponse(timeouts.command, handler);
}

void SmtpSessionImpl::asyncHandleCommandResponse(uint16_t expectedReplyCode, SmtpHandler handler) {
    asyncReadCommandResponse([expectedReplyCode, handler](error::Code errc, Response resp) {
        if (!errc && resp.replyCode != expectedReplyCode) {
            errc = error::BadReplyCode;
        }
        return handler(errc, resp);
    });
}

void SmtpSessionImpl::asyncReadMessageResponse(SmtpHandler handler) {
    asyncReadResponse(timeouts.data, handler);
}

}  // namespace ymod_smtpclient
