#include "tools.h"

#include "loader.h"
#include "pubsub.h"

#include <maps/libs/concurrent/include/threadpool.h>
#include <maps/libs/common/include/make_batches.h>
#include <yandex/maps/wiki/revision/revisionsgateway.h>
#include <yandex/maps/wiki/common/string_utils.h>
#include <maps/libs/log8/include/log8.h>

#include <boost/noncopyable.hpp>

#include <atomic>
#include <csignal>
#include <map>
#include <numeric>
#include <sstream>
#include <fstream>

namespace maps {
namespace wiki {
namespace contours {
namespace {

/**
 * Postgres hangs on batch size:
 * 8000 - production commits 2400-2500
 * 2000 - production commits 6000-7000
 * 1000 - production commits 19000-20000
 */
constexpr size_t OBJECTS_BATCH = 1;
constexpr size_t OBJECTS_DATA_BATCH = 100;

revision_meta::TCommitId
headCommitId(common::PoolHolder& pool)
{
    auto txn = pool.pool().masterReadOnlyTransaction();
    revision::RevisionsGateway gw(*txn);
    return revision::RevisionsGateway(*txn).headCommitId();
}

class ThreadpoolWrapper : boost::noncopyable {
public:
    explicit ThreadpoolWrapper(size_t threadsNumber)
        : threads_{
            concurrent::ThreadsNumber(threadsNumber),
            concurrent::QueueCapacity(threadsNumber)}
    {
        threads_.setErrorHandler([this](const std::exception&) {
            exceptions_.push(std::current_exception());
        });
    }

    void checkExceptions()
    {
        if (auto maybeException = exceptions_.tryPop()) {
            std::rethrow_exception(*maybeException);
        }
    }

    concurrent::ThreadPool& get() { return threads_; }

private:
    concurrent::UnboundedBlockingQueue<std::exception_ptr> exceptions_;
    concurrent::ThreadPool threads_;
};

/**
 * \brief thread-safe read/write batch handler
 */
class BatchHandler : boost::noncopyable {
public:
    BatchHandler(Loader& from, common::PoolHolder& to)
        : from_(from)
        , to_(to)
        , objectCount_(0)
        , polygonCount_(0)
        , invalidCount_(0)
        , deletedCount_(0)
    {}

