#include "offer_tracking_state.h"

namespace NTravel::NPriceChecker {
    TOfferTrackingState::TOfferTrackingState(const TString& key,
                                             const NTravelProto::TSearchOffersReq& requestPb,
                                             const NTravelProto::TOffer& offerPb,
                                             TInstant timestamp,
                                             TOfferTrackerCounters* counters)
        : Key(key)
        , StateType_(EOfferTrackingStateType::New)
        , Counters_(counters)
    {
        OfferTrackingResult.InitialRequest = requestPb;

        OfferTrackingResult.InitialOffer = offerPb;
        // Clear memory-consuming fields
        OfferTrackingResult.InitialOffer.ClearLandingInfo();

        OfferTrackingResult.InitialTimestamp = timestamp;
    }

    bool TOfferTrackingState::IsActive() {
        return StateType_ == EOfferTrackingStateType::WaitingForScheduling ||
               StateType_ == EOfferTrackingStateType::Scheduled ||
               StateType_ == EOfferTrackingStateType::InFly;
    }

    bool TOfferTrackingState::IsTerminating() {
        return StateType_ == EOfferTrackingStateType::Done ||
               StateType_ == EOfferTrackingStateType::Reported;
    }

    bool TOfferTrackingState::IsReported() {
        return StateType_ == EOfferTrackingStateType::Reported;
    }

    bool TOfferTrackingState::IsReadyToSchedule() {
        return StateType_ == EOfferTrackingStateType::WaitingForScheduling;
    }

    void TOfferTrackingState::ToWaitingForScheduling() {
        TrySwitch(EOfferTrackingStateType::WaitingForScheduling, [this]() {
            if (StateType_ == EOfferTrackingStateType::New) {
                Counters_->NActiveTrackings.Inc();
                Counters_->NWaitingChecks.Inc();
                return true;
            }
            if (StateType_ == EOfferTrackingStateType::InFly) {
                Counters_->NInFlyChecks.Dec();
                Counters_->NWaitingChecks.Inc();
                return true;
            }
            return false;
        });
    }

    bool TOfferTrackingState::IsWaiting() {
        return StateType_ == EOfferTrackingStateType::WaitingForScheduling || StateType_ == EOfferTrackingStateType::Scheduled;
    }

    bool TOfferTrackingState::IsInFly() {
        return StateType_ == EOfferTrackingStateType::InFly;
    }

    void TOfferTrackingState::ToScheduled() {
        TrySwitch(EOfferTrackingStateType::Scheduled, [this]() {
            if (StateType_ == EOfferTrackingStateType::WaitingForScheduling) {
                Counters_->NScheduledChecks.Inc();
                return true;
            }
            return false;
        });
    }

    void TOfferTrackingState::ToInFly() {
        TrySwitch(EOfferTrackingStateType::InFly, [this]() {
            if (StateType_ == EOfferTrackingStateType::Scheduled) {
                Counters_->NWaitingChecks.Dec();
                Counters_->NInFlyChecks.Inc();
                return true;
            }
            return false;
        });
    }

    void TOfferTrackingState::ToDone() {
        TrySwitch(EOfferTrackingStateType::Done, [this]() {
            if (StateType_ == EOfferTrackingStateType::WaitingForScheduling) {
                Counters_->NActiveTrackings.Dec();
                Counters_->NWaitingChecks.Dec();
                return true;
            }
            if (StateType_ == EOfferTrackingStateType::InFly) {
                Counters_->NActiveTrackings.Dec();
                Counters_->NInFlyChecks.Dec();
                return true;
            }
            return false;
        });
    }

    void TOfferTrackingState::ToReported() {
        TrySwitch(EOfferTrackingStateType::Reported, [this]() {
            return StateType_ == EOfferTrackingStateType::Done;
        });
    }

    size_t TOfferTrackingState::GetSize(bool useSpaceUsedForPb) {
        auto getPbSize = [useSpaceUsedForPb](const google::protobuf::Message& message) {
            return useSpaceUsedForPb ? message.SpaceUsedLong() : message.ByteSizeLong();
        };

        size_t size = 0;
        size += sizeof(Key) + Key.length();
        size += sizeof(Lock);
        size += sizeof(OfferTrackingResult.InitialTimestamp);
        size += sizeof(OfferTrackingResult.Result);
        size += getPbSize(OfferTrackingResult.InitialRequest);
        size += getPbSize(OfferTrackingResult.InitialOffer);

        size += sizeof(OfferTrackingResult.CheckResults);
        for (const auto& checkResult : OfferTrackingResult.CheckResults) {
            size += sizeof(checkResult);
            if (checkResult.MatchedOffer) {
                size += getPbSize(*checkResult.MatchedOffer.Get());
            }
        }

        return size;
    }

