#include "generator.h"

#include "sandbox.h"

#include <tasklet/api/tasklet.pb.h>
#include <tasklet/gen/lib/names.h>
#include <tasklet/gen/lib/components.h>

#include <google/protobuf/compiler/cpp/cpp_helpers.h>
#include <google/protobuf/compiler/plugin.h>

#include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/strutil.h>

#include <util/generic/map.h>
#include <util/generic/ptr.h>
#include <util/generic/set.h>
#include <util/generic/vector.h>
#include <util/string/printf.h>

namespace NTaskletGen {

namespace {

    void MemberDefinition(io::Printer &printer, const TProtoStringType &field, const FieldDescriptor *descriptor) {
        const auto& module = ProtoModuleName(descriptor->message_type()->file());
        printer.Print("self.$fi$ = $mo$.$na$()\n", "fi", field, "mo", module, "na", descriptor->message_type()->name());
    }

    void ClassMemberDefinition(io::Printer &printer, const TProtoStringType &field, const FieldDescriptor *descriptor) {
        const auto& module = ProtoModuleName(descriptor->message_type()->file());
        printer.Print("$fi$ = $mo$.$na$\n", "fi", field, "mo", module, "na", descriptor->message_type()->name());
    }

} // anonymous namespace

TPythonGenerator::TPythonGenerator() = default;
TPythonGenerator::~TPythonGenerator() = default;

bool TPythonGenerator::Generate(
    const FileDescriptor* file,
    const TProtoStringType& parameter,
    GeneratorContext* generator_context,
    TProtoStringType* error
) const {
    Y_UNUSED(parameter, error);

    TProtoStringType taskletFilename = TaskletFileName(file);
    THolder<io::ZeroCopyOutputStream> taskletStream(generator_context->Open(taskletFilename));
    io::Printer taskletPrinter(taskletStream.Get(), '$');

    GenerateTaskletCode(taskletPrinter, file);

    TProtoStringType sbFilename = SandboxTaskFileName(file);
    THolder<io::ZeroCopyOutputStream> sbStream(generator_context->Open(sbFilename));
    io::Printer sbPrinter(sbStream.Get(), '$');

    GenerateSandboxTaskCode(sbPrinter, file);

    return true;
}

void TPythonGenerator::GenerateImports(io::Printer& printer, const FileDescriptor* file) const {
    TProtoStringType pb2 = ProtoModuleName(file);

    printer.Print("from tasklet.runtime.python import base as py_base\n");
    printer.Print("import tasklet.api.tasklet_pb2\n");
    printer.Print("import $pb2$\n", "pb2", pb2);
    printer.Print("\n\n");
}

void TPythonGenerator::GenerateHolder(io::Printer& printer, const Descriptor* descriptor) const {
    const TProtoStringType& name = descriptor->name();

    printer.Print("class $name$(py_base.TaskletHolder):\n", "name", HolderClassName(descriptor));
    printer.Indent();
    printer.Print("name = '$name$'\n", "name", name);
    printer.Print(
        "_sbtask_import_path = '$m$:$cls$'\n",
        "m", SandboxTaskModuleName(descriptor), "cls", SandboxTaskClassName(descriptor)
    );
    TTaskletComponents tasklet(descriptor);
    if (tasklet.InputField != nullptr) {
        ClassMemberDefinition(printer, "Input", tasklet.InputField);
    }
    if (tasklet.OutputField != nullptr) {
        ClassMemberDefinition(printer, "Output", tasklet.OutputField);
    }
    if (tasklet.ContextField != nullptr) {
        ClassMemberDefinition(printer, "Context", tasklet.ContextField);
    }
    if (tasklet.RequirementsField != nullptr) {
        ClassMemberDefinition(printer, "Requirements", tasklet.RequirementsField);
    }
    printer.Print("\n");
    printer.Print("def __init__(self, *args, **kwargs):\n");
    printer.Indent();
    printer.Print("super($name$, self).__init__(*args, **kwargs)\n", "name", name);
    printer.Outdent();

    printer.Outdent();

    printer.Print("\n");
    printer.Print("\n");
}

void TPythonGenerator::GenerateBase(io::Printer& printer, const Descriptor* descriptor) const {
    const TProtoStringType& name = descriptor->name();

    printer.Print("class $name$(py_base.TaskletBase):\n", "name", BaseTaskletClassName(descriptor));
    printer.Indent();
    printer.Print("__holder_cls__ = $cls$\n", "cls", HolderClassName(descriptor));
    printer.Print("\n");
    printer.Print("def __init__(self, model):\n");
    printer.Indent();

    TTaskletComponents tasklet(descriptor);
    const auto& components = TVector<const FieldDescriptor*>{
        tasklet.InputField, tasklet.OutputField, tasklet.ContextField, tasklet.RequirementsField
    };
    for (const auto* field: components) {
        if (field != nullptr) {
            MemberDefinition(printer, field->name(), field);
        }
    }
    printer.Print("super($name$Base, self).__init__(model)\n", "name", name);

    printer.Outdent();

    printer.Outdent();
    printer.Print("\n");
    printer.Print("\n");
}

void TPythonGenerator::GenerateTaskletCode(io::Printer& printer, const FileDescriptor* file) const {
    GenerateImports(printer, file);

    for (int i = 0; i < file->message_type_count(); ++i) {
        const Descriptor* descriptor = file->message_type(i);

        if (TTaskletComponents::IsTasklet(descriptor)) {
            GenerateHolder(printer, descriptor);
            GenerateBase(printer, descriptor);
        }
    }
}

} // namespace NTaskletGen
