
#include <library/cpp/json/json_value.h>
#include <util/stream/file.h>
#include <library/cpp/json/json_reader.h>
#include <mapreduce/yt/interface/init.h>
#include <mapreduce/yt/interface/client.h>
#include <util/generic/hash_set.h>
#include <util/string/split.h>
#include <library/cpp/json/json_writer.h>
#include <util/folder/path.h>
#include <util/stream/tee.h>
#include <util/generic/map.h>
#include <library/cpp/accurate_accumulate/accurate_accumulate.h>
#include <util/generic/ymath.h>
#include <util/generic/iterator_range.h>

struct TStat{
    double threshold{};
    double precision{};
    double recall{};
    double f1{};

    size_t tp{}, fp{}, tn{}, fn{};

    bool operator < (const TStat & stat) const {
        return f1 < stat.f1;
    }

    void WriteJson(NJson::TJsonWriter & writer, const TStringBuf & key) const {
        writer.OpenMap(key);

        writer.Write("threshold", threshold);
        writer.Write("precision", precision);
        writer.Write("recall", recall);
        writer.Write("f1", f1);
        writer.Write("tp", tp);
        writer.Write("fp", fp);
        writer.Write("tn", tn);
        writer.Write("fn", fn);

        writer.CloseMap();
    }

    TStat() = default;
    TStat(const TStat &) noexcept = default;
    TStat&operator=(const TStat&) noexcept = default;

    TStat(double threshold, size_t tp, size_t fp, size_t tn, size_t fn)
            : threshold(threshold), tp(tp), fp(fp), tn(tn), fn(fn) {

        precision = 1. * tp / (tp + fp);

        recall = 1. * tp / (tp + fn);

        f1 = 2. * precision * recall / (precision + recall);
    }
};

struct TRoc{
    double GetAuc() const {
        return 1.0 * auc / (totalPositives * totalNegatives);
    }

    void AddPoint(bool positive) {
        if(positive)
            value ++;
        else
            auc += value;
    }

    TRoc(size_t totalPositives, size_t totalNegatives)
            : totalPositives(totalPositives), totalNegatives(totalNegatives) {}

    size_t totalPositives;
    size_t totalNegatives;
    size_t value{}, auc{};
};


THashSet<TString> ParsePositiveValues(const TString & rawPositiveValues) {
    THashSet<TString> positiveValues;

    TVector<TString> v;
    Split(rawPositiveValues, ",", v);
    positiveValues.insert(v.cbegin(), v.cend());

    return positiveValues;
}

struct TMark{
    bool positive;
    double value;

    bool operator < (const TMark & mark) const {
        return value < mark.value;
    }

    TMark(bool positive, double value) : positive(positive), value(value) {}
};

struct TMarksWithStats{
    TVector<TMark> marks;
    size_t totalPositive{}, totalNegative{};

    TMarksWithStats(TVector<TMark> marks, size_t totalPositive, size_t totalNegative) :
            marks(std::move(marks)), totalPositive(totalPositive), totalNegative(totalNegative) {}
};

TMarksWithStats ReadMarks(
        const TString & sourceTable,
        const TString & targetField,
        const TString & predictionField,
        const THashSet<TString> & positiveValues) {

    auto client = NYT::CreateClient("hahn");

    const auto rowsCount = client->Get(sourceTable + "/@row_count").AsInt64();

    TVector<TMark> marks(Reserve((size_t) rowsCount));

    size_t totalPositive{};
    {
        auto reader = client->CreateTableReader<NYT::TNode>(
                NYT::TRichYPath(sourceTable)
                        .Columns({predictionField, targetField})
        );

        for (; reader->IsValid(); reader->Next()) {
            const auto & record = reader->GetRow();

            if(!record.HasKey(predictionField) || !record.HasKey(targetField))
                continue;

            double value{};
            const auto & recordValue = record[predictionField];

            switch(recordValue.GetType()) {
                case NYT::TNode::String:
                    value = FromString<double>(recordValue.AsString());
                    break;
                case NYT::TNode::Int64:
                    value = recordValue.AsInt64();
                    break;
                case NYT::TNode::Uint64:
                    value = recordValue.AsUint64();
                    break;
                case NYT::TNode::Double:
                    value = recordValue.AsDouble();
                    break;
                case NYT::TNode::Bool:
                    value = recordValue.AsBool();
                    break;

                case NYT::TNode::List:
                case NYT::TNode::Map:
                    ythrow yexception() << "field " << targetField << " has type not numeric type " << recordValue.GetType();
                case NYT::TNode::Null:
                case NYT::TNode::Undefined:
                    continue;
            }

            TString targetAsString;
            const auto& target = record[targetField];
            switch(target.GetType()){

                case NYT::TNode::Map:break;
                case NYT::TNode::Null:break;
                case NYT::TNode::List:break;
                case NYT::TNode::Undefined:
                    continue;
                case NYT::TNode::String:
                    targetAsString = target.AsString();
                    break;
                case NYT::TNode::Int64:
                    targetAsString = ToString(target.AsInt64());
                    break;
                case NYT::TNode::Uint64:
                    targetAsString = ToString(target.AsUint64());
                    break;
                case NYT::TNode::Double:
                    targetAsString = ToString(target.AsDouble());
                    break;
                case NYT::TNode::Bool:
                    targetAsString = ToString(target.AsBool());
                    break;
            }

            const bool positive = positiveValues.contains(targetAsString);
            if(positive)
                totalPositive ++;

            marks.emplace_back(TMark{positive, value});
        }
    }

    std::sort(marks.begin(), marks.end());

    size_t totalNegative = marks.size() - totalPositive;

    return TMarksWithStats{std::move(marks), totalPositive, totalNegative};
}

int main(int argc, const char ** argv) {
    NYT::Initialize(argc, argv);

    const TString & sourceTable = argv[1];
    const TString & targetField = argv[2];
    const TString & predictionField = argv[3];
    const TString & rawPositiveValues = argv[4];
    const TFsPath & outputJsonPath = argv[5];

    const auto & positiveValues = ParsePositiveValues(rawPositiveValues);

    const TMarksWithStats & marksWithStats = ReadMarks(sourceTable, targetField, predictionField, positiveValues);

    size_t negativeLessThenThreshold{}, positiveLessThenThreshold{};

    TStat bestF1;
    TRoc roc(marksWithStats.totalPositive, marksWithStats.totalNegative);

    for(const auto & mark : MakeIteratorRange(marksWithStats.marks.crbegin(), marksWithStats.marks.crend())) {
        roc.AddPoint(mark.positive);
    }

    for(const auto & mark : MakeIteratorRange(marksWithStats.marks.cbegin(), marksWithStats.marks.cend() - 1)) {
        if(mark.positive)
            positiveLessThenThreshold ++;
        else
            negativeLessThenThreshold ++;

        size_t tp = marksWithStats.totalPositive - positiveLessThenThreshold;
        size_t fp = positiveLessThenThreshold;

        size_t tn = negativeLessThenThreshold;
        size_t fn = marksWithStats.totalNegative - negativeLessThenThreshold;

        TStat stat(mark.value, tp, fp, tn, fn);

        if(bestF1.f1 < stat.f1)
            bestF1 = stat;
    }
    {
        TUnbufferedFileOutput dst(outputJsonPath);
        TTeeOutput out(&dst, &Cout);
        NJson::TJsonWriter writer(&out, true);

        writer.OpenMap();

        writer.Write("auc", roc.GetAuc());
        bestF1.WriteJson(writer, "best_f1");

        writer.CloseMap();
    }

    Cout << Endl;
    return 0;
}