    void TOfferTrackingState::TrySwitch(EOfferTrackingStateType targetState, const std::function<bool()>& condition) {
        if (condition()) {
            StateType_ = targetState;
        } else {
            ythrow yexception() << "Can't switch from " << StateType_ << " to " << targetState << ". Key: " << Key;
        }
    }

    NTravelProto::NPriceChecker::TTrackingState TOfferTrackingState::ToProto() const {
        NTravelProto::NPriceChecker::TTrackingState state;

        state.SetKey(Key);
        state.SetStateType(TrackingStateFromInnerToPb(StateType_));
        auto offerTrackingResult = state.MutableOfferTrackingResult();
        offerTrackingResult->SetInitialTimestamp(OfferTrackingResult.InitialTimestamp.Seconds());
        offerTrackingResult->MutableInitialRequest()->CopyFrom(OfferTrackingResult.InitialRequest);
        offerTrackingResult->MutableInitialOffer()->CopyFrom(OfferTrackingResult.InitialOffer);
        for (const auto& checkResult : OfferTrackingResult.CheckResults) {
            auto mutableCheckResult = offerTrackingResult->AddCheckResults();
            mutableCheckResult->SetTimestamp(checkResult.Timestamp.Seconds());
            if (checkResult.MatchedOffer.Defined()) {
                mutableCheckResult->MutableMatchedOffer()->CopyFrom(*checkResult.MatchedOffer.Get());
            }
        }

        std::visit(TResultVisitorToProto(offerTrackingResult), OfferTrackingResult.Result);

        for (const auto& error : LastSuccessiveErrors) {
            auto pbError = state.AddLastSuccessiveErrors();
            pbError->SetError(error.Error);
            pbError->SetTimestamp(error.Timestamp.Seconds());
        }

        return state;
    }

    std::shared_ptr<TOfferTrackingState> TOfferTrackingState::FromProto(const NTravelProto::NPriceChecker::TTrackingState& trackingState, TOfferTrackerCounters* counters) {
        auto& pbOfferTrackingResult = trackingState.GetOfferTrackingResult();

        auto result = std::make_shared<TOfferTrackingState>(
            trackingState.GetKey(),
            pbOfferTrackingResult.GetInitialRequest(),
            pbOfferTrackingResult.GetInitialOffer(),
            TInstant::Seconds(pbOfferTrackingResult.GetInitialTimestamp()),
            counters);

        result->StateType_ = TrackingStateFromPbToInner(trackingState.GetStateType());
        auto& offerTrackingResult = result->OfferTrackingResult;

        for (const auto& pbCheckResult : pbOfferTrackingResult.GetCheckResults()) {
            auto& checkResult = offerTrackingResult.CheckResults.emplace_back();
            if (pbCheckResult.HasTimestamp()) {
                checkResult.Timestamp = TInstant::Seconds(pbCheckResult.GetTimestamp());
                if (pbCheckResult.HasMatchedOffer()) {
                    checkResult.MatchedOffer = NTravelProto::TOffer();
                    checkResult.MatchedOffer->CopyFrom(pbCheckResult.GetMatchedOffer());
                }
            }
        }

        if (pbOfferTrackingResult.HasUnknownTrackingResult()) {
            offerTrackingResult.Result = TUnknownTrackingResult();
        } else if (pbOfferTrackingResult.HasPriceDifferTrackingResult()) {
            auto& trackingResult = pbOfferTrackingResult.GetPriceDifferTrackingResult();
            offerTrackingResult.Result = TPriceDifferTrackingResult(TDuration::Seconds(trackingResult.GetLifetimeMin()),
                                                                    TDuration::Seconds(trackingResult.GetLifetimeMax()),
                                                                    trackingResult.GetOldPrice(),
                                                                    trackingResult.GetNewPrice());
        } else if (pbOfferTrackingResult.HasPriceNotFoundTrackingResult()) {
            auto& trackingResult = pbOfferTrackingResult.GetPriceNotFoundTrackingResult();
            offerTrackingResult.Result = TPriceNotFoundTrackingResult(TDuration::Seconds(trackingResult.GetLifetimeMin()),
                                                                      TDuration::Seconds(trackingResult.GetLifetimeMax()),
                                                                      trackingResult.GetOldPrice());
        } else if (pbOfferTrackingResult.HasNoChangesTrackingResult()) {
            offerTrackingResult.Result = TNoChangesTrackingResult();
        } else if (pbOfferTrackingResult.HasErrorTrackingResult()) {
            auto& trackingResult = pbOfferTrackingResult.GetErrorTrackingResult();
            offerTrackingResult.Result = TErrorTrackingResult(trackingResult.GetError());
        }

        for (const auto& error : trackingState.GetLastSuccessiveErrors()) {
            result->LastSuccessiveErrors.emplace_back(error.GetError(), TInstant::Seconds(error.GetTimestamp()));
        }

        return result;
    }

