#include <util/generic/string.h>
#include <util/stream/str.h>
#include <util/generic/deque.h>
#include <util/string/join.h>
#include <util/generic/maybe.h>
#include <util/string/split.h>
#include <util/random/shuffle.h>
#include <vector>
#include <string>
#include <util/system/yield.h>
#include <library/cpp/coroutine/engine/impl.h>
#include <library/cpp/coroutine/engine/network.h>

#include "Error.h"
#include "Work.h"
#include "Connection.h"

namespace sql {
    TConnectionString TConnectionString::Parse(const TString & s) {
        TVector<TStringBuf> parts;

        Split(s, " ", parts);

        TConnectionString connectionString;
        for(const auto & part : parts) {
            TStringBuf key, value;
            Split(part, "=", key, value);

            if("dbname" == key) {
                connectionString.SetDb(TString{value});
            } else if("user" == key) {
                connectionString.SetUser(TString{value});
            } else if("password" == key) {
                connectionString.SetPass(TString{value});
            } else if("host" == key) {
                connectionString.SetHost(TString{value});
            } else if("port" == key) {
                connectionString.SetPort(FromString<ui32>(value));
            }
        }

        return connectionString;
    }

    TString TConnectionString::MakeRawConnectionString(ESafe safe) const {
        TVector<TString> parts;

        if(db)
            parts.emplace_back("dbname=" + *db);
        if(host)
            parts.emplace_back("host=" + *host);
        if(user)
            parts.emplace_back("user=" + *user);
        if(pass && safe == ESafe::No)
            parts.emplace_back("password=" + *pass);
        if(port)
            parts.emplace_back("port=" + ToString(*port));

        return JoinSeq(" ", parts);
    }

    void TConnectionsTRaits::UpdateTraits() {
        traitsMutex.AcquireRead();
        auto traitsCopy = connectionsTraits;
        traitsMutex.ReleaseRead();
//        GetConnectionAsync

        for (auto &p : traitsCopy) {
            try {
                TVector<TString> ips = ares.Resolve(p.second.host);
                for(const TString & ip : ips) {
                    p.second.connectionString.SetHost(ip);
                    std::tie(p.second.ok, p.second.isReplica) = Check(p.second.connectionString.MakeRawConnectionString());

                    if(p.second.ok)
                        break;
                }
            } catch (...) {
                std::tie(p.second.ok, p.second.isReplica) = Check(p.second.connectionString.MakeRawConnectionString());
            }
        }

        TWriteGuard g(traitsMutex);
        connectionsTraits = std::move(traitsCopy);
    }

    void TConnectionsTRaits::AddAndCheck(const TString & rawConnectionString) {
        const auto key = THash<TString>()(rawConnectionString);
        {
            TReadGuard g(traitsMutex);
            if(Y_LIKELY(connectionsTraits.contains(key)))
                return;
        }

        TConnectionString connectionString = TConnectionString::Parse(rawConnectionString);

        if(!connectionString.GetHost().Defined()) {
            ythrow TWithBackTrace<yexception>() << "connection string doesnt contain host";
        }

        auto defaultHost = *connectionString.GetHost();

        TVector<TString> ips;
        try {
            ips = ares.Resolve(defaultHost);
        } catch (...) {
            ips.emplace_back(defaultHost);
        }

        for(const TString & ip : ips) {
            connectionString.SetHost(ip);

            bool ok, isReplica;
            std::tie(ok, isReplica) = Check(connectionString.MakeRawConnectionString());

            if(ok) {
                TWriteGuard g(traitsMutex);
                connectionsTraits.emplace(key, TConnectionTRaits{connectionString, std::move(defaultHost), ok, isReplica});
                return;
            }
        }
    }

    TConnectionTRaits TConnectionsTRaits::GetTraits(const TString & connectionString) const {
        TReadGuard g(traitsMutex);
        auto it = connectionsTraits.find(THash<TString>()(connectionString));
        if(Y_LIKELY(connectionsTraits.cend() != it))
            return it->second;

        ythrow TWithBackTrace<yexception>() << "traits not exists";
    }

    THashMap<size_t, TConnectionTRaits> TConnectionsTRaits::GetTraits() const {
        TReadGuard g(traitsMutex);
        return connectionsTraits;
    };

    TConnectionsTRaits::TConnectionsTRaits() : ares(TDuration::Seconds(10), 6) {}

    std::tuple<bool, bool> TConnectionsTRaits::Check(const TString & connectionString) {
        auto connectionHolder = sql::TPgConnectionHolder::PqConnectDb(connectionString);

        if (!connectionHolder.Ok()) {
            Cdbg << "Check: " << connectionString << "; res: " << connectionHolder.Info() << Endl;
            return {false, false};
        }

        try {
            return {true, connectionHolder.IsReplica()};
        } catch(const std::exception &e) {
            Cdbg << "Check: " << connectionString << "; exception: " << e.what() << Endl;
            return {false, false};
        }
    }

