#include "extract_userdata.h"

#include <crypta/lib/native/sampler/rest_sampler.h>
#include <crypta/lib/native/vectors/vectors.h>

#include <library/cpp/dot_product/dot_product.h>

#include <util/generic/xrange.h>
#include <util/string/ascii.h>
#include <util/string/cast.h>

#include <iostream>
#include <tuple>
#include <utility>

using namespace NLab;

namespace {
    void SetTokenAndWeight(const TString& columnName, const TNode& row, TUserData& userData) {
        auto* token = columnName == "host" ? userData.MutableAffinities()->MutableHosts()->AddToken() : userData.MutableAffinities()->MutableWords()->AddToken();
        token->SetToken(row.At(columnName).AsString());
        token->SetWeight(static_cast<float>(row.At("weight").AsUint64()));
    }

    void UnpackVector(const TString& vectorAsString, TUserData& userData) {
        const auto& vectorToUnpack = NVectorOperations::Unpack<float>(vectorAsString);
        auto* mutableData = userData.MutableVectors()->MutableVector()->MutableData();
        mutableData->Resize(vectorToUnpack.size(), 0);
        for (auto i : xrange(vectorToUnpack.size())) {
            mutableData->mutable_data()[i] = vectorToUnpack[i];
        }
    }

    TMaybe<TId> GetIdWithType(const TNode& row) {
    	static const TVector<std::pair<TString, TId::EIdType>> fieldTypes = {
    		{"yandexuid", TId::EIdType::Yandexuid},
    		{"yuid", TId::EIdType::Yandexuid},
    		{"crypta_id", TId::EIdType::CryptaId},
    	};

    	for (const auto& [fieldName, type] : fieldTypes) {
    		const auto& idNode = row[fieldName];
    		if (!idNode.IsUndefined()) {
    			return TId{
    				.Id = idNode.IsUint64() ? idNode.AsUint64() : FromString<ui64>(idNode.AsString()),
    				.Type = type
				};
    		}
    	}
    	return Nothing();
    }

    TDevice GetCryptaIdDeviceType(const TNode& row) {
        static const THashMap<TString, TDevice> fieldToDevice = {
            {"desk", DESKTOP},
            {"phone", PHONE},
            {"tablet", TABLET}
        };

        TMultiMap<size_t, TDevice, std::greater<>> counts;
        for (const auto& [field, node]: row.AsMap()) {
            auto it = fieldToDevice.find(field);
            if (const auto count = node.IsUint64() ? node.AsUint64() : FromString<ui64>(node.AsString()); it != fieldToDevice.end() && count != 0) {
                counts.insert({count, it->second});
            }
        }

        size_t countsSize = counts.size();
        if (countsSize == 0) {
            return UNKNOWN_DEVICE;
        } else if (countsSize == 1) {
            return counts.begin()->second;
        }

        return MIXED;
    }
}

TExtractUserData::TExtractUserData(const TBuffer& buffer)
    : TStateful(buffer)
{
}

const std::unordered_set<ui64> TExtractUserData::PROBABILISTIC_KEYWORDS = {544, 546};

void TExtractUserData::AddSegments(const TNode& row, const TExtractUserDataState::TProfilesLogState::TFields& fields, ui64 keyword, TUserData* userData) {
        for (const auto& field : fields.GetField()) {
            const auto& value = row[field];
            if (value.IsMap()) {
                const auto& segments = value.AsMap();
                for (const auto& kv : segments) {
                    auto* segment = userData->MutableSegments()->AddSegment();
                    segment->SetKeyword(keyword);
                    segment->SetID(FromString<ui64>(kv.first));
                    if (PROBABILISTIC_KEYWORDS.contains(keyword)) {
                        segment->SetScore(kv.second.IsDouble() ? kv.second.AsDouble() : kv.second.AsInt64());
                    }
                }
            } else if (value.IsList()) {
                for (const auto& segmentId : value.AsList()) {
                    auto* segment = userData->MutableSegments()->AddSegment();
                    segment->SetKeyword(keyword);
                    ui64 id = 0;
                    if (segmentId.IsUint64()) {
                        id = segmentId.AsUint64();
                    } else if (segmentId.IsInt64()) {
                        id = static_cast<ui64>(segmentId.AsInt64());
                    } else if (segmentId.IsString()) {
                        id = FromString<ui64>(segmentId.AsString());
                    }
                    segment->SetID(id);
                }
            }
        }
    }

