#include "helpers.h"

#include <dict/json/json_object.h>
#include <dict/json/json_reader.h>

#include <library/cpp/testing/unittest/registar.h>

#include <google/protobuf/text_format.h>

#include <library/cpp/http/io/stream.h>

#include <util/string/cast.h>
#include <util/generic/string.h>
#include <util/generic/yexception.h>
#include <library/cpp/protobuf/json/proto2json.h>

THttpResponse GetHttpData(ui16 port, const char* query) {
    TNetworkAddress address("localhost", port);
    TSocket socket(address);
    SendMinimalHttpRequest(socket, "localhost", query);
    TSocketInput socketInput(socket);
    THttpInput httpInput(&socketInput);
    THttpResponse response = {
        ParseHttpRetCode(httpInput.FirstLine()),
        httpInput.Headers(),
        httpInput.ReadAll()
    };
    DEBUG_LOG << response.Body << Endl;
    return response;
}

THttpResponse GetHttpData(ui16 port, const char* query, const THashMap<TString, TString>& headers) {
    TNetworkAddress address("localhost", port);
    TSocket socket(address);
    TSocketOutput so(socket);
    THttpOutput output(&so);
    output.EnableKeepAlive(false);
    output.EnableCompression(false);

    TVector<IOutputStream::TPart> parts = {
        IOutputStream::TPart(TStringBuf("GET ")),
        IOutputStream::TPart(query),
        IOutputStream::TPart(TStringBuf(" HTTP/1.1")),
        IOutputStream::TPart::CrLf(),
        IOutputStream::TPart(TStringBuf("Host: ")),
        IOutputStream::TPart(TStringBuf("localhost")),
        IOutputStream::TPart::CrLf(),
    };

    for (const auto& kv: headers) {
        parts.push_back(IOutputStream::TPart(kv.first));
        parts.push_back(IOutputStream::TPart(TStringBuf(": ")));
        parts.push_back(IOutputStream::TPart(kv.second));
        parts.push_back(IOutputStream::TPart::CrLf());
    }
    parts.push_back(IOutputStream::TPart::CrLf());

    output.Write(parts.data(), parts.size());
    output.Finish();

    TSocketInput socketInput(socket);
    THttpInput httpInput(&socketInput);
    THttpResponse response = {
        ParseHttpRetCode(httpInput.FirstLine()),
        httpInput.Headers(),
        httpInput.ReadAll()
    };
    DEBUG_LOG << response.Body << Endl;
    return response;
}

TString GetHeader(const THttpHeaders& headers, const TString& name) {
    for(THttpHeaders::TConstIterator iter = headers.Begin(),
        end = headers.End(); iter != end; ++iter) {
        if (iter->Name() == name) {
            return iter->Value();
        }
    }
    throw yexception() << "Header " << name.Quote() << " not found";
}

TSearchResults ExtractSearchResults(const NMetaProtocol::TReport& report,
    unsigned httpStatus, const bool isKv)
{
    TSearchResults results(httpStatus);
    if (!report.GetGrouping().size()) {
        return results;
    }

    const NMetaProtocol::TGrouping& grouping = report.GetGrouping(0);
    using ::google::protobuf::RepeatedPtrField;
    for (RepeatedPtrField<NMetaProtocol::TGroup>::const_iterator iter =
        grouping.GetGroup().begin(), end = grouping.GetGroup().end();
        iter != end; ++iter) {
        if (iter->GetDocument().size()) {
            results.push_back(
                TSearchResultsGroup(iter->GetDocument(0).GetRelevance()));
            for (RepeatedPtrField<NMetaProtocol::TDocument>::const_iterator
                iterDoc = iter->GetDocument().begin(),
                endDoc = iter->GetDocument().end(); iterDoc != endDoc;
                ++iterDoc) {
                if (isKv) {
                    results.back().Urls.push_back(iterDoc->GetUrl());
                } else {
                    if (iterDoc->HasArchiveInfo()) {
                        results.back().Urls.push_back(
                            iterDoc->GetArchiveInfo().GetUrl());
                    }
                }
            }
        }
    }
    return results;
}

