#pragma once

#include <pgg/database/database.h>
#include <pgg/database/fallback.h>
#include <pgg/query/boundaries.h>
#include <pgg/query/repository.h>
#include <pgg/share.h>
#include <boost/enable_shared_from_this.hpp>
#include <boost/asio/coroutine.hpp>
#include <boost/asio/yield.hpp>
#include <pgg/chrono.h>

namespace pgg {

template <typename F, typename Arg>
inline auto invokeWithConn(F&& f, Arg&& arg, ConnectionPtr conn) {
    if constexpr( std::is_invocable_v<F, Arg, ConnectionPtr> ) {
        return std::invoke(std::forward<F>(f), std::forward<Arg>(arg), std::move(conn));
    } else {
        return std::invoke(std::forward<F>(f), std::forward<Arg>(arg));
    }
}

namespace query {

struct Begin : public QueryImpl<Begin> {
    template<typename... ArgsT>
    Begin( const Traits & traits, const ArgsT& ... args )
        : Inherited( traits, args... ) {
    }
};

struct Commit : public QueryImpl<Commit> {
    template<typename... ArgsT>
    Commit( const Traits & traits, const ArgsT& ... args )
        : Inherited( traits, args... ) {
    }
};

struct Rollback : public QueryImpl<Rollback> {
    template<typename... ArgsT>
    Rollback( const Traits & traits, const ArgsT& ... args )
        : Inherited( traits, args... ) {
    }
};

class Transaction : public boost::enable_shared_from_this<Transaction>,
                    public boost::noncopyable {
public:
#define ASSERT_NOT_NULL(arg) checkNullArg(arg, #arg, __PRETTY_FUNCTION__)
#define ASSERT_TRANSACTION_IN_PROGRESS checkInProgress(__PRETTY_FUNCTION__)
#define ASSERT_TRANSACTION_TIMEOUT_NOT_EXCEEDED checkTransactionTimeout(__PRETTY_FUNCTION__)

    Transaction(RepositoryPtr queryRepository)
            : queryRepository_(queryRepository), connection_(nullptr) {
        ASSERT_NOT_NULL(queryRepository);
    }
    template <typename ExecuteHandler>
    Transaction(RepositoryPtr queryRepository, ConnectionPtr connection,
            ExecuteHandler&& action, Milliseconds timeout)
            : queryRepository_(queryRepository), connection_(nullptr) {
        ASSERT_NOT_NULL(queryRepository);
        ASSERT_NOT_NULL(connection);
        ASSERT_NOT_NULL(action);
        setDeadline(timeout);
        beginExecute(connection, std::forward<ExecuteHandler>(action));
    }

    template <typename RequestHandler>
    void request(const Query & q, RequestHandler&& handler) const {
        ASSERT_TRANSACTION_IN_PROGRESS;
        ASSERT_NOT_NULL(handler);
        ASSERT_TRANSACTION_TIMEOUT_NOT_EXCEEDED;
        connection_->request(q, std::forward<RequestHandler>(handler));
    }
    template <typename FetchHandler>
    void fetch(const Query & q, FetchHandler&& handler) const {
        ASSERT_TRANSACTION_IN_PROGRESS;
        ASSERT_NOT_NULL(handler);
        ASSERT_TRANSACTION_TIMEOUT_NOT_EXCEEDED;
        connection_->fetch(q, std::forward<FetchHandler>(handler));
    }
    template <typename UpdateHandler>
    void update(const Query & q, UpdateHandler&& handler) const {
        ASSERT_TRANSACTION_IN_PROGRESS;
        ASSERT_NOT_NULL(handler);
        ASSERT_TRANSACTION_TIMEOUT_NOT_EXCEEDED;
        connection_->update(q, std::forward<UpdateHandler>(handler));
    }
    template <typename ExecuteHandler>
    void execute(const Query & q, ExecuteHandler&&  handler) const {
        ASSERT_TRANSACTION_IN_PROGRESS;
        ASSERT_NOT_NULL(handler);
        ASSERT_TRANSACTION_TIMEOUT_NOT_EXCEEDED;
        connection_->execute(q, std::forward<ExecuteHandler>(handler));
    }

    template <typename ExecuteHandler>
    void begin(ConnectionPtr connection, ExecuteHandler&& action, Milliseconds timeout) {
        ASSERT_NOT_NULL(action);
        ASSERT_NOT_NULL(connection);
        setDeadline(timeout);
        beginExecute(connection, std::forward<ExecuteHandler>(action));
    }
    template <typename ExecuteHandler>
    void commit(ExecuteHandler&& action) {
        ASSERT_TRANSACTION_IN_PROGRESS;
        ASSERT_NOT_NULL(action);
        finalExecute(query<Commit>(), std::forward<ExecuteHandler>(action));
    }
    template <typename ExecuteHandler>
    void rollback(ExecuteHandler&& action) {
        ASSERT_TRANSACTION_IN_PROGRESS;
        ASSERT_NOT_NULL(action);
        finalExecute(query<Rollback>(), std::forward<ExecuteHandler>(action));
    }

    bool inProgress() const noexcept { return connection_ != nullptr; }

    bool timedOut() const noexcept {
        if (!haveDeadline()) {
            return false;
        }
        return now().time_since_epoch() > deadline_;
    }

    ~Transaction() noexcept {
        if (inProgress()) try {
            connection_->execute(query<Rollback>(), [](error_code){});
        } catch( const std::exception & e ) {
            std::cerr << "Transaction dtor: " << e.what() << std::endl;
        } catch( ... ) {
            std::cerr << "Transaction dtor: unknown exception" << std::endl;
        }
    }
#undef ASSERT_NOT_NULL
#undef ASSERT_TRANSACTION_IN_PROGRESS
#undef ASSERT_TRANSACTION_TIMEOUT_NOT_EXCEEDED
private:
    RepositoryPtr queryRepository_;
    ConnectionPtr connection_;
    Milliseconds deadline_;

    template <typename Arg>
    auto checkNullArg( const Arg & arg, const char * name, const char * func ) const
        -> std::void_t<decltype(bool(arg==nullptr))> {
        if (arg== nullptr) {
            std::ostringstream s;
            s << func << " argument \"" << name << "\" is nullptr";
            throw std::invalid_argument(s.str());
        }
    }

    template <typename Arg>
    void checkNullArg(const Arg&, ...) const {
    }

    void checkInProgress(const char * func ) const {
        if (!inProgress()) {
            std::ostringstream s;
            s << func << " there is no transaction in progress";
            throw std::logic_error(s.str());
        }
    }

    bool haveDeadline() const {
        return deadline_ != Duration::max();
    }

    void checkTransactionTimeout(const char* func) const {
        if (timedOut()) {
            std::ostringstream s;
            s << func << " transaction timeout exceeded";
            throw std::logic_error(s.str());
        }
    }

    void setDeadline(Milliseconds timeout) {
        if (timeout != Duration::zero()) {
            deadline_ = now().time_since_epoch() + timeout;
        } else {
            deadline_ = Duration::max();
        }
    }

    template <typename ExecuteHandler>
    void finalExecute(const Query & q, ExecuteHandler&& action) {
        connection_->execute(q, makeSetConnectionHandler(nullptr, std::forward<ExecuteHandler>(action)));
    }

    template <typename ExecuteHandler>
    void beginExecute(ConnectionPtr connection, ExecuteHandler&& action) {
        if (inProgress()) {
            std::ostringstream s;
            s << __PRETTY_FUNCTION__ << " transaction in progress";
            throw std::logic_error(s.str());
        }
        connection->execute(query<Begin>(), makeSetConnectionHandler(connection, std::forward<ExecuteHandler>(action)));
    }

    template <typename ExecuteHandler>
    auto makeSetConnectionHandler(ConnectionPtr connection, ExecuteHandler action) {
        return [action = std::move(action), connection, self = share(this)](database::error_code e) mutable {
            if(!e) {
                self->connection_ = connection;
            }
            invokeWithConn(action, std::move(e), self->connection_);
        };
    }

    template <typename Q>
    Q query() const {
        return queryRepository_->query<Q>();
    }
};

using TransactionPtr = boost::shared_ptr<Transaction>;

namespace fallback {

using EndpointType = query::Traits::EndpointType;
namespace fb = database::fallback;

template <typename DatabaseGenerator, typename Rule = fb::rules::Master>
class Transaction : public boost::enable_shared_from_this<Transaction<DatabaseGenerator, Rule>>,
                    public boost::noncopyable {
public:
    using Base = query::Transaction;
    using RepositoryPtr = query::RepositoryPtr;

    Transaction(RepositoryPtr repo, Rule rule = Rule())
    : base(boost::make_shared<Base>(repo)), rule(rule) {}

    template <typename ExecuteHandler>
    Transaction(RepositoryPtr repo, DatabaseGenerator db,
            ExecuteHandler&& action, Milliseconds timeout, Rule rule = Rule())
    : base(boost::make_shared<Base>(repo)), rule(rule) {
        begin(db, std::forward<ExecuteHandler>(action), timeout);
    }

    template <typename RequestHandler>
    void request(const Query & q, RequestHandler&& handler) const {
        base->request(q, std::forward<RequestHandler>(handler));
    }
    template <typename FetchHandler>
    void fetch(const Query & q, FetchHandler&& handler) const {
        base->fetch(q, std::forward<FetchHandler>(handler));
    }
    template <typename UpdateHandler>
    void update(const Query & q, UpdateHandler&& handler) const {
        base->update(q, std::forward<UpdateHandler>(handler));
    }
    template <typename ExecuteHandler>
    void execute(const Query & q, ExecuteHandler&& handler) const {
        base->execute(q, std::forward<ExecuteHandler>(handler));
    }

    template <typename ExecuteHandler>
    void begin(DatabaseGenerator db, ExecuteHandler&& action, Milliseconds timeout) const {
        auto ctx = std::make_shared<Context<ExecuteHandler>>(rule, std::move(db), base, std::forward<ExecuteHandler>(action), timeout);
        Beginner<ExecuteHandler>(ctx).run();
    }

    template <typename ExecuteHandler>
    void commit(ExecuteHandler&& action) const { base->commit(std::forward<ExecuteHandler>(action)); }
    template <typename ExecuteHandler>
    void rollback(ExecuteHandler&& action) const { base->rollback(std::forward<ExecuteHandler>(action)); }

    bool inProgress() const noexcept { return base->inProgress(); }
    bool timedOut() const noexcept { return base->timedOut(); }
private:
    using TransactionPtr = query::TransactionPtr;
    TransactionPtr base;
    Rule rule;

    template <typename ExecuteHandler>
    struct Context {
        Rule rule;
        DatabaseGenerator db;
        TransactionPtr trans;
        std::decay_t<ExecuteHandler> action;
        Milliseconds timeout;
        error_code lastError;
        Context( Rule rule, DatabaseGenerator db,
                TransactionPtr trans, ExecuteHandler action, Milliseconds timeout )
        : rule(std::move(rule)), db(std::move(db)), trans(std::move(trans)),
          action(std::move(action)), timeout(std::move(timeout)) {}
    };

    template <typename ExecuteHandler>
    struct Beginner : boost::asio::coroutine {
        using CtxPtr = std::shared_ptr<Context<ExecuteHandler>>;

        CtxPtr ctx;
        Beginner(CtxPtr ctx) : ctx(ctx) {}

        void run() { (*this)(database::error_code()); }

        void operator() (database::error_code e, ConnectionPtr conn = ConnectionPtr() ) {
            reenter( this ) {
                do {
                    ctx->lastError = e;
                    yield ctx->rule.nextTry(ctx->db)->withinConnection(*this);
                    if(!e) {
                        yield ctx->trans->begin(conn, *this, ctx->timeout);
                    }
                } while (e && ctx->rule.recoverable(e, ctx->lastError));
                invokeWithConn(ctx->action, std::move(e), std::move(conn));
            }
        }
    };

};

} // namespace fallback

} //namespace query
} //namespace pgg

#include <boost/asio/unyield.hpp>