TString TExtractUserData::GetDeviceTypeFromUa(const TString& uaProfile) {
    TStringBuf buf(uaProfile);
    buf.NextTok("|");
    auto deviceType = TString(buf.NextTok("|"));
    return deviceType;
}

TDevice TExtractUserData::GetDevice(const TNode& row) {
    auto uaProfile = row["ua_profile"];
    if (!uaProfile.IsString()) {
        return UNKNOWN_DEVICE;
    }
    auto deviceType = GetDeviceTypeFromUa(uaProfile.AsString());
    if (deviceType == "desk") {
        return DESKTOP;
    } else if (deviceType == "phone") {
        return PHONE;
    } else if (deviceType == "tablet") {
        return TABLET;
    } else {
        return UNKNOWN_DEVICE;
    }
}

bool TExtractUserData::HasActiveIp(const TNode& row) {
    auto ipActivityType = row["ip_activity_type"];
    if (!ipActivityType.IsString()) {
        return false;
    }
    return ipActivityType.AsString() == "active";
}

bool TExtractUserData::ProcessRegionsAndDevice(const TNode& row, TUserData& userData, const TId::EIdType& type) {
    if (!HasActiveIp(row) && type == TId::EIdType::Yandexuid) {
        return false;
    }

    auto* attributes = userData.MutableAttributes();
    attributes->SetRegion(GetExactCity(row).GetOrElse(0));
    attributes->SetCity(GetCity(row));
    attributes->SetCountry(GetCountry(row));

    if (type == TId::EIdType::Yandexuid) {
        attributes->SetDevice(GetDevice(row));
    }
    return true;
}

bool TExtractUserData::ProcessDeviceCryptaID(const TNode& row, TUserData& userData) {
    auto* attributes = userData.MutableAttributes();
    attributes->SetDevice(GetCryptaIdDeviceType(row));
    return true;
}


TMaybe<ui64> TExtractUserData::GetExactCity(const TNode& row) {
    auto mainRegionCity = row["main_region_city"];
    if (!mainRegionCity.IsInt64()) {
        return Nothing();
    }
    return mainRegionCity.AsInt64();
}

TCity TExtractUserData::GetCity(const TNode& row) {
    auto city = GetExactCity(row);
    if (city.Empty()) {
        return DEFAULT_CITY;
    }
    switch (city.GetRef()) {
        case MOSCOW:
        case SAINT_PETERSBURG:
            return TCity(city.GetRef());
        default:
            return DEFAULT_CITY;
    }
}

TCountry TExtractUserData::GetCountry(const TNode& row) {
    auto mainRegionCountry = row["main_region_country"];
    if (!mainRegionCountry.IsInt64()) {
        return DEFAULT_COUNTRY;
    }
    switch (mainRegionCountry.AsInt64()) {
        case RUSSIA:
        case TURKEY:
        case UKRAINE:
        case BELARUS:
        case KAZAKHSTAN:
            return TCountry(mainRegionCountry.AsInt64());
        default:
            return DEFAULT_COUNTRY;
    }
}

TGender TExtractUserData::GetGender(const TNode& row) {
    auto exactSocdem = row.At("exact_socdem");
    if (!exactSocdem.IsMap()) {
        return UNKNOWN_GENDER;
    }
    auto value = exactSocdem["gender"];
    if (!value.IsString()) {
        return UNKNOWN_GENDER;
    }
    auto stringValue = value.AsString();
    if (stringValue == "m") {
        return MALE;
    } else if (stringValue == "f") {
        return FEMALE;
    } else {
        ythrow yexception() << "Unknown gender " << stringValue;
    }
}