TSearchResults ExtractSearchResults(const TJsonObject& report,
    unsigned httpStatus)
{
    TSearchResults results(httpStatus);
    const TJsonObject* response = report.Value("response");
    const TJsonArray* result =
        dynamic_cast<const TJsonArray*>(response->Value("results"));
    if (!result) {
        return results;
    }

    const TJsonArray* grouping =
        dynamic_cast<const TJsonArray*>(result->Item(0)->Value("groups"));

    for (int i = 0; grouping && i < grouping->Length(); ++i) {
        const TJsonObject* group = grouping->Item(i);
        const TJsonObject* relevance = group->Value("relevance");
        if (relevance) {
            results.push_back(TSearchResultsGroup(FromString(relevance->ToString()),
                TSearchResultsGroup::TUrls()));
            const TJsonArray* documents = dynamic_cast<const TJsonArray*>(
                group->Value("documents"));
            for (int j = 0; documents && j < documents->Length(); ++j) {
                const TJsonObject* url = documents->Item(j)->Value("url");
                if (url) {
                    results.back().Urls.push_back(WideToUTF8(url->ToString()));
                }
            }
        }
    }
    return results;
}

NGeminiProtos::TCastorResponse ExtractGeminiResponse(const NMetaProtocol::TReport& report)
{
    NGeminiProtos::TCastorResponse resp;
    if (!report.GetGrouping().size()) {
        return resp;
    }

    const NMetaProtocol::TGrouping& grouping = report.GetGrouping(0);
    using ::google::protobuf::RepeatedPtrField;
    for (const auto& group: grouping.GetGroup()) {
        for (const auto& doc: group.GetDocument()) {
            auto& archiveInfo = doc.GetArchiveInfo();
            for (ui32 i = 0; i < archiveInfo.GtaRelatedAttributeSize(); ++i) {
                auto& attr = archiveInfo.GetGtaRelatedAttribute(i);
                if (attr.GetKey() == "Response") {
                    DEBUG_LOG << attr.GetValue() << Endl;
                    if (!resp.ParseFromString(attr.GetValue())){
                        resp.SetError(NGeminiProtos::EErrorType::UNKNOWN_ERROR);
                    }
                }
            }
        }
    }
    DEBUG_LOG << NProtobufJson::Proto2Json(resp) << Endl;
    return resp;
}

NMetaProtocol::TReport GenerateReport(const TSearchResults& data, const bool isKv) {
    NMetaProtocol::TReport report;
    report.MutableDebugInfo();
    if (!data.size()) {
        report.MutableDebugInfo()->SetAnswerIsComplete(true);
        return report;
    }

    NMetaProtocol::TGrouping* grouping = report.AddGrouping();
    size_t totalDocCount = 0;
    for (TSearchResults::const_iterator iter = data.begin(), end = data.end();
        iter != end; ++iter) {
        NMetaProtocol::TGroup* group = grouping->AddGroup();
        group->SetRelevance(iter->Relevance);
        typedef TSearchResultsGroup::TUrls TUrls;
        const TUrls& urls = iter->Urls;
        totalDocCount += urls.size();
        for (TUrls::const_iterator groupIter = urls.begin(),
            groupEnd = urls.end(); groupIter != groupEnd; ++groupIter) {
            NMetaProtocol::TDocument* document = group->AddDocument();
            document->SetRelevance(iter->Relevance);
            if (isKv) {
                document->SetUrl(*groupIter);
            } else {
                document->MutableArchiveInfo()->SetUrl(*groupIter);
            }
            NMetaProtocol::TPairBytesBytes* attr = document->MutableArchiveInfo()->AddGtaRelatedAttribute();
            attr->SetKey("prefix");
            attr->SetValue("1");
            attr = document->MutableArchiveInfo()->AddGtaRelatedAttribute();
            attr->SetKey("_MimeType");
            attr->SetValue("7 0&d=23845&sh=-1&sg=");
        }
    }

    for (size_t i = 0; i < NUM_PRIORITIES; ++i) {
        report.AddTotalDocCount(totalDocCount);
    }
    return report;
}

