#include "clock_tower_provider.h"

#include <yandex_io/libs/ipc/i_ipc_factory.h>
#include <yandex_io/libs/signals/live_data.h>
#include <yandex_io/protos/quasar_proto.pb.h>

#include <time.h>
#include <unordered_set>

using namespace quasar;

class ClockTowerProvider::Clock: public IClock {
public:
    Clock(std::string deviceId, std::string host, int port, std::string clockId, int64_t diff, SyncLevel syncLevel)
        : deviceId_(std::move(deviceId))
        , host_(std::move(host))
        , port_(port)
        , clockId_(clockId)
        , diff_(diff)
        , syncLevel_(syncLevel)
    {
        static_assert((int)SyncLevel::NONE == (int)proto::ClockTowerSync::NOSYNC);
        static_assert((int)SyncLevel::WEAK == (int)proto::ClockTowerSync::WEAK);
        static_assert((int)SyncLevel::STRONG == (int)proto::ClockTowerSync::STRONG);
    }

    const std::string& deviceId() const noexcept override {
        return deviceId_;
    }

    const std::string& host() const noexcept override {
        return host_;
    }

    int port() const noexcept override {
        return port_;
    }

    const std::string& clockId() const noexcept override {
        return clockId_;
    }

    std::chrono::nanoseconds now() const noexcept override {
        constexpr int64_t NANOSECONDS_IN_SECOND = 1000000000;
        struct timespec ts;
        clock_gettime(CLOCK_MONOTONIC_RAW, &ts);
        return std::chrono::nanoseconds{ts.tv_sec * NANOSECONDS_IN_SECOND + ts.tv_nsec - diff_.load()};
    }

    std::chrono::nanoseconds diff() const noexcept override {
        return std::chrono::nanoseconds{diff_.load()};
    }

    SyncLevel syncLevel() const noexcept override {
        return syncLevel_.load();
    }

    bool expired() const noexcept override {
        return expired_.load();
    }

    void setDiff(int64_t diff) {
        diff_ = diff;
    }

    void setSyncLevel(SyncLevel syncLevel) {
        syncLevel_ = syncLevel;
    }

    void setExpired()
    {
        expired_ = true;
    }

private:
    const std::string deviceId_;
    const std::string host_;
    const int port_;
    const std::string clockId_;
    std::atomic<int64_t> diff_;
    std::atomic<SyncLevel> syncLevel_{SyncLevel::NONE};
    std::atomic<bool> expired_{false};
};

ClockTowerProvider::ClockTowerProvider(std::shared_ptr<ipc::IIpcFactory> ipcFactory)
    : ClockTowerProvider(ipcFactory->createIpcConnector("audioclient"))
{
}

ClockTowerProvider::ClockTowerProvider(std::shared_ptr<ipc::IConnector> connector)
    : connector_(std::move(connector))
    , clockTowerState_(std::make_shared<ClockTowerState>())
{
    connector_->setMessageHandler(
        makeSafeCallback([this](const auto& message) {
            if (message->has_clock_tower_sync()) {
                updateClockTowerState(message->clock_tower_sync());
            }
        }, lifetime_));
    connector_->connectToService();
}

ClockTowerProvider::IClockTowerState& ClockTowerProvider::clockTowerState()
{
    return clockTowerState_;
}

void ClockTowerProvider::addClock(std::string deviceId, std::string host, int port, std::string clockId) {
    auto message = ipc::buildMessage([&](auto& msg) {
        auto& clock = *msg.mutable_add_clock_tower();
        clock.set_device_id(TString(deviceId));
        clock.set_host(TString(host));
        clock.set_port(port);
        clock.set_clock_id(TString(clockId));
    });
    connector_->sendMessage(message);
}

std::map<std::string, std::chrono::nanoseconds> ClockTowerProvider::dumpAllClocks() const {
    std::map<std::string, std::chrono::nanoseconds> result;

    std::lock_guard lock(mutex_);
    if (localClock_) {
        result[std::string{localClock_->clockId()}] = localClock_->now();
    }
    for (const auto& [id, clock] : remoteClocks_) {
        result[std::string{clock->clockId()}] = clock->now();
    }
    return result;
}

void ClockTowerProvider::updateClockTowerState(const proto::ClockTowerSync& clockSync)
{
    auto eq =
        [](const Clock& clock, const std::string& deviceId, const std::string& host, int port, const std::string& clockId)
    {
        return clock.deviceId() == deviceId && clock.host() == host && clock.port() == port && clock.clockId() == clockId;
    };

    std::lock_guard hold(clockTowerState_);
    std::lock_guard lock(mutex_);
    auto state = clockTowerState_.value();
    if (!clockSync.has_local_clock()) {
        localClock_.reset();
    } else {
        const auto& pClock = clockSync.local_clock();
        if (pClock.clock_id().empty() ||
            pClock.device_id().empty() ||
            (localClock_ && !eq(*localClock_, pClock.device_id(), pClock.clock_host(), pClock.clock_port(), pClock.clock_id()))) {
            if (localClock_) {
                localClock_->setExpired();
            }
            localClock_.reset();
        }

        const auto syncLevel = static_cast<Clock::SyncLevel>(pClock.sync_level());
        if (localClock_) {
            localClock_->setDiff(pClock.diff_ns());
            localClock_->setSyncLevel(syncLevel);
        } else {
            localClock_ = std::make_shared<Clock>(pClock.device_id(), pClock.clock_host(), pClock.clock_port(), pClock.clock_id(), pClock.diff_ns(), syncLevel);
        }
    }

    std::unordered_set<std::string_view> knownClockIds;
    knownClockIds.reserve(clockSync.remote_clock_size());
    for (const auto& pClock : clockSync.remote_clock()) {
        const auto syncLevel = static_cast<Clock::SyncLevel>(pClock.sync_level());
        auto it = remoteClocks_.find(pClock.clock_id());
        if (it != remoteClocks_.end()) {
            if (!eq(*it->second, pClock.device_id(), pClock.clock_host(), pClock.clock_port(), pClock.clock_id())) {
                it = remoteClocks_.end();
            }
        }
        if (it != remoteClocks_.end()) {
            it->second->setDiff(pClock.diff_ns());
            it->second->setSyncLevel(syncLevel);
        } else {
            remoteClocks_[pClock.clock_id()] =
                std::make_shared<Clock>(pClock.device_id(), pClock.clock_host(), pClock.clock_port(), pClock.clock_id(), pClock.diff_ns(), syncLevel);
        }
        knownClockIds.insert(pClock.clock_id());
    }

    for (auto it = remoteClocks_.begin(); it != remoteClocks_.end();) {
        if (!knownClockIds.count(it->second->clockId())) {
            it->second->setExpired();
            it = remoteClocks_.erase(it);
        } else {
            ++it;
        }
    }

    bool fChanges = (localClock_ != state->localClock) || remoteClocks_.size() != state->remoteClocks.size();
    if (!fChanges) {
        auto it = remoteClocks_.begin();
        auto jt = state->remoteClocks.begin();
        for (; !fChanges && it != remoteClocks_.end() && jt != state->remoteClocks.end(); ++it, ++jt) {
            fChanges = (it->first != jt->first || it->second != jt->second);
        }
    }

    if (fChanges) {
        ClockTowerState newState;
        newState.localClock = localClock_;
        for (const auto& [id, clock] : remoteClocks_) {
            newState.remoteClocks[id] = clock;
        }
        clockTowerState_ = std::make_shared<ClockTowerState>(std::move(newState));
    }
}
