#pragma once

#include "group_settings.h"
#include "task.h"

#include <mail/ratesrv/src/common/post_wrapper.h>
#include <mail/ratesrv/src/common/format.h>

#include <yplatform/reactor.h>
#include <yplatform/log.h>
#include <yplatform/time_traits.h>

#include <util/random/random.h>
#include <util/system/types.h>

#include <boost/asio.hpp>

#include <atomic>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <unordered_map>
#include <vector>

namespace NRateSrv::NScheduler {

struct TGroup {
    TGroup(boost::asio::io_context& io, TGroupSettings settings)
        : Removed(false)
        , Strand(io)
        , Timer(io)
        , Settings(std::move(settings))
    {}

    std::atomic_bool Removed;
    boost::asio::io_context::strand Strand;
    yplatform::time_traits::timer Timer;
    TGroupSettings Settings;
    std::mutex Lock;
    std::unordered_map<ui64, ITaskPtr> Tasks;
    std::unordered_map<ui64, ITaskPtr> NewTasks;
};

using TGroupPtr = std::shared_ptr<TGroup>;

class TSchedulerImpl : public yplatform::log::contains_logger {
public:
    explicit TSchedulerImpl(yplatform::reactor& reactor);

    void Stop();

    ui64 CreateGroup(TGroupSettings settings);
    void RemoveGroup(ui64 groupId);

    ui64 AddTask(ITaskPtr&& task, ui64 groupId);
    std::vector<ui64> AddTasks(std::vector<ITaskPtr>&& tasks, ui64 groupId);
    void RemoveTask(ui64 taskId, ui64 groupId);