NMetaProtocol::TReport GenerateGeminiReport(const TString& url, const TString& mainUrl) {
    NMetaProtocol::TReport report;
    NMetaProtocol::TGrouping* grouping = report.AddGrouping();
    NMetaProtocol::TGroup* group = grouping->AddGroup();
    group->SetRelevance(1000);
    NMetaProtocol::TDocument* document = group->AddDocument();
    document->SetRelevance(1000);
    document->SetUrl(url);

    if (url == mainUrl) {
        NMetaProtocol::TPairBytesBytes* attr = document->MutableArchiveInfo()->AddGtaRelatedAttribute();
        attr->SetKey("IsMain");
        attr->SetValue("1");

        attr = document->MutableArchiveInfo()->AddGtaRelatedAttribute();
        attr->SetKey("IsBeauty");
        attr->SetValue("1");
    } else {
        NMetaProtocol::TPairBytesBytes* attr = document->MutableArchiveInfo()->AddGtaRelatedAttribute();
        attr->SetKey("Main");
        attr->SetValue(mainUrl);

        attr = document->MutableArchiveInfo()->AddGtaRelatedAttribute();
        attr->SetKey("Beauty");
        attr->SetValue(mainUrl);
    }

    for (size_t i = 0; i < NUM_PRIORITIES; ++i) {
        report.AddTotalDocCount(1);
    }
    return report;
}

NMetaProtocol::TReport GenerateEmptyReport() {
    NMetaProtocol::TReport report;
    for (size_t i = 0; i < NUM_PRIORITIES; ++i) {
        report.AddTotalDocCount(0);
    }
    return report;
}

TString GetSearchMapRule(const char* service, ui16 port, ui16 shardMin,
    ui16 shardMax, const char* rank, const char* shardBy)
{
    return GetSearchMap({ GetSearchMapLine(service, port, shardMin, shardMax, rank, shardBy) });
}

TString GetSearchMap(const TVector<TSearchMapLine>& lines) {
    NJson::TJsonValue json;
    for (auto&& line : lines) {
        NJson::TJsonValue host;
        host["host"] = "localhost";
        host["search_port"] = line.SearchPort;
        host["shard_min"] = line.ShardMin;
        host["shard_max"] = line.ShardMax;
        if (!line.Group.empty())
            host["group"] = line.Group;
        json[line.Service]["replicas"]["default"].AppendValue(host);
        if (line.ShardBy) {
            json[line.Service]["shard_by"] = line.ShardBy;
        }
    }

    TString result = json.GetStringRobust();
    DEBUG_LOG << result << Endl;
    return result;
}

const char* GetDaemonConfig() {
    return
        "<DaemonConfig>\n"
        "\tLogLevel: 0\n"
        "\tLoggerType: null\n"
        "</DaemonConfig>";
}

TString GetSearchProxyConfig(ui16 port) {
    return TString() +
        + "<SearchProxy>\n"
        + "\tHost: localhost\n"
        + "\tPort: " + ToString(port) + '\n'
        + "</SearchProxy>";
}

TString GetPatientSearchProxyConfig(ui16 port, const char* searchConfig, const char* logRoot, const char* serviceConfig) {
    TString logSection;
    if (logRoot) {
        logSection += "<Logger>\n";
        logSection += TString("InfoLog:") + logRoot + "/access.log\n";
        logSection += TString("ReqAnsLog:") + logRoot + "/reqans.log\n";
        logSection += "</Logger>\n";
    }
    TString serviceConfigStr;
    if (serviceConfig) {
        serviceConfigStr = + "\t" + TString(serviceConfig) + "\n";
    }
    return TString() +
        + "<SearchProxy>\n"
        + "\tHost: localhost\n"
        + "\tPort: " + ToString(port) + '\n'
        + "\tMaxConnections: 3000\n"
        + "\tMaxQueueSize: 10000\n"
        + "\tThreads: 64\n"
        + logSection
        + "\t<SearchConfig>\n"
        + "\t\tThreads: 64\n"
        + "\t\tConnectTimeout: 10\n"
        + "\t\tGlobalTimeout: 30000\n"
        + "\t\t" + searchConfig + "\n"
        + "\t\t<SearchPageTemplate>\n"
        + "\t\t\tMethod: builtin\n"
        + "\t\t\tModule: searchproxy\n"
        + "\t\t</SearchPageTemplate>\n"
        + "\t\t<HttpStatuses>\n"
        + "\t\t\tEmptySetStatus: 404\n"
        + "\t\t\tIncompleteStatus: 502\n"
        + "\t\t\tSyntaxErrorStatus: 400\n"
        + "\t\t</HttpStatuses>\n"
        + "\t</SearchConfig>\n"
        + serviceConfigStr
        + "</SearchProxy>";
}

TString OneStepRequest(const TString& service, const TString& kps,
    const TString& request, const TString& /*userRequest*/, const TString& how)
{
    return "/?" + (how == "rlv" ? "" : ("&how=" + how))
        + (!!kps ? (TString("&kps=") + kps) : "") + "&service=" + service + "&text=" + request
        + "&ms=proto&p=0&g=0..10.1.-1.0.0.-1." + how + ".0.";
}

