#include <util/string/split.h>
#include <util/stream/pipe.h>
#include "main.h"

class TYtPacketsReducer
    : public NYT::IReducer<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>> {
private:
    TString Header_ = "";
public:
    void Start(TWriter* writer) override {
        Y_UNUSED(writer);

        TShellCommand unpack("tar -zxf libraries.tar.gz");
        unpack.Run().Wait();

        TShellCommand prepareHome("mkdir -p '.wireshark' &&"
                                  "cp custom_http_header_fields '.wireshark' && "
                                  "cp custom_http2_header_fields '.wireshark'");
        prepareHome.Run().Wait();

        TUnbufferedFileInput headerInput("header");
        Header_ = headerInput.ReadAll();
    }

    void Do(TReader* reader, TWriter* writer) override {
        uint64_t session = 0;

        TPipeHandle writerPipe, readerPipe;
        TPipeHandle::Pipe(readerPipe, writerPipe);

        THolder<IInputStream> pipedInputHolder;
        pipedInputHolder.Reset(new TPipedInput(readerPipe.Release()));

        TStringStream jsonPacket;
        TStringStream stdErr;

        TShellCommandOptions options;
        options.SetAsync(true);
        options.SetLatency(1);
        options.SetClearSignalMask(true);
        options.SetCloseAllFdsOnExec(true);
        options.SetCloseInput(false);

        options.SetInputStream(pipedInputHolder.Get());
        options.SetOutputStream(&jsonPacket);
        options.SetErrorStream(&stdErr);

        options.Environment["HOME"] = "./";
        options.Environment["LD_LIBRARY_PATH"] = "./";

        TList<TString> args = {
            "-qX", "lua_script:tshark.lua",
            "-r", "-",
            "-o", "ssl.desegment_ssl_records:TRUE",
            "-o", "ssl.desegment_ssl_application_data:TRUE",
            "-o", "http.desegment_body:FALSE",
            "-o", "ssl.keylog_file:secrets"};

        TShellCommand command("./tshark", args, options);

        command.Run();

        writerPipe.Write(Header_.c_str(), Header_.size());
        for (; reader->IsValid(); reader->Next()) {
            const auto &row = reader->GetRow();

            writerPipe.Write(row["data"].AsString().c_str(), row["data"].AsString().size());
            session = row["hash"].AsUint64();
        }

        writerPipe.Close();
        command.CloseInput();
        command.Wait();

        if (!stdErr.Empty()) {
            NYT::TNode node;

            node["stderr"] = stdErr.Str();

            writer->AddRow(node);
            return;
        }

        TRapidStream rapidStream(jsonPacket.Str());
        for (;;) {
            rapidjson::Document json;
            rapidjson::ParseResult result = json.ParseStream<rapidjson::kParseStopWhenDoneFlag>(rapidStream);

            if (!result)
                break;

            if (!json.IsArray() || json.HasParseError()) {
                NYT::TNode node;

                TStringStream buf("Incorrect json: ");
                buf << jsonPacket.Str() << " Error: " << rapidjson::GetParseError_En(result.Code()) << " " << result.Offset() << "\n";

                node["stderr"] = buf.Str();

                writer->AddRow(node);
                break;
            }

            for (size_t i = 0; i < json.Size(); ++i) {
                NYT::TNode node;
                node["session"] = session;

                for (auto &itr: json[i].GetObject()) {
                    if (itr.value.IsString())
                        node[itr.name.GetString()] = itr.value.GetString();
                    else if (itr.value.IsUint64())
                        node[itr.name.GetString()] = itr.value.GetUint64();
                    else if (itr.value.IsDouble() || itr.value.IsLosslessDouble())
                        node[itr.name.GetString()] = itr.value.GetDouble();
                    else if (itr.value.IsArray()) {
                        auto list = node.CreateList();
                        for (auto &arrayIterator: itr.value.GetArray()) {
                            if (arrayIterator.IsObject()) {
                                auto map = node.CreateMap();
                                for (auto &mapIterator: arrayIterator.GetObject()) {
                                    if (mapIterator.value.IsString())
                                        map[mapIterator.name.GetString()] = mapIterator.value.GetString();
                                    else if (mapIterator.value.IsUint64())
                                        map[mapIterator.name.GetString()] = mapIterator.value.GetUint64();
                                    else if (mapIterator.value.IsDouble() || mapIterator.value.IsLosslessDouble())
                                        map[mapIterator.name.GetString()] = mapIterator.value.GetDouble();
                                }
                                list.Add(map);
                            } else if (arrayIterator.IsString()) {
                                list.Add(arrayIterator.GetString());
                            }
                        }
                        node[itr.name.GetString()] = list;
                    } else if (itr.value.IsObject()) {
                        auto map = node.CreateMap();
                        for (auto &mapIterator: itr.value.GetObject()) {
                            if (mapIterator.value.IsString())
                                map[mapIterator.name.GetString()] = mapIterator.value.GetString();
                            else if (mapIterator.value.IsUint64())
                                map[mapIterator.name.GetString()] = mapIterator.value.GetUint64();
                            else if (mapIterator.value.IsDouble() || mapIterator.value.IsLosslessDouble())
                                map[mapIterator.name.GetString()] = mapIterator.value.GetDouble();
                        }
                        node[itr.name.GetString()] = map;
                    }
                }

                writer->AddRow(node);
            }
        }
    }
};
REGISTER_REDUCER(TYtPacketsReducer);