    void ProcessGroup(TGroupPtr group, ui64 groupId, std::vector<ui64> newTasks);
    void RunTasks(TGroupPtr group, ui64 groupId, std::vector<ui64> tasks);
    bool RunTask(TGroup* group, ui64 groupId, ITask* task, ui64 taskId);

private:
    TGroupPtr GetGroup(ui64 groupId);
    void StopGroup(TGroupPtr group);

private:
    yplatform::reactor& Reactor;
    std::mutex Lock;
    std::unordered_map<ui64, TGroupPtr> Groups;
};

TSchedulerImpl::TSchedulerImpl(yplatform::reactor& reactor)
    : Reactor(reactor)
{}

void TSchedulerImpl::Stop() {
    std::lock_guard guard(Lock);
    for (auto& [id, group] : Groups) {
        StopGroup(std::move(group));
    }
    Groups.clear();
}

void TSchedulerImpl::ProcessGroup(TGroupPtr group, ui64 groupId, std::vector<ui64> newTasks) {
    if ((!newTasks.empty() && group->Settings.Policy == EExecutionPolicyWhenTaskAdding::RunAll) ||
        group->Settings.Duration == TDuration::zero())
    {
        boost::asio::post(group->Strand, [this, group, groupId]() {
            RunTasks(std::move(group), groupId, {});
        });
    } else if (!newTasks.empty() && group->Settings.Policy == EExecutionPolicyWhenTaskAdding::RunOne) {
        boost::asio::post(group->Strand, [this, group, groupId, tasks = std::move(newTasks)]() {
            RunTasks(std::move(group), groupId, std::move(tasks));
        });
    } else if (newTasks.empty() || group->Tasks.size() == newTasks.size()) {
        group->Timer.expires_after(group->Settings.Duration);
        group->Timer.async_wait(MakePostWrapper(
            group->Strand,
            [this, group, groupId](const boost::system::error_code& ec) {
                if (ec == boost::asio::error::operation_aborted) {
                    return;
                }
                RunTasks(std::move(group), groupId, {});
            }
        ));
    }
}

void TSchedulerImpl::RunTasks(TGroupPtr group, ui64 groupId, std::vector<ui64> tasks) {
    if (tasks.empty() || group->Tasks.size() == tasks.size()) {
        YLOG_L(info) << Format("Run all tasks in group %1%", groupId);

        for (auto& [id, task] : group->Tasks) {
            if (!RunTask(group.get(), groupId, task.get(), id)) {
                break;
            }
        }

        if (!group->Removed.load(std::memory_order_acquire) && group->Tasks.size() > 0) {
            ProcessGroup(std::move(group), groupId, {});
        }
    } else {
        YLOG_L(info) << Format("Run %1% task(s) in group %2%", tasks.size(), groupId);

        for (ui64 taskId : tasks) {
            auto it = group->Tasks.find(taskId);
            if (it == group->Tasks.end()) {
                continue;
            }
            if (!RunTask(group.get(), groupId, it->second.get(), taskId)) {
                break;
            }
        }
    }
}

bool TSchedulerImpl::RunTask(TGroup* group, ui64 groupId, ITask* task, ui64 taskId) {
    if (group->Removed.load(std::memory_order_acquire)) {
        return false;
    }
    try {
        task->Run();
    } catch (...) {
        if (group->Settings.TaskErrorHandler) {
            group->Settings.TaskErrorHandler(std::current_exception());
        } else {
            YLOG_L(error) << Format("Unhandled exception in task %1% of group %2%", taskId, groupId);
        }
    }

    return true;
}

TGroupPtr TSchedulerImpl::GetGroup(ui64 groupId) {
    std::lock_guard guard(Lock);
    auto it = Groups.find(groupId);
    if (it == Groups.end()) {
        throw std::invalid_argument(Format("Unknown group with id %1%", groupId));
    }

    return it->second;
}

ui64 TSchedulerImpl::CreateGroup(TGroupSettings settings) {
    ui64 groupId;
    std::lock_guard guard(Lock);

    do {
        groupId = RandomNumber<ui64>();
    } while (Groups.count(groupId) > 0);

    Groups.emplace(groupId, std::make_shared<TGroup>(*Reactor.io(), std::move(settings)));

    return groupId;
}

void TSchedulerImpl::RemoveGroup(ui64 groupId) {
    std::lock_guard guard(Lock);
    auto it = Groups.find(groupId);
    if (it == Groups.end()) {
        throw std::invalid_argument(Format("Unknown group with id %1%", groupId));
    }

    StopGroup(std::move(it->second));
    Groups.erase(it);
}

void TSchedulerImpl::StopGroup(TGroupPtr group) {
    group->Removed.store(true, std::memory_order_release);

    boost::asio::post(group->Strand, [group]() {
        group->Timer.cancel();
    });

    for (auto& [id, task] : group->Tasks) {
        task->Cancel();
    }
}

ui64 TSchedulerImpl::AddTask(ITaskPtr&& task, ui64 groupId) {
    std::vector<ITaskPtr> tasks;
    tasks.push_back(std::move(task));
    return AddTasks(std::move(tasks), groupId).front();
}

std::vector<ui64> TSchedulerImpl::AddTasks(std::vector<ITaskPtr>&& tasks, ui64 groupId) {
    if (tasks.empty()) {
        return {};
    }

    std::vector<ui64> taskIds(tasks.size());
    auto group = GetGroup(groupId);
    std::lock_guard guard(group->Lock);

    for (size_t i = 0; i < tasks.size(); ++i) {
        if (!tasks[i]) {
            std::invalid_argument("Task is undefined");
        }

        ui64& taskId = taskIds[i];
        do {
            taskId = RandomNumber<ui64>();
        } while (group->Tasks.count(taskId) > 0 && group->NewTasks.count(taskId) > 0);
    }

    for (size_t i = 0; i < tasks.size(); ++i) {
        group->NewTasks.emplace(taskIds[i], std::move(tasks[i]));
    }

    if (group->NewTasks.size() == tasks.size()) {
        boost::asio::post(group->Strand, [this, group, groupId]() {
            std::lock_guard guard(group->Lock);

            std::vector<ui64> taskIds;
            taskIds.reserve(group->NewTasks.size());
            for (const auto& [id, task] : group->NewTasks) {
                taskIds.push_back(id);
            }

            group->Tasks.merge(group->NewTasks);
            ProcessGroup(std::move(group), groupId, std::move(taskIds));
        });
    }

    return taskIds;
}

void TSchedulerImpl::RemoveTask(ui64 taskId, ui64 groupId) {
    auto group = GetGroup(groupId);

    std::lock_guard guard(group->Lock);
    auto it = group->Tasks.find(taskId);
    if (it == group->Tasks.end()) {
        throw std::invalid_argument(Format("Unknown task with id %1% in group %2%", taskId, groupId));
    }
    auto& task = it->second;

    task->Cancel();

    boost::asio::post(group->Strand, [taskId, group]() {
        std::lock_guard guard(group->Lock);
        group->Tasks.erase(taskId);
    });
}

} // namespace NRateSrv::NScheduler