    void Connection::SendQuery(const Query& query, bool binary) {
        const bool zeroParams = query.params.empty();
        const int sendSuccess = binary ? conn.PqSendQueryParams(
                query.query.Data(),
                static_cast<int>(query.params.size()),
                zeroParams ? nullptr : &query.oids[0],
                zeroParams ? nullptr : &query.params[0],
                zeroParams ? nullptr : &query.lengths[0],
                zeroParams ? nullptr : &query.formats[0],
                1)
           : conn.PqSendQuery(query.query.Data());

        if (sendSuccess != 1)
            ythrow TInterfaceError(TInterfaceError::CannotAsyncExec) << conn;
    }

    TPgResultHolder Connection::ExecQuery(const Query& query, bool binary) {
        const bool zeroParams = query.params.empty();
        return binary ? conn.PqExecParams(
                query.query.Data(),
                static_cast<int>(query.params.size()),
                zeroParams ? nullptr : &query.oids[0],
                zeroParams ? nullptr : &query.params[0],
                zeroParams ? nullptr : &query.lengths[0],
                zeroParams ? nullptr : &query.formats[0],
                1)
          : conn.PqExec(query.query.Data());
    }

    void Connection::Flush() {
        auto res = conn.PqFlush();
        if(res < 0)
            ythrow TInterfaceError(TInterfaceError::Flush) << conn;
    }

    void Connection::Flush(TInstant deadline) {
        auto res = conn.PqFlush();
        if(res < 0)
            ythrow TInterfaceError(TInterfaceError::Flush) << conn;

        if(res > 0) {
            TSocketPoller poller;
            poller.WaitWrite(conn.PqSocket(), (void*)42);
            void * cookie = nullptr;
            while ((cookie = poller.WaitD(deadline))) {
                if(cookie != (void*)42) {
                    ThreadYield();
                    continue;
                }
                res = conn.PqFlush();
                if (res == 0)
                    break;
                if (res < 0)
                    ythrow TInterfaceError(TInterfaceError::Flush) << conn;
            }
        }
    }

    void Connection::Flush(TCont *cont, TInstant deadline) {
        auto res = conn.PqFlush();
        if(res < 0)
            ythrow TInterfaceError(TInterfaceError::Flush) << conn;

        if(res > 0) {
            while (!NCoro::PollD(cont, conn.PqSocket(), CONT_POLL_WRITE, deadline)) {
                res = conn.PqFlush();
                if (res == 0)
                    return;
                if (res < 0)
                    ythrow TInterfaceError(TInterfaceError::Flush) << conn;
            }

            ythrow TInterfaceError(TInterfaceError::Flush) << conn;
        }
    }

    TPgResultHolder Connection::GetResult(TInstant deadline) {
        TSocketPoller poller;
        poller.WaitRead(conn.PqSocket(), (void*)42);

        void * cookie = nullptr;
        TPgResultHolder lastRes;
        while ((cookie = poller.WaitD(deadline))) {
            if(cookie != (void*)42) {
                ThreadYield();
                continue;
            }
            if (!conn.PqConsumeInput())
                ythrow TInterfaceError(TInterfaceError::AsyncConsume) << conn;

            while(conn.PqIsBusy() == 0) {
                TPgResultHolder newRes = conn.PqGetResult();
                if (!newRes)
                    return lastRes;
                lastRes = std::move(newRes);
            }
        }
        ythrow TInterfaceError(TInterfaceError::ExecTimeout);
    }

    TPgResultHolder Connection::GetResult(TInstant deadline, TCont* cont) {
        TPgResultHolder lastRes;
        while (!NCoro::PollD(cont, conn.PqSocket(), CONT_POLL_READ, deadline)) {
            if (!conn.PqConsumeInput())
                ythrow TInterfaceError(TInterfaceError::AsyncConsume) << conn;

            while(Now() < deadline && conn.PqIsBusy() == 0) {
                TPgResultHolder newRes = conn.PqGetResult();
                if (!newRes)
                    return lastRes;
                lastRes = std::move(newRes);
            }
        }
        ythrow TInterfaceError(TInterfaceError::ExecTimeout);
    }

    void Connection::setExecTimeout(const TDuration& timeout) {
        execTimeout = timeout;
    }

    TDuration Connection::getExecTimeout() const {
        return execTimeout;
    }

    TString Connection::Info() const {
        return conn.Info();
    }

    bool Connection::ok() const {
        return conn.Ok();
    }

    void Connection::reset() {
        conn.PqReset();
    }

    TResWithError<bool> Connection::isReplica()
    try {
        if(traits)
            return traits->GetTraits(conn.GetConnectionString()).isReplica;
        else
            return conn.IsReplica();
    } catch (const TInterfaceError & e) {
        return e;
    }

    TPgConnectionHolder ConnectAsync(const TString & host, TInstant deadline, TCont * cont) {
        TPgConnectionHolder connection = TPgConnectionHolder::PqConnectStart(host);
        if(!connection.Ok())
            return {};

        while(true) {
            switch (connection.PqConnectPoll()) {
                case EPollingStatus::Ok:
                    return connection;
                case EPollingStatus::Fail:
                    return {};
                case EPollingStatus::Read:
                    if(NCoro::PollD(cont, connection.PqSocket(), CONT_POLL_READ, deadline)) {
                        return {};
                    }
                    break;
                case EPollingStatus::Write: {
                    if(NCoro::PollD(cont, connection.PqSocket(), CONT_POLL_WRITE, deadline)) {
                        return {};
                    }
                    break;
                }
                case EPollingStatus::InProgress:
                    cont->Yield();
                    break;
            }
        }

    }

