#include "sandbox.h"

#include <tasklet/gen/lib/components.h>
#include <tasklet/gen/lib/names.h>

#include <tasklet/api/tasklet.pb.h>

#include <google/protobuf/compiler/cpp/cpp_helpers.h>
#include <google/protobuf/compiler/plugin.h>

#include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/strutil.h>

#include <util/generic/map.h>
#include <util/generic/ptr.h>
#include <util/generic/set.h>
#include <util/generic/vector.h>
#include <util/generic/yexception.h>
#include <util/string/join.h>
#include <util/string/printf.h>

namespace {

using namespace google::protobuf;

const int MAX_FIELD_DEPTH_ALLOWED = 8;
const TString FIELDS_JOINER = "__";
const TString TASKLET_INPUT_PARAMETER_NAME = "__tasklet_input__";
const TString TASKLET_OUTPUT_PARAMETER_NAME = "__tasklet_output__";

TProtoStringType SDKParameterValueType(const FieldDescriptor::Type type) {
    if (type == FieldDescriptor::TYPE_FLOAT || type == FieldDescriptor::TYPE_DOUBLE) {
        return "Float";
    } else if (type == FieldDescriptor::TYPE_INT32 || type == FieldDescriptor::TYPE_INT64) {
        return "Integer";
    } else if (type == FieldDescriptor::TYPE_BOOL) {
        return "Bool";
    }
    return "String";
}

void PrintSDKParameterDecl(
    io::Printer& printer,
    const FieldDescriptor* field,
    const TString& fieldName)
{

    auto line = "$name$ = sdk2.parameters.$type$(\"$name$\")\n";
    if (field->is_repeated()) {
        line = "$name$ = sdk2.parameters.List(\"$name$\", sdk2.parameters.$type$)\n";
    }

    printer.Print(
        line,
        "name", fieldName,
        "type", SDKParameterValueType(field->type())
    );
}

bool IsTooComplex(const Descriptor* descriptor, int depth = 0) {
    if (depth >= MAX_FIELD_DEPTH_ALLOWED) {
        return true;
    }
    for (int i = 0; i < descriptor->field_count(); ++i) {
        const FieldDescriptor* subfield = descriptor->field(i);
        if (subfield->message_type() == nullptr) {
            continue; // primitive value
        }

        if (subfield->is_repeated() || IsTooComplex(subfield->message_type(), depth + 1)) {
            return true;
        }
    }

    return false;
}

void PrintTaskParameter(
    io::Printer& printer,
    const FieldDescriptor* field,
    const TString& prefix = "")
{
    if (field->name().Contains(FIELDS_JOINER)) {
        ythrow yexception() << "Using of '" << FIELDS_JOINER << "' in field names is not allowed: "
                     << field->full_name() << Endl;
    }

    TString fullFieldName =
        prefix.Empty() ? field->name() : Join(FIELDS_JOINER, prefix, field->name());
    if (field->message_type() == nullptr) {
        PrintSDKParameterDecl(printer, field, fullFieldName);
    } else {
        for (int i = 0; i < field->message_type()->field_count(); ++i) {
            const FieldDescriptor* subfield = field->message_type()->field(i);
            PrintTaskParameter(printer, subfield, fullFieldName);
        }
    }
}

void PrintRequiredJsonParameter(
    io::Printer& printer,
    const TString& complexParameterName,
    const TString& complexLabel,
    const TString& description)
{
    printer.Print(
        "$name$ = sdk2.parameters.JSON(\"$label$\","
        " required=True,"
        " description=\"$description$\")\n",
        "name", complexParameterName,
        "label", complexLabel,
        "description", description
    );
}

void PrintTaskParameters(
    io::Printer& printer,
    const FieldDescriptor* field,
    const TString& complexParameterName,
    const TString& complexLabel)
{
    const Descriptor* message = field->message_type();
    bool showAsJson = field->options().GetExtension(tasklet::show_as_json);

    printer.Indent();

    if (IsTooComplex(message)) {
        PrintRequiredJsonParameter(
            printer,
            complexParameterName,
            complexLabel,
            "Protobuf message is too complex to represent as separated fields."
            " Json format is used"
        );
    } else if (showAsJson) {
        PrintRequiredJsonParameter(
            printer,
            complexParameterName,
            complexLabel,
            "Json format is forced by field annotation in protobuf message"
        );
    } else {
        for (int i = 0; i < message->field_count(); ++i) {
            const FieldDescriptor* field = message->field(i);

            PrintTaskParameter(printer, field);
        }
    }
    printer.Outdent();
}

}

namespace NTaskletGen {

void GenerateSandboxTaskCode(io::Printer& printer, const FileDescriptor* file) {

    TSet<TProtoStringType> toImport;
    for (int i = 0; i < file->message_type_count(); ++i) {
        const Descriptor* descriptor = file->message_type(i);
        if (TTaskletComponents::IsTasklet(descriptor)) {
            toImport.insert(BaseTaskletModuleName(descriptor->file()));
        }
    }

    printer.Print("from sandbox import sdk2\n\n");
    printer.Print("from tasklet.domain import sandbox\n\n");
    for (const auto& module: toImport) {
        printer.Print("import $module$\n", "module", module);
    }

    for (int i = 0; i < file->message_type_count(); ++i) {
        const Descriptor* descriptor = file->message_type(i);
        if (!TTaskletComponents::IsTasklet(descriptor)) {
            continue;
        }
        printer.Print("\n\n");

        const TProtoStringType& name = descriptor->name();

        printer.Print("class $name$(sandbox.BaseTasklet):\n",
                      "name",
                      SandboxTaskClassName(descriptor));
        printer.Indent();
        printer.Print(
            R"'(""" Tasklet for <a href="https://a.yandex-team.ru/arc_vcs/$path$">$name$</a> """)'",
            "path",
            descriptor->file()->name(),
            "name",
            descriptor->name()
        );
        printer.Print("\n");
        printer.Print(
            "__holder_cls__ = $m$.$cls$\n\n",
            "m",
            BaseTaskletModuleName(file),
            "cls",
            HolderClassName(descriptor)
        );
        printer.Print("class Parameters(sandbox.BaseTasklet.Parameters):\n");
        printer.Indent();

        for (int j = 0; j < descriptor->field_count(); ++j) {
            const FieldDescriptor* field = descriptor->field(j);
            if (field->options().HasExtension(tasklet::input)) {
                printer.Print(
                    "with sdk2.parameters.Group(\"$name$ input parameters\") as TaskletInputParameters:\n",
                    "name",
                    name
                );
                PrintTaskParameters(printer, field, TASKLET_INPUT_PARAMETER_NAME, "Tasklet input");
            } else if (field->options().HasExtension(tasklet::output) && field->message_type()->field_count() != 0) {
                printer.Print("with sdk2.parameters.Output:\n");
                PrintTaskParameters(
                    printer,
                    field,
                    TASKLET_OUTPUT_PARAMETER_NAME,
                    "Tasklet output"
                );
            }
        }
        printer.Outdent();

        printer.Outdent();
    }
}

} // namespace NTaskletGen
