#include <crypta/graph/soup/orgs/email.h>

#include <tensorflow/core/public/session.h>
#include <tensorflow/core/graph/default_device.h>

using tensorflow::DataType;
using tensorflow::GraphDef;
using tensorflow::Session;
using tensorflow::SessionOptions;
using tensorflow::Status;
using tensorflow::Tensor;
using tensorflow::TensorShape;

using NNativeYT::TProtoState;
using std::pair;

namespace {
    void CheckTensorflowStatus(const Status& status, const TString& msg) {
        if (not status.ok()) {
            TString message = msg.empty() ? "" : msg + ": ";
            message += status.ToString();
            ythrow yexception() << message;
        }
    }

    TVector<Tensor> Calculate(Session* const sessionPtr,
                              const TVector<pair<TString, Tensor>>& inputs,
                              const TVector<TString>& outputNames) {
        TVector<Tensor> outputTensors;
        outputTensors.reserve(outputNames.size());
        CheckTensorflowStatus(sessionPtr->Run(inputs, outputNames, {}, &outputTensors),
                              "Calculate");
        return outputTensors;
    }

    const TString ID_TYPE = "email";
}

class TTFBatch {
public:
    explicit TTFBatch(ui32 batchSize, ui32 dim,
                      DataType dtype = DataType::DT_FLOAT);
    void Reset();
    void Append(const TVector<float>& item);
    inline ui32 Size();
    inline bool IsFull();
    inline Tensor& getTensor();

private:
    ui32 Dim;
    ui32 BatchSize;
    ui32 CurrentItemIdx;
    Tensor BatchTensor;
    void AssertDimension(const TVector<float>& item);
};

class TJobBatchIterator;

class TJobBatch {
public:
    explicit TJobBatch(ui32 batchSize);
    void Append(const TString& email);
    inline int Size() const;
    inline int IsFull() const;
    inline void Reset();
    void Apply(Session* const sessionPtr);
    const TVector<TString>& GetEmails();
    const TVector<float>& GetOutputs();
    TJobBatchIterator begin();
    TJobBatchIterator end();

private:
    ui32 CurrentIndex;
    ui32 BatchSize;
    TVector<TString> Emails;
    TTFBatch Logins;
    TTFBatch Domains;
    TVector<float> Outputs;
};

class TJobBatchIterator
    : std::iterator<std::forward_iterator_tag, TJobBatch*> {
public:
    explicit TJobBatchIterator(TJobBatch* jobBatchPtr, int index);
    TNode operator*();
    bool operator==(const TJobBatchIterator& rhs);
    bool operator!=(const TJobBatchIterator& rhs);
    TJobBatchIterator& operator++();

private:
    int Index;
    TJobBatch* JobBatchPtr;
};

void NEmailOrgModelApply::TUniqueEmails::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    for (; input->IsValid(); input->Next()) {
        const auto& row = input->GetRow();
        if (row["idType"].AsString() != "email") {
            continue;
        }
        TNode out;
        out("email", row["id"].AsString());
        output->AddRow(out);
        break;
    }
}

TJobBatchIterator::TJobBatchIterator(TJobBatch* jobBatchPtr, int index)
    : Index(index)
    , JobBatchPtr(jobBatchPtr){};

NYT::TNode TJobBatchIterator::operator*() {
    NYT::TNode out;
    out("id_value", JobBatchPtr->GetEmails()[Index]);
    out("id_type", ID_TYPE);
    auto isOrgScore = JobBatchPtr->GetOutputs()[Index];
    out("is_org_score", isOrgScore);
    return out;
}

TJobBatchIterator TJobBatch::begin() {
    return TJobBatchIterator(this, 0);
}

TJobBatchIterator TJobBatch::end() {
    return TJobBatchIterator(this, this->Size());
}

bool TJobBatchIterator::operator==(const TJobBatchIterator& rhs) {
    return (Index == rhs.Index) && (JobBatchPtr == rhs.JobBatchPtr);
}

bool TJobBatchIterator::operator!=(const TJobBatchIterator& rhs) {
    return !(*this == rhs);
}

TJobBatchIterator& TJobBatchIterator::operator++() {
    ++Index;
    return *this;
}

TTFBatch::TTFBatch(ui32 batchSize, ui32 dim, DataType dtype)
    : Dim(dim)
    , BatchSize(batchSize)
    , CurrentItemIdx(0)
    , BatchTensor(dtype, TensorShape{batchSize, dim})
{
    Reset();
}

void TTFBatch::Reset() {
    auto eigenTensor = BatchTensor.tensor<float, 2>();
    eigenTensor.setConstant(0.);
    CurrentItemIdx = 0;
}

void TTFBatch::Append(const TVector<float>& item) {
    AssertDimension(item);
    if (CurrentItemIdx == BatchSize) {
        ythrow yexception() << "Append to Full batch";
    }
    auto eigenTensor = BatchTensor.tensor<float, 2>();
    for (ui32 i = 0; i < item.size(); ++i) {
        eigenTensor(CurrentItemIdx, i) = item[i];
    }
    ++CurrentItemIdx;
}

void TTFBatch::AssertDimension(const TVector<float>& item) {
    if (item.size() > Dim) {
        ythrow yexception() << "Vector dimension is not valid. Required: "
                            << Dim
                            << " But got: "
                            << item.size();
    }
}