void
TMapOperations::RunMap(const TYtOptions& options)
{
    auto schema = NYT::TTableSchema()
        .AddColumn("session", NYT::VT_UINT64)
        .AddColumn("frame_protocols", NYT::VT_ANY)
        .AddColumn("frame_time_epoch", NYT::VT_INT64)
        .AddColumn("ip_src", NYT::VT_STRING)
        .AddColumn("tcp_srcport", NYT::VT_UINT64)
        .AddColumn("ip_dst", NYT::VT_STRING)
        .AddColumn("tcp_dstport", NYT::VT_UINT64)
        .AddColumn("http", NYT::VT_ANY)
        .AddColumn("http2", NYT::VT_ANY)
        .AddColumn("ssl_alert_message_desc", NYT::VT_UINT64)
        .AddColumn("ssl_alert_message_level", NYT::VT_UINT64)
        .AddColumn("ssl_handshake_ciphersuite", NYT::VT_UINT64)
        .AddColumn("ssl_handshake_extensions_alpn_str", NYT::VT_STRING)
        .AddColumn("ssl_handshake_random", NYT::VT_STRING)
        .AddColumn("ssl_handshake_type", NYT::VT_UINT64)
        .AddColumn("ssl_handshake_version", NYT::VT_UINT64)
        .AddColumn("ssl_ignored_unknown_record", NYT::VT_UINT64)
        .AddColumn("ssl_record_content_type", NYT::VT_UINT64)
        .AddColumn("ssl_record_length", NYT::VT_UINT64)
        .AddColumn("tcp_ack", NYT::VT_UINT64)
        .AddColumn("tcp_analysis_ack_rtt", NYT::VT_INT64)
        .AddColumn("tcp_analysis_duplicate_ack_num", NYT::VT_UINT64)
        .AddColumn("tcp_analysis_fast_retransmission", NYT::VT_UINT64)
        .AddColumn("tcp_analysis_flags", NYT::VT_UINT64)
        .AddColumn("tcp_analysis_retransmission", NYT::VT_UINT64)
        .AddColumn("tcp_checksum", NYT::VT_UINT64)
        .AddColumn("tcp_flags_ack", NYT::VT_UINT64)
        .AddColumn("tcp_flags_cwr", NYT::VT_UINT64)
        .AddColumn("tcp_flags_ecn", NYT::VT_UINT64)
        .AddColumn("tcp_flags_fin", NYT::VT_UINT64)
        .AddColumn("tcp_flags_push", NYT::VT_UINT64)
        .AddColumn("tcp_flags_rst", NYT::VT_UINT64)
        .AddColumn("tcp_flags_syn", NYT::VT_UINT64)
        .AddColumn("tcp_flags_urg", NYT::VT_UINT64)
        .AddColumn("tcp_mss_val", NYT::VT_UINT64)
        .AddColumn("tcp_nxtseq", NYT::VT_UINT64)
        .AddColumn("tcp_options_sack_le", NYT::VT_UINT64)
        .AddColumn("tcp_options_sack_re", NYT::VT_UINT64)
        .AddColumn("tcp_seq", NYT::VT_UINT64)
        .AddColumn("tcp_tsecr", NYT::VT_UINT64)
        .AddColumn("tcp_tsval", NYT::VT_UINT64)
        .AddColumn("tcp_window_size_value", NYT::VT_UINT64)
        .AddColumn("stderr", NYT::VT_STRING);

    YtTrx_->Reduce(
        NYT::TReduceOperationSpec()
            .ReducerSpec(
                NYT::TUserJobSpec()
                    .AddFile(NYT::TRichYPath(options.ytHeaderFile).FileName("header"))
                    .AddFile(NYT::TRichYPath(options.tsharkPath).FileName("tshark").Executable(true))
                    .AddFile(NYT::TRichYPath(options.libsPath).FileName("libraries.tar.gz"))
                    .AddFile(NYT::TRichYPath(options.luaPath).FileName("tshark.lua"))
                    .AddFile(NYT::TRichYPath(options.httpHeadersPath).FileName("custom_http_header_fields"))
                    .AddFile(NYT::TRichYPath(options.http2HeadersPath).FileName("custom_http2_header_fields"))
                    .AddFile(NYT::TRichYPath(options.ytSecretsFile).FileName("secrets"))
                )
            .ReduceBy("hash")
            .AddInput<NYT::TNode>(options.ytInputTable)
            .AddOutput<NYT::TNode>(NYT::TRichYPath(options.ytOutputTable).Schema(schema)),
        new TYtPacketsReducer(),
        NYT::TOperationOptions().Spec(
            NYT::TNode()("max_failed_job_count", 20)
                ("memory_limit", options.memoryLimit * 1024 * 1024)
                ("memory_reserve_factor", 0.8)
                ("data_size_per_job", 16 * 1024 * 1024))
    );

    YtTrx_->Sort(
        NYT::TSortOperationSpec()
            .SortBy("frame_time_epoch")
            .AddInput(options.ytOutputTable)
            .Output(options.ytOutputTable)
    );
    YtTrx_->Commit();
}

void RunMainMap(TStringBuf argv0, const TYtOptions &options)
{
    if (options.tsharkPath.empty()) {
        ythrow yexception() << "Empty tshark YT path.\nTry '"
             << argv0 << "--help' for more information.\n";
    }

    if (options.libsPath.empty()) {
        ythrow yexception() << "Empty libraries package YT path.\nTry '"
             << argv0 << "--help' for more information.\n";
    }

    if (options.luaPath.empty()) {
        ythrow yexception() << "Empty lua script YT path.\nTry '"
             << argv0 << "--help' for more information.\n";
    }

    if (options.httpHeadersPath.empty()) {
        ythrow yexception() << "Empty http headers YT path.\nTry '"
             << argv0 << "--help' for more information.\n";
    }

    if (options.http2HeadersPath.empty()) {
        ythrow yexception() << "Empty http2 headers YT path.\nTry '"
             << argv0 << "--help' for more information.\n";
    }

    if (options.ytSecretsFile.empty()) {
        ythrow yexception() << "Empty secrets YT path.\nTry '"
             << argv0 << "--help' for more information.\n";
    }

    TMapOperations mapOperations(options.ytCluster);
    mapOperations.RunMap(options);
}
