#include "trace.h"
#include "udp_sender.h"

#include <solomon/libs/cpp/logging/logging.h>
#include <solomon/libs/cpp/trace/jaeger/jaeger_types.h>
#include <solomon/libs/cpp/selfmon/selfmon.h>

#include <library/cpp/actors/core/actor_bootstrapped.h>
#include <library/cpp/actors/core/events.h>
#include <library/cpp/actors/core/hfunc.h>
#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>

#include <thrift/transport/TBufferTransports.h>
#include <thrift/protocol/TCompactProtocol.h>

#include <util/generic/queue.h>
#include <util/stream/format.h>
#include <util/system/hostname.h>

namespace NSolomon::NTracing {

using namespace NActors;

TString TTraceId::Hex() const {
    return TStringBuilder{} << ::Hex(Id_[0], HF_FULL) << ::Hex(Id_[1], HF_FULL);
}

namespace NPrivate {

class TTracingServiceActor: public TActorBootstrapped<TTracingServiceActor> {
public:
    TTracingServiceActor(NMonitoring::TMetricRegistry& registry, const TString& serviceId)
        : Registry_(registry)
        , Sender_{nullptr}
        , CleanupDelay_(TDuration::Seconds(5))
        , MaxSize_(10'000)
        , SamplingRate_(0.001)
        , AgentPort_(6831)
    {
        Y_UNUSED(Registry_);
        jaegertracing::thrift::Tag hostnameTag;
        hostnameTag.__set_key("hostname");
        hostnameTag.__set_vType(jaegertracing::thrift::TagType::STRING);
        hostnameTag.__set_vStr(GetHostName());

        ProcessInfo_.__set_serviceName(serviceId);
        ProcessInfo_.__set_tags({hostnameTag});
    }

    void Bootstrap() {
        Become(&TTracingServiceActor::Normal);
        Sender_ = std::move(MakeUdpSender(AgentPort_));
        Schedule(CleanupDelay_, new TEvents::TEvWakeup);
    }
private:
    STRICT_STFUNC(Normal, {
        hFunc(TTracingEvents::TEvReportBegin, OnReportBegin);
        hFunc(TTracingEvents::TEvReportEnd, OnReportEnd);
        hFunc(TTracingEvents::TEvReportMissing, OnReportMissing);

        hFunc(TTracingEvents::TEvGetTraceRequest, OnGetTrace);
        hFunc(TTracingEvents::TEvEraseTraceRequest, OnEraseTrace);

        hFunc(TEvents::TEvWakeup, OnWakeup);
        hFunc(TEvents::TEvPoison, OnPoison);

        hFunc(NSelfMon::TEvPageDataReq, OnRequest)
    });

    void OnReportBegin(TTracingEvents::TEvReportBegin::TPtr& evPtr) {
        auto& ev = *evPtr->Get();

        TTraceId traceId = TTraceId::From(ev.ParentSpan);
        ui64 parentSpanId = GetSpanId(ev.ParentSpan);
        ui64 spanId = GetSpanId(evPtr->TraceId);

        auto it = Storage_.find(traceId);
        if (it == Storage_.end() || it->second.Erased) {
            it = Storage_.insert({traceId, TTrace{ev.Moment}}).first;
            TracesByCreation_.emplace(ev.Moment, traceId);
        }
        auto& trace = it->second;

        if (ev.ParentSpan.IsRoot()) {
            if (trace.FirstSpanId) {
                MON_ERROR(Tracing, ev.Loc << " in " << ev.Function
                        << " top level span added twice to the same trace " << traceId.Hex());
            }
            trace.FirstSpanId = spanId;
        }

        if (trace.Spans.contains(spanId)) {
            MON_ERROR(Tracing, ev.Loc << " in " << ev.Function
                      << " same span " << Hex(spanId, HF_FULL) << " added to trace " << traceId.Hex());
        }
        trace.Spans[spanId] = TSpan{spanId, parentSpanId, ev.Moment, TInstant::Zero(), std::move(ev.Description)};
    }

    void OnReportEnd(TTracingEvents::TEvReportEnd::TPtr& evPtr) {
        auto& ev = *evPtr->Get();

        TTraceId traceId = TTraceId::From(evPtr->TraceId);
        ui64 spanId = GetSpanId(evPtr->TraceId);

        auto traceIt = Storage_.find(traceId);
        if (traceIt == Storage_.end() || traceIt->second.Erased) {
            MON_ERROR(Tracing, ev.Loc << " in " << ev.Function
                    << " attempted to end a span " << Hex(spanId, HF_FULL) << " in an unknown trace " << traceId.Hex());
            return;
        }
        auto& trace = traceIt->second;
        auto it = trace.Spans.find(spanId);
        if (it == trace.Spans.end()) {
            MON_ERROR(Tracing, ev.Loc << " in " << ev.Function
                    << " attempted to end a nonexistent span " << Hex(spanId, HF_FULL) << " in trace " << traceId.Hex());
            return;
        }
        if (it->second.End) {
            MON_ERROR(Tracing, ev.Loc << " in " << ev.Function
                    << " attempted to end an already completed span " << Hex(spanId, HF_FULL) << " in trace " << traceId.Hex());
            return;
        }
        it->second.End = ev.Moment;

        if (spanId == trace.FirstSpanId) {
            trace.CompletedAt = ev.Moment;
            TVector<ui64> LeftoverSpans;
            for (const auto& [id, span]: trace.Spans) {
                if (!span.End) {
                    LeftoverSpans.push_back(id);
                }
            }
            if (LeftoverSpans) {
                TStringBuilder sb;
                sb << ev.Loc << " in " << ev.Function
                    << " upon closing trace " << traceId.Hex() << " there were active spans:";
                for (auto id: LeftoverSpans) {
                    sb << ' ' << Hex(id, HF_FULL);
                }
                MON_WARN(Tracing, sb);
            }
        }
    }

    void OnReportMissing(TTracingEvents::TEvReportMissing::TPtr& evPtr) {
        auto& ev = *evPtr->Get();
        MON_ERROR(Tracing, ev.Loc << " in " << ev.Function
            << " is missing span on " << (ev.Starting ? "begin" : "end"));
    }

    void OnGetTrace(TTracingEvents::TEvGetTraceRequest::TPtr& ev) const {
        const auto& traceId = ev->Get()->ReqTraceId;

        auto it = Storage_.find(traceId);
        if (it == Storage_.end() || it->second.Erased) {
            Send(ev->Sender, new TTracingEvents::TEvGetTraceResponse{traceId, {}}, 0, ev->Cookie);
            return;
        }

        const auto& spans = it->second.Spans;
        TVector<TSpan> result;
        for (const auto& [id, span]: spans) {
            result.push_back(span);
        }
        Send(ev->Sender, new TTracingEvents::TEvGetTraceResponse{traceId, std::move(result)}, 0, ev->Cookie);
    }

    void OnEraseTrace(TTracingEvents::TEvEraseTraceRequest::TPtr& ev) {
        const auto& traceId = ev->Get()->ReqTraceId;

        auto it = Storage_.find(traceId);
        if (it == Storage_.end()) {
            return;
        }
        
        it->second.Erased = true;
    }

    static constexpr size_t BATCH_BYTES_THRESHOLD = 60'000;

    static size_t EstimateSpanBytes(const jaegertracing::thrift::Span& span) {
        return 10 + 10 + 10 + 
                2 + span.operationName.size() + 
                1 + 
                10 + 10 + 
                span.tags.size() * 20 +
                10; // extra for whatever reason
    }

    struct TTrace;

    static size_t SerializeSpan(
        const TTraceId traceId,
        const TTrace& trace,
        const TSpan& span,
        jaegertracing::thrift::Span& spanWire)
    {
        spanWire.traceIdLow = traceId.Low();
        spanWire.traceIdHigh = traceId.High();
        spanWire.spanId = span.Id;
        spanWire.parentSpanId = span.ParentId;
        spanWire.operationName = span.Description;
        spanWire.flags = 1; // SAMPLED
        spanWire.startTime = (span.Begin + trace.WalltimeFromRealtimeAdjustment).MicroSeconds();
        auto end = span.End ? span.End : trace.CompletedAt;
        spanWire.duration = (end - span.Begin).MicroSeconds();
        if (!span.End) {
            jaegertracing::thrift::Tag tag;
            tag.__set_key("unfinished");
            tag.__set_vType(jaegertracing::thrift::TagType::BOOL);
            tag.__set_vBool(true);
            spanWire.__set_tags({tag});
        }
        return EstimateSpanBytes(spanWire);
    }

    void UploadTraces() {
        auto start = TInstant::Now();
        size_t uploadedSpans = 0;
        size_t bufferSizeEstimate = 0;
        std::vector<jaegertracing::thrift::Span> spanBuffer;
        for (auto& [traceId, trace]: Storage_) {
            if (trace.CompletedAt && !trace.Dumped) {
                auto trId = traceId;
                auto traceIdLow = traceId.Low();
                double sampleHash = std::ldexp(traceIdLow, -8 * static_cast<int>(sizeof(traceIdLow)));
                if (sampleHash < SamplingRate_) {
                    MON_DEBUG(Tracing, "Trace id " << trId.Hex() << " is sampled");
                    for (auto& [spanId, span]: trace.Spans) {
                        jaegertracing::thrift::Span spanWire;
                        auto spanBytes = SerializeSpan(traceId, trace, span, spanWire);
                        if (spanBytes > BATCH_BYTES_THRESHOLD) {
                            // remove when structured bindings are allowed in lambdas
                            auto spanId = span.Id;
                            MON_WARN(Tracing, "span " << Hex(spanId, HF_FULL) << " in trace "
                                    << trId.Hex() << " is too big, dropped");
                            continue;
                        }
                        if (bufferSizeEstimate + spanBytes > BATCH_BYTES_THRESHOLD) {
                            Flush(spanBuffer);
                            bufferSizeEstimate = 0;
                        }
                        spanBuffer.push_back(std::move(spanWire));
                        bufferSizeEstimate += spanBytes;
                        uploadedSpans++;
                    }
                }
                trace.Dumped = true;
            }
        }
        if (!spanBuffer.empty()) {
            Flush(spanBuffer);
        }
        if (uploadedSpans > 0) {
            auto dur = TInstant::Now() - start;
            MON_INFO(Tracing, "uploaded " << uploadedSpans << " spans to agent, took " << dur);
        }
    }

    void Flush(std::vector<jaegertracing::thrift::Span>& spans) {
        jaegertracing::thrift::Batch batch;
        batch.process = ProcessInfo_;
        batch.spans = std::move(spans);
        spans.clear();

        auto buffer = std::make_shared<apache::thrift::transport::TMemoryBuffer>(1 << 16);
        apache::thrift::protocol::TCompactProtocol proto(buffer);

        try {
            proto.writeMessageBegin("emitBatch", apache::thrift::protocol::TMessageType::T_ONEWAY, BatchSeqNo_);
            proto.writeFieldBegin("batch", ::apache::thrift::protocol::T_STRUCT, 1);
            batch.write(&proto);
            proto.writeFieldEnd();
            proto.writeFieldStop();
            proto.writeMessageEnd();

            uint8_t* buf;
            uint32_t sz;
            buffer->getBuffer(&buf, &sz);

            Sender_->Send(reinterpret_cast<const char*>(buf), sz);

            BatchSeqNo_++;
        } catch (const std::exception& e) {
            MON_ERROR(Tracing, "exception in Flush: " << e.what());
        }
    }

    void RemoveErased() {
        for (auto it = Storage_.begin(); it != Storage_.end();) {
            if (it->second.Dumped && it->second.Erased) {
                TracesByCreation_.erase(TraceIdWithCreation{it->second.CreatedAt, it->first});
                Storage_.erase(it++);
            } else {
                ++it;
            }
        }
    }

    size_t DropOldest() {
        auto size = Storage_.size();
        size_t deleted = 0;
        while (size > MaxSize_) {
            auto traceIdToRemove = TracesByCreation_.begin()->Id;
            TracesByCreation_.erase(TracesByCreation_.begin());
            Storage_.erase(traceIdToRemove);
            --size;
            ++deleted;
        }
        return deleted;
    }

    void OnWakeup(TEvents::TEvWakeup::TPtr& ev) {
        // Upload all, even erased
        UploadTraces();
        RemoveErased();
        size_t deleted = DropOldest();

        if (deleted > 0) {
            MON_INFO(Tracing, "cleaned " << deleted << " oldest traces");
        }

        Schedule(CleanupDelay_, ev->Release().Release());
    }

    void OnPoison(TEvents::TEvPoison::TPtr& ev) {
        if (ev->Sender) {
            Send(ev->Sender, new TEvents::TEvPoisonTaken);
        }

        PassAway();
    }

    void OnRequest(NSelfMon::TEvPageDataReq::TPtr& ev) {
        if (ev->Get()->HttpReq->Method == "POST") {
            TQuickCgiParam params{ev->Get()->HttpReq->Body};
            if (TStringBuf valueStr = params.Get("samplingRate")) {
                double value;
                if (TryFromString(valueStr, value)) {
                    SamplingRate_ = Max(0.0, Min(1.0, value));
                }
            }
        }

        yandex::monitoring::selfmon::Page page;
        page.set_title("Settings");
        if (auto* col = page.mutable_grid()->add_rows()->add_columns()) {
            col->set_width(4);
            auto* f = col->mutable_component()->mutable_form();
            f->set_layout(yandex::monitoring::selfmon::FormLayout::Vertical);
            f->set_method(yandex::monitoring::selfmon::FormMethod::Post);

            if (auto* item = f->add_items()) {
                item->set_label("Sampling rate");
                item->set_help("Valid value is between 0 and 1");
                auto* input = item->mutable_input();
                input->set_type(yandex::monitoring::selfmon::InputType::FloatNumber);
                input->set_name("samplingRate");
                input->set_value(Sprintf("%g", SamplingRate_));
            }

            auto* submit = f->add_submit();
            submit->set_title("Save");
        }

        Send(ev->Sender, new NSelfMon::TEvPageDataResp{std::move(page)});
    }

private:
    NMonitoring::TMetricRegistry& Registry_;
    std::unique_ptr<IUdpSender> Sender_;
    TDuration CleanupDelay_;
    ui32 MaxSize_;

    struct TTrace {
        TInstant CreatedAt;
        TDuration WalltimeFromRealtimeAdjustment;
        ui64 FirstSpanId{0};
        absl::flat_hash_map<ui64, TSpan> Spans;
        TInstant CompletedAt{TInstant::Zero()};
        bool Erased{false};
        bool Dumped{false};

        TTrace(TInstant createdAt)
            : CreatedAt{createdAt}
            , WalltimeFromRealtimeAdjustment(TInstant::Now() - createdAt)
        {}
    };

    struct TraceIdWithCreation {
        TInstant CreatedAt;
        TTraceId Id;

        TraceIdWithCreation(TInstant createdAt, TTraceId id)
            : CreatedAt{createdAt}
            , Id{std::move(id)}
        {
        }

        bool operator<(const TraceIdWithCreation& other) const {
            if (CreatedAt < other.CreatedAt) {
                return true;
            }
            if (CreatedAt > other.CreatedAt) {
                return false;
            }
            return Id < other.Id;
        }
    };

    std::set<TraceIdWithCreation> TracesByCreation_;
    absl::flat_hash_map<TTraceId, TTrace> Storage_;
    jaegertracing::thrift::Process ProcessInfo_;
    double SamplingRate_;
    i64 BatchSeqNo_{0};
    ui16 AgentPort_;
};

} // NPrivate


void InitTracingService(
    NMonitoring::TMetricRegistry& registry,
    NActors::TActorSystem& actorSystem,
    ui32 executorPool,
    const TString& serviceId)
{
    auto actor = actorSystem.Register(
            new NPrivate::TTracingServiceActor(registry, serviceId), NActors::TMailboxType::HTSwap, executorPool);
    actorSystem.RegisterLocalService(TracingServiceId(), actor);
}

} // NSolomon::NTracing
