#include "read_rules.h"

namespace NSaasLB {

    TReadRule::TReadRule(const TString& topicPath, const TString& consumerPath, EDataCenter dc)
        : TopicPath(GetCorrectEntityPath(topicPath))
        , ConsumerPath(GetCorrectEntityPath(consumerPath))
        , DataCenter(dc)
    {}

    TReadRule::TReadRule(const NLogBroker::ReadRule& readRule)
        : TReadRule(
            readRule.read_rule().topic().path(),
            readRule.read_rule().consumer().path(),
            GetDataCenter(readRule.read_rule())
        )
    {}

    const TString& TReadRule::GetTopicPath() const {
        return TopicPath;
    }

    const TString& TReadRule::GetConsumerPath() const {
        return ConsumerPath;
    }

    EDataCenter TReadRule::GetDataCenter() const {
        return DataCenter;
    }

    NLogBroker::ReadRule TReadRule::GetProto() const {
        NLogBroker::ReadRule readRule;
        FillEntity(*readRule.mutable_read_rule());
        return readRule;
    }

    NLogBroker::SingleModifyRequest TReadRule::GetAddCommand() const {
        NLogBroker::SingleModifyRequest request;
        FillEntity(*request.mutable_create_read_rule()->mutable_read_rule());
        return request;
    }

    NLogBroker::SingleModifyRequest TReadRule::GetRemoveCommand() const {
        NLogBroker::SingleModifyRequest request;
        FillEntity(*request.mutable_remove_read_rule()->mutable_read_rule());
        return request;
    }

    bool TReadRule::operator== (const TReadRule& other) const {
        return TopicPath == other.TopicPath &&
            ConsumerPath == other.ConsumerPath &&
            DataCenter == other.DataCenter;
    }

    bool TReadRule::operator< (const TReadRule& other) const {
         if (TopicPath != other.TopicPath) {
            return TopicPath < other.TopicPath;
         }
         if (ConsumerPath != other.ConsumerPath) {
            return ConsumerPath < other.ConsumerPath;
         }
         return DataCenter < other.DataCenter;
    }

    EDataCenter TReadRule::GetDataCenter(const NLogBroker::ReadRuleKey& readRuleKey) {
        if (readRuleKey.has_all_original()) {
            return EDataCenter::original;
        }
        if (readRuleKey.mirror_to_cluster().cluster() == "kafka-bs") {  // LOGBROKER-5057
            return EDataCenter::original;
        }
        return FromString(readRuleKey.mirror_to_cluster().cluster());
    }

    void TReadRule::FillEntity(NLogBroker::ReadRuleKey& readRule) const {
        readRule.mutable_topic()->set_path(TopicPath);
        readRule.mutable_consumer()->set_path(ConsumerPath);
        if (DataCenter == EDataCenter::original) {
            readRule.mutable_all_original();
        } else {
            readRule.mutable_mirror_to_cluster()->set_cluster(ToString(DataCenter));
        }
    }

    bool TReadRuleStorage::Exist(const TReadRule& readRule) const {
        auto consumerRules = RulesByConsumers.find(readRule.GetConsumerPath());
        if (consumerRules == RulesByConsumers.end()) {
            return false;
        }
        auto it = std::find(consumerRules->second.begin(), consumerRules->second.end(), readRule);
        return it != consumerRules->second.end();
    }

    std::optional<NLogBroker::SingleModifyRequest> TReadRuleStorage::Add(const TReadRule& readRule) {
        if (Exist(readRule)) {
            return {};
        }
        RulesByConsumers[readRule.GetConsumerPath()].insert(readRule);
        RulesByTopics[readRule.GetTopicPath()].insert(readRule);
        return readRule.GetAddCommand();
    }

    std::optional<NLogBroker::SingleModifyRequest> TReadRuleStorage::Remove(const TReadRule& readRule) {
        if (!Exist(readRule)) {
            return {};
        }
        Remove(RulesByConsumers, readRule, readRule.GetConsumerPath());
        Remove(RulesByTopics, readRule, readRule.GetTopicPath());
        return readRule.GetRemoveCommand();
    }

    TVector<TReadRule> TReadRuleStorage::GetFor(const IEntity& entity) const {
        if (entity.GetType() == EEntityType::topic) {
            return GetFor(RulesByTopics, entity.GetPath());
        } else if (entity.GetType() == EEntityType::consumer) {
            return GetFor(RulesByConsumers, entity.GetPath());
        }
        return {};
    }

    void TReadRuleStorage::Remove(
        THashMap<TString, TSet<TReadRule>>& rulesByEntity,
        const TReadRule& readRule,
        const TString& path
    ) {
        auto rulesForEntity = rulesByEntity.find(path);
        if (rulesForEntity == rulesByEntity.end()) {
            return;
        }
        auto& rules = rulesForEntity->second;
        auto it = std::find(rules.begin(), rules.end(), readRule);
        if (it != rules.end()) {
            rules.erase(it);
        }
    }

    TVector<TReadRule> TReadRuleStorage::GetFor(
        const THashMap<TString, TSet<TReadRule>>& rulesByEntity,
        const TString& path
    ) const {
        auto it = rulesByEntity.find(path);
        if (it == rulesByEntity.end()) {
            return {};
        }

        return {it->second.begin(), it->second.end()};
    }

}