    TOfferTrackingState::EOfferTrackingStateType TOfferTrackingState::TrackingStateFromPbToInner(TOfferTrackingState::PbTrackingStateType stateType) {
        switch (stateType) {
#define ST(_PB_NAME_, _INNER_NAME_)                                                                   \
    case TOfferTrackingState::PbTrackingStateType::TTrackingState_ETrackingStateType_TST_##_PB_NAME_: \
        return EOfferTrackingStateType::_INNER_NAME_
            ST(NEW, New);
            ST(WAITING_FOR_SCHEDULING, WaitingForScheduling);
            ST(SCHEDULED, Scheduled);
            ST(IN_FLY, InFly);
            ST(DONE, Done);
            ST(REPORTED, Reported);
#undef ST
        }
    }

    TOfferTrackingState::PbTrackingStateType TOfferTrackingState::TrackingStateFromInnerToPb(TOfferTrackingState::EOfferTrackingStateType stateType) {
        switch (stateType) {
#define ST(_INNER_NAME_, _PB_NAME_)             \
    case EOfferTrackingStateType::_INNER_NAME_: \
        return TOfferTrackingState::PbTrackingStateType::TTrackingState_ETrackingStateType_TST_##_PB_NAME_
            ST(New, NEW);
            ST(WaitingForScheduling, WAITING_FOR_SCHEDULING);
            ST(Scheduled, SCHEDULED);
            ST(InFly, IN_FLY);
            ST(Done, DONE);
            ST(Reported, REPORTED);
#undef ST
        }
    }

    TOfferTrackingState::TResultVisitorToProto::TResultVisitorToProto(NTravelProto::NPriceChecker::TTrackingState::TOfferTrackingResult* trackingResult)
        : TrackingResult(trackingResult)
    {
    }

    void TOfferTrackingState::TResultVisitorToProto::operator()(const TUnknownTrackingResult&) const {
        TrackingResult->MutableUnknownTrackingResult();
    }

    void TOfferTrackingState::TResultVisitorToProto::operator()(const TPriceDifferTrackingResult& result) const {
        auto trackingResult = TrackingResult->MutablePriceDifferTrackingResult();
        trackingResult->SetLifetimeMin(result.LifetimeMin.Seconds());
        trackingResult->SetLifetimeMax(result.LifetimeMax.Seconds());
        trackingResult->SetOldPrice(result.OldPrice);
        trackingResult->SetNewPrice(result.NewPrice);
    }

    void TOfferTrackingState::TResultVisitorToProto::operator()(const TPriceNotFoundTrackingResult& result) const {
        auto trackingResult = TrackingResult->MutablePriceNotFoundTrackingResult();
        trackingResult->SetLifetimeMin(result.LifetimeMin.Seconds());
        trackingResult->SetLifetimeMax(result.LifetimeMax.Seconds());
        trackingResult->SetOldPrice(result.OldPrice);
    }

    void TOfferTrackingState::TResultVisitorToProto::operator()(const TNoChangesTrackingResult&) const {
        TrackingResult->MutableNoChangesTrackingResult();
    }

    void TOfferTrackingState::TResultVisitorToProto::operator()(const TErrorTrackingResult& result) const {
        auto trackingResult = TrackingResult->MutableErrorTrackingResult();
        trackingResult->SetError(result.Error);
    }
}