inline ui32 TTFBatch::Size() {
    return CurrentItemIdx;
}

inline bool TTFBatch::IsFull() {
    return CurrentItemIdx == BatchSize;
}

inline Tensor& TTFBatch::getTensor() {
    return BatchTensor;
}

struct NEmailOrgModelApply::TTFModelApply::TSelf {
    TSelf(const TBuffer& buffer) : State(buffer)
    {
    }

    GraphDef GetGraphDef();
    SessionOptions GetOptions();
    THolder<Session> CreateSession();
    void Yield(Session* sessionPtr, TJobBatch& jobBatch, TTableWriter<TNode>* output);

    TProtoState<TApplyModelState> State;
};

NEmailOrgModelApply::TTFModelApply::TTFModelApply()
    : Self()
{
}

NEmailOrgModelApply::TTFModelApply::TTFModelApply(const TBuffer& buffer)
    : Self(new TSelf(buffer))
{
}

GraphDef NEmailOrgModelApply::TTFModelApply::TSelf::GetGraphDef() {
    GraphDef graphDef;
    Y_PROTOBUF_SUPPRESS_NODISCARD graphDef.ParseFromString(State->GetModel());
    return graphDef;
}

SessionOptions NEmailOrgModelApply::TTFModelApply::TSelf::GetOptions() {
    SessionOptions options;
    options.config.set_log_device_placement(true);
    options.config.set_allow_soft_placement(true);
    options.config.set_inter_op_parallelism_threads(0);
    options.config.set_intra_op_parallelism_threads(0);
    return options;
}

void NEmailOrgModelApply::TTFModelApply::Save(IOutputStream& output) const {
    Self->State.Save(output);
}
void NEmailOrgModelApply::TTFModelApply::Load(IInputStream& input) {
    Self->State.Load(input);
}

THolder<Session> NEmailOrgModelApply::TTFModelApply::TSelf::CreateSession() {
    GraphDef graphDef = GetGraphDef();
    SessionOptions options = GetOptions();

    THolder<Session> sessionPtr;
    Session* sessionTmpPtr;
    CheckTensorflowStatus(NewSession(options, &sessionTmpPtr),
                          "Creating Session");
    sessionPtr.Reset(sessionTmpPtr);
    CheckTensorflowStatus(sessionPtr->Create(graphDef),
                          "Set loaded graph");
    return sessionPtr;
}

void TJobBatch::Apply(Session* const sessionPtr) {
    if (!Size()) {
        return;
    }
    TVector<pair<TString, Tensor>> inputs = {
        {"login_input_1", Logins.getTensor()},
        {"domain_input_1", Domains.getTensor()},
    };
    TVector<TString> outputNames = {
        "network_output",
    };
    auto tfOutputs = Calculate(sessionPtr, inputs, outputNames);
    auto eigenMainOutput = tfOutputs[0].tensor<float, 2>();
    for (int i = 0; i < Size(); ++i) {
        Outputs.push_back(eigenMainOutput(i, 0));
    }
}

void NEmailOrgModelApply::TTFModelApply::TSelf::Yield(Session* const sessionPtr,
                                                      TJobBatch& jobBatch,
                                                      TTableWriter<NYT::TNode>* output) {
    jobBatch.Apply(sessionPtr);
    for (const auto& out : jobBatch) {
        output->AddRow(out);
    }
    jobBatch.Reset();
}

void NEmailOrgModelApply::TTFModelApply::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    fclose(stdout);
    THolder<Session> sessionPtr = Self->CreateSession();
    auto batch = TJobBatch(Self->State->GetBatchSize());
    for (; input->IsValid(); input->Next()) {
        const auto& row = input->GetRow();
        batch.Append(row["email"].AsString());
        if (batch.IsFull()) {
            Self->Yield(sessionPtr.Get(), batch, output);
        }
    }
    Self->Yield(sessionPtr.Get(), batch, output);
    CheckTensorflowStatus(sessionPtr->Close(), "Close session");
}

TJobBatch::TJobBatch(ui32 batchSize)
    : CurrentIndex(0)
    , BatchSize(batchSize)
    , Emails()
    , Logins(batchSize, NEmailOrganization::MAX_LEN)
    , Domains(batchSize, NEmailOrganization::DOMAIN_DICT_SIZE)
{
}

void TJobBatch::Append(const TString& email) {
    if (CurrentIndex == BatchSize) {
        ythrow yexception() << "TJobBatch Batch overflow";
    }
    TVector<float> loginVector;
    TVector<float> domainVector;
    if (!NEmailOrganization::ConvertString(email, loginVector, domainVector)) {
        return;
    }
    Emails.push_back(email);
    Logins.Append(loginVector);
    Domains.Append(domainVector);
    ++CurrentIndex;
}

int TJobBatch::Size() const {
    return CurrentIndex;
}

int TJobBatch::IsFull() const {
    return CurrentIndex == BatchSize;
}

void TJobBatch::Reset() {
    Logins.Reset();
    Domains.Reset();
    Emails.clear();
    Outputs.clear();
    CurrentIndex = 0;
}

const TVector<TString>& TJobBatch::GetEmails() {
    return Emails;
}

const TVector<float>& TJobBatch::GetOutputs() {
    return Outputs;
}