TAge TExtractUserData::GetAge(const TNode& row) {
    auto exactSocdem = row.At("exact_socdem");
    if (!exactSocdem.IsMap()) {
        return UNKNOWN_AGE;
    }
    auto value = exactSocdem["age_segment"];
    if (!value.IsString()) {
        return UNKNOWN_AGE;
    }
    auto stringValue = value.AsString();
    if (stringValue == "0_17") {
        return FROM_0_TO_17;
    } else if (stringValue == "18_24") {
        return FROM_18_TO_24;
    } else if (stringValue == "25_34") {
        return FROM_25_TO_34;
    } else if (stringValue == "35_44") {
        return FROM_35_TO_44;
    } else if (stringValue == "45_54") {
        return FROM_45_TO_54;
    } else if (stringValue == "55_99") {
        return FROM_55_TO_99;
    } else {
        ythrow yexception() << "Unknown age " << stringValue;
    }
}

TIncome TExtractUserData::GetIncome(const TNode& row) {
    auto exactSocdem = row.At("exact_socdem");
    if (!exactSocdem.IsMap()) {
        return UNKNOWN_INCOME;
    }
    auto value = exactSocdem["income_5_segment"];
    if (!value.IsString()) {
        return UNKNOWN_INCOME;
    }
    auto stringValue = value.AsString();
    if (stringValue == "A") {
        return INCOME_A;
    } else if (stringValue == "B1") {
        return INCOME_B1;
    } else if (stringValue == "B2") {
        return INCOME_B2;
    } else if (stringValue == "C1") {
        return INCOME_C1;
    } else if (stringValue == "C2") {
        return INCOME_C2;
    } else {
        ythrow yexception() << "Unknown income " << stringValue;
    }
}

TTokens TExtractUserData::GetAffinitiveSites(const TNode& row) {
    auto affinitiveSiteIds = row.At("affinitive_site_ids");
    if (!affinitiveSiteIds.IsMap()) {
        return TTokens();
    }
    TTokens result;
    for (const auto& item : affinitiveSiteIds.AsMap()) {
        auto token = result.AddToken();
        token->SetToken(item.first);
        token->SetWeight(item.second.AsDouble());
    }
    return result;
}

TTokens TExtractUserData::GetTopCommonSites(const TNode& row) {
    auto topCommonSiteIds = row.At("top_common_site_ids");
    if (!topCommonSiteIds.IsList()) {
        return TTokens();
    }
    TTokens result;
    for (const auto& item : topCommonSiteIds.AsList()) {
        auto token = result.AddToken();
        token->SetToken(ToString(item.AsUint64()));
        token->SetWeight(1.0);
    }
    return result;
}

bool TExtractUserData::ProcessMonthlyVectors(const TNode& row, TUserData& userData, time_t oldestActiveTimestamp) {
    if (!NDates::HasActiveDate(row, "days_active", oldestActiveTimestamp)) {
        return false;
    }
    UnpackVector(row.At("vector").AsString(), userData);
    return true;
}

bool TExtractUserData::ProcessApps(const TNode& row, TUserData& userData) {
    auto appsBundleIds = row.At("apps");
    if (!appsBundleIds.IsList()) {
        return false;
    }
    for (const auto& item : appsBundleIds.AsList()) {
        auto* token = userData.MutableAffinities()->MutableApps()->AddToken();
        auto appId = item.AsString();
        token->SetToken(appId);
        token->SetWeight(AppsWeights[appId]);
    }
    return true;
}

bool TExtractUserData::ProcessWords(const TNode& row, TUserData& userData) {
    SetTokenAndWeight("word", row, userData);
    return true;
}