TString KvRequest(const TString& service, const TString& kps, const TString& request) {
    return "/?" + (!!kps ? (TString("&kps=") + kps) : "") + "&service=" + service + "&text=" + request;
}

void CheckResults(const TSearchResults& actual, const TSearchResults& expected)
{
    UNIT_ASSERT_VALUES_EQUAL(actual.HttpStatus, expected.HttpStatus);
    UNIT_ASSERT_VALUES_EQUAL(actual.size(), expected.size());
    for (size_t i = 0; i < actual.size(); ++i) {
        UNIT_ASSERT_VALUES_EQUAL_C(actual[i].Urls.size(),
            expected[i].Urls.size(),
            ("While checking results group #" + ToString(i)).data());
        UNIT_ASSERT_VALUES_EQUAL_C(actual[i].Relevance,
            expected[i].Relevance,
            ("While checking results group #" + ToString(i)).data());
        for (size_t j = 0; j < actual[i].Urls.size(); ++j) {
            UNIT_ASSERT_VALUES_EQUAL_C(actual[i].Urls[j],
                expected[i].Urls[j],
                "While checking results[" + ToString(i) + "]["
                + ToString(j) + ']');
        }
    }
}

void CheckGeminiProtos(const NGeminiProtos::TCastorResponse& actual, const NGeminiProtos::TCastorResponse& expected)
{
    UNIT_ASSERT_VALUES_EQUAL(actual.HasError(), expected.HasError());
    if (actual.HasError()) {
        UNIT_ASSERT_EQUAL(actual.GetError(), expected.GetError());
    }
    UNIT_ASSERT_VALUES_EQUAL(actual.GetCanonizedUrl(), expected.GetCanonizedUrl());
    UNIT_ASSERT_VALUES_EQUAL(actual.GetOriginalUrl(), expected.GetOriginalUrl());
    UNIT_ASSERT_VALUES_EQUAL(actual.MainUrlSize(), expected.MainUrlSize());
    if (actual.MainUrlSize()) {
        UNIT_ASSERT_VALUES_EQUAL(actual.GetMainUrl(0), expected.GetMainUrl(0));
    }
}

void CheckProtoQuery(ui16 port, const TString& query,
    const TSearchResults& expected, const bool isKv)
{
    try {
        const THttpResponse& response = GetHttpData(port,
            (query + "&ms=proto&hr=da&dump=eventlog").data());
        UNIT_ASSERT_VALUES_EQUAL(expected.HttpStatus, response.HttpCode);
        NMetaProtocol::TReport report;
        UNIT_ASSERT_C(::google::protobuf::TextFormat::ParseFromString(response.Body, &report),
            ("Failed to parse message: " + response.Body.Quote()).data());
        CheckResults(ExtractSearchResults(report, response.HttpCode, isKv), expected);
    } catch (...) {
        Cerr << "While checking proto query: " << query << Endl;
        throw;
    }
}

void CheckJsonQuery(ui16 port, const TString& query,
    const TSearchResults& expected)
{
    try {
        const THttpResponse& response = GetHttpData(port,
            (query + "&format=json").data());
        UNIT_ASSERT_VALUES_EQUAL(expected.HttpStatus, response.HttpCode);
        TStringInput input(response.Body);
        TJsonReader reader(CODES_UTF8, &input);
        const TJsonObject* data = reader.ReadAll();
        UNIT_ASSERT_C(data,
            ("Failed to parse message: " + response.Body.Quote()).data());
        CheckResults(ExtractSearchResults(*data, response.HttpCode), expected);
    } catch (...) {
        Cerr << "While checking json query: " << query << Endl;
        throw;
    }
}

void CheckGeminiQuery(ui16 port, const TString& query,
    const NGeminiProtos::TCastorResponse& expected)
{
    try {
        const THttpResponse& httpResponse = GetHttpData(port,
            (query + "&ms=proto&dump=eventlog").data());
        NMetaProtocol::TReport report;
        UNIT_ASSERT_C(report.ParseFromString(httpResponse.Body),
            ("Failed to parse message: " + httpResponse.Body.Quote()).data());
        NGeminiProtos::TCastorResponse response = ExtractGeminiResponse(report);
        CheckGeminiProtos(response, expected);
    } catch (...) {
        Cerr << "While checking proto query: " << query << Endl;
        throw;
    }
}