    TResWithError<TDeque<TPgConnectionHolder>> GetConnection(TVector<TString> hosts, TStringStream &errStream, TInstant /*deadline*/) {
        Shuffle(hosts.begin(), hosts.end());
        TDeque<TPgConnectionHolder> connections;

        for(const auto & host : hosts) {
            TPgConnectionHolder connection = TPgConnectionHolder::PqConnectDb(host);
            if(!connection.Ok()) {
                Cerr << "Cannot connect: " << connection << Endl;
                errStream << connection;
                continue;
            }

            connections.emplace_back(std::move(connection));
        }
        return ResWithError(std::move(connections));
    }

    template<sql::Connection::InitStrategy strategy>
    static TDeque<TString> GetHostsByStrategy(TVector<TString>& hosts, TConnectionsTRaits& traits) {
        TDeque<TString> result;
        Shuffle(hosts.begin(), hosts.end());

        std::copy_if(hosts.cbegin(), hosts.cend(), std::back_inserter(result), [&traits](const TString & host) {
            try {
                return Connection::goodMode<strategy>(traits.GetTraits(host).isReplica);
            } catch (const std::exception&) {
                return false;
            }
        });

        if (strategy == Connection::RandomReplica && result.empty())
            return GetHostsByStrategy<Connection::Master>(hosts, traits);

        return result;
    }

    template <> TVoidWithError Connection::Init<Connection::Master>(TInstant /*deadline*/) {
        conn.Reset();
        for (const auto & host : GetHostsByStrategy<Connection::Master>(hosts, *traits)) {
            TPgConnectionHolder connection = TPgConnectionHolder::PqConnectDb(host);
            if(!connection.Ok())
                continue;

            conn = std::move(connection);
            break;
        }

        if(!conn.Ok())
            ythrow TWithBackTrace<yexception>() << "there is no valid connections: " << MakeRangeJoiner(",", hosts);

        lastROCheck = Now();
        return setNonBlockng(true);
    }

    template <> TVoidWithError Connection::Init<Connection::Master>(TCont * cont, TInstant deadline) {
        conn.Reset();

        for (const auto & host : GetHostsByStrategy<Connection::Master>(hosts, *traits)) {
            TPgConnectionHolder connection = ConnectAsync(host, deadline, cont);
            if(!connection.Ok())
                continue;

            conn = std::move(connection);
            break;
        }

        if(!conn.Ok())
            ythrow TWithBackTrace<yexception>() << "there is no valid connections: " << MakeRangeJoiner(",", hosts);

        lastROCheck = Now();
        return setNonBlockng(true);
    }

    template <> TVoidWithError Connection::Init<Connection::RandomReplica>(TInstant /*deadline*/) {
        conn.Reset();
        for (const auto & host : GetHostsByStrategy<Connection::RandomReplica>(hosts, *traits)) {
            TPgConnectionHolder connection = TPgConnectionHolder::PqConnectDb(host);
            if(!connection.Ok())
                continue;

            conn = std::move(connection);
            break;
        }

        if(!conn.Ok())
            ythrow TWithBackTrace<yexception>() << "there is no valid connections: " << MakeRangeJoiner(",", hosts);

        lastROCheck = Now();
        return setNonBlockng(true);
    }

    template <> TVoidWithError Connection::Init<Connection::RandomReplica>(TCont * cont, TInstant deadline) {
        conn.Reset();
        for (const auto & host : GetHostsByStrategy<Connection::RandomReplica>(hosts, *traits)) {
            TPgConnectionHolder connection = ConnectAsync(host, deadline, cont);
            if(!connection.Ok())
                continue;

            conn = std::move(connection);
            break;
        }

        if(!conn.Ok())
            ythrow TWithBackTrace<yexception>() << "there is no valid connections: " << MakeRangeJoiner(",", hosts);

        lastROCheck = Now();
        return setNonBlockng(true);
    }

    TVoidWithError Connection::setNonBlockng(bool nonblocking) {
        bool mustChange = bool(conn.PqIsnonblocking()) != nonblocking;
        if (mustChange && conn.PqSetnonblocking(nonblocking ? 1 : 0) == -1)
            return TInterfaceError(TInterfaceError::SetNonblocking, conn.Info());

        return TVoidWithError();
    }

    Connection::Connection(const std::vector<std::string>& _hosts)
            : hosts(_hosts.size())
    {
        std::copy(_hosts.cbegin(), _hosts.cend(), hosts.begin());
    }

    Connection::Connection(const TVector<TString>& _hosts)
            : hosts(_hosts)
    {
    }
    Connection::Connection(const TString & connectionString) {
        Split(connectionString, ",", hosts);
    }
} /* namespace sql */
