#include "session.h"
#include "command_composer.h"
#include "dot_stuffer.h"

#include <boost/algorithm/string/join.hpp>
#include <boost/algorithm/string/trim.hpp>
#include <boost/asio/ssl/error.hpp>
#include <boost/asio/yield.hpp>

#include <functional>

namespace ymod_smtpclient {

namespace p = std::placeholders;

inline std::string get_ip_string(const boost::asio::ip::address& address) {
    auto ip = address;
    if (ip.is_v6()) {
        auto v6 = ip.to_v6();
        if (v6.is_v4_mapped() || v6.is_v4_compatible()) {
            ip = v6.to_v4();
        }
    }
    return ip.to_string();
}

inline std::string format_address(const boost::asio::ip::address& address, unsigned short port) {
    return "[" + get_ip_string(address) + "]:" + std::to_string(port);
}

Session::Session(
    yplatform::net::io_data& iodata,
    const yplatform::log::source& logger,
    const Settings& settings
)
    : yplatform::log::contains_logger(logger)
    , settings(settings)
    , clientLog(logger)
    , serverLog(logger)
    , session(std::make_shared<SmtpSessionImpl>(iodata, settings, logger))
{
}

Session::~Session() {
    close();
}

void Session::close() {
    reusable = false;
    if (isOpen()) {
        if (settings.enableLogging) {
            auto now = yplatform::time_traits::clock::now();
            YLOG_L(debug) << "conn=" << session->getId() << " close session "
                << format_address(session->getLocalAddr(), session->getLocalPort()) << " -> "
                << format_address(session->getRemoteAddr(), session->getRemotePort())
                << ", duration=" << yplatform::time_traits::to_string(now - createdAt)
                << ", count=" << requestsProcessed;
        }
        try {
            session->cancel();
            session->close();
        } catch (...) {
        }
    }
}

void Session::runCoro() {
    try { reenter(coro) {
        if (!isOpen()) {
            yield return connect();

            yield return greeting();
            yield return hello();

            if (!session->isEncrypted() && session->getServerExtensions().starttls) {
                yield return starttls();
                yield return hello();
            }
        }

        if (!authorized && request->data.auth) {
            if (!session->isEncrypted()) {
                return complete(error::AuthWithoutSSL);
            }
            yield return authorize();
        }

        if (lastErrorCode != error::Success) {
            yield return rset();
        }

        if (settings.enableLogging) {
            YLOG_CTX_LOCAL(request->ctx, info)
                << "from=" << request->data.mailfrom.email
                << ", nrcpts=" << request->data.rcpts.size()
                << ", msg_size=" << request->data.message->size();
        }

        if (session->getServerExtensions().pipelining) {
            yield pipelineCommands();

            yield session->asyncHandleMailFrom(
                std::bind(&Session::handleMailfrom, shared_from_this(), p::_1, p::_2));

            rcptActive = request->data.rcpts.begin();
            for(; rcptActive != request->data.rcpts.end(); ++rcptActive) {
                yield session->asyncHandleRcptTo(
                    std::bind(&Session::handleRcptTo, shared_from_this(), p::_1, p::_2));
            }
            yield session->asyncHandleDataStart(
                std::bind(&Session::handleDataStart, shared_from_this(), p::_1, p::_2));
        } else {
            yield session->asyncMailFrom(request->data.mailfrom,
                std::bind(&Session::handleMailfrom, shared_from_this(), p::_1, p::_2));

            rcptActive = request->data.rcpts.begin();
            for(; rcptActive != request->data.rcpts.end(); ++rcptActive) {
                yield session->asyncRcptTo(*rcptActive, session->getServerExtensions().dsn,
                    std::bind(&Session::handleRcptTo, shared_from_this(), p::_1, p::_2));
            }
            yield session->asyncDataStart(
                std::bind(&Session::handleDataStart, shared_from_this(), p::_1, p::_2));
        }

        if (request->data.address.proto == SmtpPoint::lmtp) {
            rcptActive = request->data.rcpts.begin();
            while(rcptActive != request->data.rcpts.end() &&
                request->response.rcpts.count(rcptActive->email) > 0)
            {
                ++rcptActive;
            }
            yield session->asyncWriteMessage(*request->data.message, request->dotStuffing,
                std::bind(&Session::handleLmtpResponse, shared_from_this(), p::_1, p::_2));
        } else {
            yield session->asyncWriteMessage(*request->data.message, request->dotStuffing,
                std::bind(&Session::handleSmtpResponse, shared_from_this(), p::_1, p::_2));
        }

        {
            auto errc = calcSessionErrorCode();
            reusable &= (errc == error::Success || errc == error::PartialSend);

            return complete(errc);
        }

    } } catch (const std::exception& e) {   // reenter
        try { YLOG_L(error) << "session coro exception: " << e.what(); } catch(...) {}
        return complete(error::Unknown);
    }
}

void Session::connect() {
    auto self = shared_from_this();
    session->asyncConnect(
        request->data.address.host,
        request->data.address.port,
        request->useSsl,
        [self, this](error::Code errc) {
            if (errc) {
                return complete(errc);
            }
            if (settings.enableLogging) {
                YLOG_L(debug) << "conn=" << session->getId() << " open new session "
                    << format_address(session->getLocalAddr(), session->getLocalPort()) << " -> "
                    << format_address(session->getRemoteAddr(), session->getRemotePort());
            }
            runCoro();  // continue transaction
        });
}

void Session::hello() {
    auto self = shared_from_this();
    session->asyncHelo(request->data.address.proto,
        [self, this](error::Code errc, Response resp) {
            if (errc) {
                return complete(errc, resp, "on helo command: ");
            }
            return runCoro();
    });
}

void Session::starttls() {
    auto self = shared_from_this();
    session->asyncStartTls([self, this](error::Code errc, Response resp) {
        if (errc) {
            return complete(errc, resp, "on command STARTTLS: ");
        }
        runCoro();
    });
}

void Session::greeting() {
    auto self = shared_from_this();
    session->asyncGreeting(
        [self, this](error::Code errc, Response resp) {
            if (errc) {
                return complete(errc, resp, "bad greeting: ");
            }
            runCoro();
    });
}

void Session::rset() {
    auto self = shared_from_this();
    session->asyncRset([self, this](error::Code errc, Response resp) {
        if (errc) {
            return complete(errc, resp, "on command RSET: ");
        }
        runCoro();
    });
}

void Session::pipelineCommands() {
    std::string buffer;
    buffer += Composer::MAIL_FROM(request->data.mailfrom).str();
    rcptActive = request->data.rcpts.begin();
    for (; rcptActive != request->data.rcpts.end(); ++rcptActive) {
        buffer += Composer::RCPT_TO(*rcptActive, session->getServerExtensions().dsn).str();
    }
    buffer += Composer::DATA_START().str();
    session->asyncWriteCommand(buffer,
        [self = shared_from_this(), this](error::Code errc) {
            if (errc) {
                if (errc == error::RequestTimedOut) {
                    return complete(error::RequestTimedOut, "write command timeout");
                } else {
                    return complete(error::WriteError);
                }
            }
            runCoro();
        });
}

void Session::authorize() {
    assert(request->data.auth);
    auto mech = request->data.auth->mechanism;
    const auto& extensions = session->getServerExtensions();
    if (mech != sasl::Mechanism::None && extensions.authMechanisms.count(mech) == 0) {
        return complete(error::AuthError, "not found auth mechanism: " + to_string(mech));
    } else if (mech == sasl::Mechanism::None) {
        if (extensions.authMechanisms.count(sasl::Mechanism::Login) > 0) {
            mech = sasl::Mechanism::Login;
        } else if (extensions.authMechanisms.count(sasl::Mechanism::Plain) > 0) {
            mech = sasl::Mechanism::Plain;
        } else {
            return complete(error::AuthError, "not found auth mechanisms: LOGIN, PLAIN");
        }
    }
    request->data.auth->mechanism = mech;

    if (settings.enableLogging) {
        YLOG_CTX_LOCAL(request->ctx, info)
            << "start auth, login='" << request->data.auth->login << "'"
            << ", method='" << to_string(request->data.auth->mechanism) << "'";
    }

    auto self = shared_from_this();
    session->asyncAuth(request->data.auth.get(),
        [self, this](error::Code errc, Response resp) {
            if (errc) {
                return complete(error::AuthError);
            }
            if (resp.replyCode != 235) {
                request->response.session = multiline2OneLineResponse(resp);
                return complete(error::AuthError,
                    "bad reply code: " + std::to_string(resp.replyCode));
            }
            authorized = true;
            return runCoro();   // continue transaction
    });
}

void Session::handleLmtpResponse(error::Code, const Response& resp) {

    request->response.rcpts.emplace(rcptActive->email, multiline2OneLineResponse(resp));
    ++rcptActive;
    while(rcptActive != request->data.rcpts.end() &&
        request->response.rcpts.count(rcptActive->email) > 0)
    {
        ++rcptActive;
    }
    if (rcptActive != request->data.rcpts.end()) {
        session->asyncReadMessageResponse(
            std::bind(&Session::handleLmtpResponse, shared_from_this(), p::_1, p::_2));
    } else {
        request->response.session = multiline2OneLineResponse(resp);
        runCoro();
    }
}

void Session::handleSmtpResponse(error::Code errc, const Response& resp) {
    if (errc == error::RequestTimedOut) {
        return complete(errc, "write message");
    }
    if (errc) {
        return complete(errc);
    }
    request->response.session = multiline2OneLineResponse(resp);
    runCoro();
}

error::Code Session::calcSessionErrorCode() {
    bool hasSuccess = false;
    bool hasFail = false;
    for (const auto& rcptResp: request->response.rcpts) {
        if (reply::smtp_code_to_status(rcptResp.second.replyCode) != reply::Status::Accept) {
            hasFail = true;
        } else {
            hasSuccess = true;
        }
    }
    if (request->data.address.proto == SmtpPoint::smtp) {
        auto status = reply::smtp_code_to_status((*request->response.session).replyCode);
        if (status != reply::Status::Accept) {
            hasFail = true;
        } else {
            hasSuccess = true;
        }
    }
    if (hasSuccess) {
        return hasFail ? error::Code::PartialSend : error::Code::Success;
    } else {
        return error::Code::BadReplyCode;
    }
}

void Session::complete(error::Code errc, const std::string& reason) {
    if (!request) {
        return;
    }
    if ((errc != error::Success) && (errc != error::PartialSend)) {
        reusable = false;
    }
    lastErrorCode = errc;
    auto now = yplatform::time_traits::clock::now();
    request->totalTime = now - request->startedAt;
    if (settings.enableLogging) {
        std::ostringstream responseStream;
        if (auto respOpt = request->response.session) {
            responseStream << ", code=" << respOpt->replyCode;
            if (auto enhancedCodeOpt = respOpt->enhancedStatusCode) {
                responseStream << ", enhanced_code=" << *enhancedCodeOpt;
            }
            responseStream << ", text='" << respOpt->data << "'";
        }
        YLOG_CTX_LOCAL(request->ctx, info) << "conn=" << id()
            << " req finished, errc='" << error::message(errc) << "'"
            << (!reason.empty() ? ", reason='" + reason + "'" : "")
            << ", duration=" << yplatform::time_traits::to_string(request->totalTime)
            << responseStream.str();
    }
    RequestDataPtr req;
    req.swap(request);
    CompletionHandler handler;
    handler.swap(completionHandler);
    if (handler) {
        handler(req, shared_from_this(), errc);
    }
}

void Session::complete(error::Code errc, const Response& resp, const std::string& reason) {
    if (errc == error::BadReplyCode) {
        request->response.session = multiline2OneLineResponse(resp);
        return complete(errc, reason + std::to_string(resp.replyCode));
    }
    return complete(errc);
}

inline bool isErrCorrectEOF(boost::system::error_code errc) {
    static auto short_read = boost::asio::ssl::error::stream_truncated;
    return  ((errc == boost::asio::error::eof) ||
        (errc == short_read) ||
        (errc == boost::asio::error::shut_down) ||
        (errc == boost::asio::error::connection_reset) ||
        (errc == boost::asio::error::connection_aborted));
}

bool Session::checkCorrectEOF(boost::system::error_code errc) {
    if (!isErrCorrectEOF(errc)) {
        YLOG_L(warning) << "server did not close connection: " << errc.message();
    }
    return isErrCorrectEOF(errc);
}

void Session::handleMailfrom(error::Code errc, const Response& resp) {
    if (errc) {
        return complete(errc, resp, "on command MAIL FROM: ");
    }
    runCoro();  // continue transaction
}

void Session::handleDataStart(error::Code errc, const Response& resp) {
    if (errc) {
        return complete(errc, resp, "on command DATA: ");
    }
    runCoro();  // continue transaction
}

void Session::handleRcptTo(error::Code errc, const Response& resp) {
    if (errc && errc != error::BadReplyCode) {
        return complete(errc);
    }
    if (resp.replyCode != 250) {
        if (request->allowRcptToErrors) {
            auto badResult = multiline2OneLineResponse(resp);
            if (settings.enableLogging) {
                YLOG_CTX_LOCAL(request->ctx, warning)
                    << "bad response on command RCPT TO for "
                    << rcptActive->email << ": '" << badResult << "'";
            }
            request->response.rcpts.emplace(rcptActive->email, std::move(badResult));
        } else {
            auto status = reply::smtp_code_to_status(resp.replyCode);
            request->response.session = multiline2OneLineResponse(resp);
            return complete(status == reply::Status::Reject
                ? error::BadRecipient
                : error::BadReplyCode, "on command RCPT TO for <"
                + rcptActive->email + ">: " + std::to_string(resp.replyCode));
        }
    }
    runCoro();  // continue transaction
}

}   // namespace ymod_smtpclient