bool TExtractUserData::ProcessHosts(const TNode& row, TUserData& userData) {
    SetTokenAndWeight("host", row, userData);
    return true;
}

bool TExtractUserData::ProcessProfilesLog(const TNode& row, TUserData& userData, const TExtractUserDataState::TProfilesLogState& state) {
    for (const auto& keyword : state.GetKeywordFields()) {
        AddSegments(row, keyword.second, keyword.first, &userData);
    }

    if (!userData.HasSegments()) {
        // это нужно, чтобы в строках без сегментов вместо Null был пустой протобуф
        auto* segments = userData.MutableSegments();
        segments -> MergeFrom(TUserData_TSegments());
    }

    auto* attributes = userData.MutableAttributes();
    attributes->SetIncome(GetIncome(row));
    attributes->SetGender(GetGender(row));
    attributes->SetAge(GetAge(row));

    auto* affinities = userData.MutableAffinities();
    affinities->MutableAffinitiveSites()->MergeFrom(GetAffinitiveSites(row));
    affinities->MutableTopCommonSites()->MergeFrom(GetTopCommonSites(row));
    return true;
}

bool TExtractUserData::ProcessYuidCid(const TNode& row, TUserData& userData) {
    userData.SetYandexuid(row["yuid"].AsString());
    userData.SetCryptaID(row["cid"].AsString());

    return true;
}

void TExtractUserData::Start(TWriter*) {
    // Userdata by yandexuid mapper spec doesn't have external table
    // TODO(terekhinam): remove condition after turning off userData by yandexuids
    if (State->GetUseAppsWeights()) {
        TFileInput tableDump("apps_weights");
        auto reader = CreateTableReader<TNode>(&tableDump);
        for (auto& cursor : *reader) {
            const auto& row = cursor.GetRow();
            AppsWeights[row["App"].AsString()] = row["Weight"].AsUint64();
        }
    }
}

void TExtractUserData::Do(TTableReader<TNode>* input, TTableWriter<TUserData>* output) {
    NCrypta::TRestSampler restSampler(State->GetSampler().GetDenominator(), State->GetSampler().GetRest());

    const auto& profilesLogState = State->GetProfilesLogState();
    TUserData userData;

    for (; input->IsValid(); input->Next()) {
        const auto& row = input->GetRow();

         const auto& idWithType = GetIdWithType(row);
         if (idWithType.Empty()) {
            continue;
         }
         const auto& id = idWithType->Id;
         const auto& type = idWithType->Type;

         if (!restSampler.Passes(id)) {
            continue;
         }

        userData.Clear();
        bool writeToTable = false;
        const auto tableIndex = input->GetTableIndex();

        if (tableIndex == 0) {
            writeToTable = ProcessMonthlyVectors(row, userData, State->GetOldestTimestamp());
        } else if (tableIndex == 1) {
            writeToTable = ProcessRegionsAndDevice(row, userData, type);
        } else if (tableIndex == 2) {
            writeToTable = ProcessProfilesLog(row, userData, profilesLogState);
        } else if (tableIndex == 3) {
            if (type == TId::EIdType::CryptaId) {
                writeToTable = ProcessDeviceCryptaID(row, userData);
            } else if (type == TId::EIdType::Yandexuid) {
                writeToTable = ProcessYuidCid(row, userData);
            }
        } else if (tableIndex == 4) {
            writeToTable = ProcessWords(row, userData);
        } else if (tableIndex == 5) {
            writeToTable = ProcessHosts(row, userData);
        } else if (tableIndex == 6 && type == TId::EIdType::CryptaId) {
            writeToTable = ProcessApps(row, userData);
        } else {
            ythrow yexception() << "Invalid table index " << tableIndex;
        }

        if (writeToTable) {
            if (type == TId::EIdType::Yandexuid){
                userData.SetYandexuid(ToString(id));
            } else {
                userData.SetCryptaID(ToString(id));
            }
            output->AddRow(userData);
        }
    }
}