    void handle(const revision_meta::TObjectIdSet& batch)
    {
        constexpr size_t ATTEMPTS_LIMIT = 3;
        for (size_t attempt = 0; true; ++attempt) {
            try {
                handleImpl(batch);
                break;
            } catch (const pqxx::data_exception& e) {
                if (attempt >= ATTEMPTS_LIMIT) {
                    throw;
                }
                WARN() << "retrying (" << e.what() << ")";
            } catch (const std::exception& e) {
                if (attempt >= ATTEMPTS_LIMIT) {
                    throw;
                }
                WARN() << "retrying (" << e.what() << ")";
            }
        }
    }

private:
    void handleImpl(const revision_meta::TObjectIdSet& batch)
    {
        INFO() << "start objects: " << common::join(batch, ',');
        auto objs = from_.load(batch);
        auto txn = to_.pool().masterWriteableTransaction();
        auto& work = *txn;
        auto data = dumpToSql(objs, work);
        for (const auto& batch: ::maps::common::makeBatches(data, OBJECTS_DATA_BATCH)) {
            std::ostringstream os;
            for (const auto& str: batch) {
                os << str;
            }
            work.exec(os.str());
        }
        work.commit();
        INFO() << "end objects: " << common::join(batch, ',');
        objectCount_ += objs.size();
        polygonCount_ += getPolygonCount(objs);
        invalidCount_ += getInvalidCount(objs);
        deletedCount_ += getDeletedCount(objs);
        INFO()
            << "objects: " << objectCount_
            << ", polygons: " << polygonCount_
            << ", invalid: " << invalidCount_
            << ", deleted: " << deletedCount_;
    }

private:
    Loader& from_;
    common::PoolHolder& to_;
    std::atomic_size_t objectCount_;
    std::atomic_size_t polygonCount_;
    std::atomic_size_t invalidCount_;
    std::atomic_size_t deletedCount_;
};

size_t estimateWorkerThreadsNumber
    ( common::PoolHolder& revisionPool
    , common::PoolHolder& viewPool
    )
{
    constexpr size_t RESERVED_VIEW_CONNECTIONS = 1; // pubsub lock
    REQUIRE
        ( viewPool.pool().state().constants.masterMaxSize > RESERVED_VIEW_CONNECTIONS
        , "write pool size must be greater than " << RESERVED_VIEW_CONNECTIONS
        );
    auto viewConnections
        = viewPool.pool().state().constants.masterMaxSize
        - RESERVED_VIEW_CONNECTIONS;
    auto revisionConnections = revisionPool.pool().state().constants.slaveMaxSize;
    auto result = std::min(viewConnections, revisionConnections);
    DEBUG() << "threads number " << result;
    return result;
}

void write
    ( Loader& loader
    , common::PoolHolder& viewPool
    , const revision_meta::TObjectIdSet& ids
    , size_t threadsNumber
    )
{
    BatchHandler batcher(loader, viewPool);
    revision_meta::TObjectIdSet batch;
    ThreadpoolWrapper threads {threadsNumber};
    for (const auto id: ids) {
        batch.insert(id);
        if (batch.size() == OBJECTS_BATCH) {
            threads.checkExceptions();
            threads.get().add([&batcher, batch = std::move(batch)] {
                batcher.handle(batch);
            });
            batch.clear();
        }
    }
    if (!batch.empty()) {
        batcher.handle(batch);
    }
    threads.get().join();
    threads.checkExceptions();
}

revision_meta::TObjectIdSet getObjectIds(common::PoolHolder& viewPool)
{
    revision_meta::TObjectIdSet result;
    /**
     * We use master for data integrity.
     * Performance doesn't suffer because it is a rare function call.
     */
    auto txn = viewPool.pool().masterReadOnlyTransaction();
    for (const auto& row: txn->exec(sqlSelectObjectIds())) {
        result.insert(row[0].as<revision_meta::TObjectId>());
    }
    return result;
}

revision_meta::TObjectIdSet getDifference
    ( common::PoolHolder& viewPool
    , const revision_meta::TObjectIdSet& ids
    )
{
    auto viewIds = getObjectIds(viewPool);
    revision_meta::TObjectIdSet result;
    std::set_difference
        ( viewIds.begin(), viewIds.end()
        , ids.begin(), ids.end()
        , std::inserter(result, result.end())
        );
    return result;
}

void remove
    ( common::PoolHolder& viewPool
    , const revision_meta::TObjectIdSet& ids
    )
{
    auto txn = viewPool.pool().masterWriteableTransaction();
    auto& work = *txn;
    work.exec(sqlDelete(ids));
    work.commit();
}

} // namespace

void removeDifference
    ( common::PoolHolder& viewPool
    , const revision_meta::TObjectIdSet& ids
    )
{
    auto diff = getDifference(viewPool, ids);
    revision_meta::TObjectIdSet batch;
    for (auto id: diff) {
        batch.insert(id);
        if (batch.size() == OBJECTS_BATCH) {
            remove(viewPool, batch);
            batch.clear();
        }
    }
    if (!batch.empty()) {
        remove(viewPool, batch);
    }
}

bool handleNextCommits
    ( common::PoolHolder& revisionPool
    , common::PoolHolder& viewPool
    , const Params& params
    ) try
{
    auto consumerTxn = viewPool.pool().masterWriteableTransaction();
    PubsubWrapper consumer(*consumerTxn, params.commitsBatchSize);
    auto commitRange = consumer.consumeBatch(*revisionPool.pool().slaveTransaction());
    if (commitRange.second == INVALID_COMMIT_ID) { // nothing to do
        return false;
    }

    Loader loader(revisionPool, commitRange.second, params);
    revision_meta::TObjectIdSet ids;
    const bool firstTime{commitRange.first == INVALID_COMMIT_ID};
    if (firstTime) {
        ids = loader.allObjectIds();
    } else {
        ids = loader.affectedObjectIds(commitRange.first);
    }
    INFO() << "objects " << ids.size();

    const auto threadsNumber = estimateWorkerThreadsNumber(revisionPool, viewPool);
    write(loader, viewPool, ids, threadsNumber);

    if (firstTime) {
        removeDifference(viewPool, ids);
    }

    consumerTxn->commit();
    INFO() << "commit " << commitRange.second << " is done";
    return true;
} catch (const pubsub::AlreadyLockedException&) {
    return false;
}

namespace {
revision_meta::TObjectIdSet
parseIdsFile(const std::string& idsFile)
{
    revision_meta::TObjectIdSet ids;
    std::ifstream file(idsFile);
    REQUIRE(file, "Failed to open file: " << idsFile);
    for (revision_meta::TObjectId id; file >> id; ) {
        ids.insert(id);
    }
    return ids;
}
} // namespace

void handleIdsFromFile
    ( common::PoolHolder& revisionPool
    , common::PoolHolder& viewPool
    , const Params& params
    )
{
    revision_meta::TObjectIdSet objectIds = parseIdsFile(params.idsFile);
    if (objectIds.empty()) {
        return;
    }
    auto commitId = headCommitId(revisionPool);
    INFO() << "Head commit id: " << commitId;
    INFO() << "Object ids count: " << objectIds.size();
    Loader loader(revisionPool, commitId, params);
    const auto threadsNumber = estimateWorkerThreadsNumber(revisionPool, viewPool);
    write(loader, viewPool, objectIds, threadsNumber);
    INFO() << "File " << params.idsFile << " done.";
}

void resetWatermark(common::PoolHolder& viewPool, const Params& params)
{
    while (true) {
        try {
            auto consumerTxn = viewPool.pool().masterWriteableTransaction();
            PubsubWrapper consumer(*consumerTxn, params.commitsBatchSize);
            consumer.reset();
            consumerTxn->commit();
            break;
        } catch (const pubsub::AlreadyLockedException&) {
            continue;
        }
    }
}

SignalListener& SignalListener::singleton()
{
    static SignalListener listener;
    return listener;
}

bool SignalListener::isStopped()
{
    std::unique_lock<std::mutex> lck(mtx_);
    return isStopped_;
}

SignalListener::SignalListener()
    : isStopped_ {false}
{
    signal(SIGINT, signalHandler);
    signal(SIGQUIT, signalHandler);
    signal(SIGTERM, signalHandler);
}

void SignalListener::detach()
{
    signal(SIGINT, SIG_IGN);
    signal(SIGQUIT, SIG_IGN);
    signal(SIGTERM, SIG_IGN);
}

void SignalListener::stop()
{
    detach();
    {
        std::unique_lock<std::mutex> lck(mtx_);
        if (isStopped_) {
            return;
        }
        isStopped_ = true;
    }
    condVar_.notify_one();
}

void SignalListener::signalHandler(int sig)
{
    switch(sig) {
        case SIGINT:
        case SIGQUIT:
        case SIGTERM:
            {
                INFO() << "got " << strsignal(sig) << ", starting shutdown sequence...";
                singleton().stop();
            }
            break;
        default:
            {
                ERROR() << "unhandled signal (" << sig << ") " << strsignal(sig);
            }
            break;
    }
}

} // contours
} // wiki
} // maps
