#include <crypta/lab/lib/native/matching.h>
#include <library/cpp/digest/md5/md5.h>
#include <util/generic/maybe.h>
#include <util/generic/set.h>
#include <util/generic/hash.h>
#include <contrib/libs/openssl/include/openssl/sha.h>
#include <util/stream/str.h>
#include <util/stream/format.h>
#include <crypta/lib/native/identifiers/lib/generic.h>



TString GetSHA256(const TString& value) {
    unsigned char digest[SHA256_DIGEST_LENGTH];
    SHA256((const unsigned char*)value.data(), value.size(), digest);
    TStringStream stream;
    for (size_t i = 0; i < SHA256_DIGEST_LENGTH; ++i) {
        stream << Hex(digest[i], HF_FULL);
    }

    return stream.Str();
}

struct IIdValueMapper {
    virtual ~IIdValueMapper() {
    }
    virtual TMaybe<TString> Map(const TString& raw) const = 0;
};

struct TMd5IdValueMapper: public IIdValueMapper {
    TMaybe<TString> Map(const TString& raw) const override {
        return NIdentifiers::TGenericID("md5", MD5::Calc(raw)).Normalize();
    }
};

struct TSha256IdValueMapper: public IIdValueMapper {
    TMaybe<TString> Map(const TString& raw) const override {
        return NIdentifiers::TGenericID("sha256", GetSHA256(raw)).Normalize();
    }
};

struct TIdentityIdValueMapper: public IIdValueMapper {
    TMaybe<TString> Map(const TString& raw) const override {
        return raw;
    }
};

struct TGenericIDMapper: public IIdValueMapper {
    TGenericIDMapper() {
        type = NCrypta::NIdentifiersProto::NIdType::DEFAULT;
    }

    TGenericIDMapper(const NCrypta::NIdentifiersProto::NIdType::EIdType& identifierType) {
        type = identifierType;
    }

    TMaybe<TString> Map(const TString& raw) const override {
        const NIdentifiers::TGenericID genericId(type, raw);
        TMaybe<TString> result;
        result.Clear();
        if (genericId.IsValid())
            result = genericId.Normalize();
        return result;
    }

private:
    NCrypta::NIdentifiersProto::NIdType::EIdType type;
};

THolder<IIdValueMapper> GetMapper(const NLab::TMatchingOptions& matchingOptions) {
    auto labType = ConvertLabType(matchingOptions.GetIdType());
    if (matchingOptions.GetHashingMethod() == NLab::HM_MD5) {
        labType = MakeMaybe(NCrypta::NIdentifiersProto::NIdType::MD5);
    } else if (matchingOptions.GetHashingMethod() == NLab::HM_SHA256) {
        labType = MakeMaybe(NCrypta::NIdentifiersProto::NIdType::SHA256);
    }
    if (labType)
        return MakeHolder<TGenericIDMapper>(labType.GetRef());
    ythrow yexception() << "Unsupported lab type";
}

THolder<IIdValueMapper> GetHasher(NLab::EHashingMethod hashingMethod) {
    switch (hashingMethod) {
        case NLab::HM_IDENTITY:
            return MakeHolder<TIdentityIdValueMapper>();
        case NLab::HM_MD5:
            return MakeHolder<TMd5IdValueMapper>();
        case NLab::HM_SHA256:
            return MakeHolder<TSha256IdValueMapper>();
        default:
            ythrow yexception() << "Unsupported hashing method";
    }
}

TSet<TString> GetTypes(const NLab::TMatchingOptions& matchingOptions) {
    TSet<TString> types;
    auto idType = matchingOptions.GetIdType();
    auto hashingMethod = matchingOptions.GetHashingMethod();
    auto includeMd5 = (hashingMethod == NLab::HM_MD5);
    switch (idType) {
        case NLab::LAB_ID_LOGIN:
            types.insert("login");
            break;
        case NLab::LAB_ID_CRYPTA_ID:
            types.insert("crypta_id");
            break;
        case NLab::LAB_ID_YANDEXUID:
            types.insert("yandexuid");
            break;
        case NLab::LAB_ID_ICOOKIE:
            types.insert("icookie");
            break;
        case NLab::LAB_ID_IDFA_GAID:
            types.insert("idfa_gaid");
            types.insert("idfa");
            types.insert("gaid");
            break;
        case NLab::LAB_ID_MM_DEVICE_ID:
            types.insert("mm_device_id");
            break;
        case NLab::LAB_ID_EMAIL:
            types.insert("email");
            if (includeMd5) {
                types.insert("email_md5");
            }
            break;
        case NLab::LAB_ID_PHONE:
            types.insert("phone");
            if (includeMd5) {
                types.insert("phone_md5");
            }
            break;
        case NLab::LAB_ID_PUID:
            types.insert("puid");
            break;
        case NLab::LAB_ID_UUID:
            types.insert("uuid");
            break;
        case NLab::LAB_ID_DIRECT_CLIENT_ID:
            // ClientIDs are pre-mapped into logins
            types.insert("login");
            break;
        default:
            break;
    }
    return types;
}

void TComputeMatchingIdMapper::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    auto mapper = GetMapper(State->GetSource());
    auto hasher = GetHasher(State->GetSource().GetHashingMethod());
    auto idKey = State->GetSource().GetKey();
    TString mappedRowName = State->GetSource().GetIdType() == NLab::LAB_ID_CRYPTA_ID ? "cryptaId" : "id";
    if (State->GetDestination().GetScope() == NLab::IN_DEVICE) {
        mappedRowName = "id";
    }
    for (; input->IsValid(); input->Next()) {
        auto row = input->GetRow();
        auto nodeId = row[idKey];
        if (nodeId.IsNull() || nodeId.IsUndefined()) {
            continue;
        }
        auto identifier = nodeId.ConvertTo<TString>();
        TMaybe<TString> mapResult = mapper->Map(identifier);
        if (mapResult.Defined()) {
            if (State->GetDestination().GetIncludeOriginal()) {
                if (row.HasKey(mappedRowName)) {
                    row["_" + mappedRowName] = row[mappedRowName];
                }
                row[mappedRowName] = mapResult.GetRef();
                output->AddRow(row);
                if (State->GetSource().GetHashingMethod() != NLab::HM_IDENTITY) {
                    row[mappedRowName] = hasher->Map(mapResult.GetRef()).GetRef();
                    output->AddRow(row);
                }
            } else {
                TNode result;
                result[mappedRowName] = mapResult.GetRef();
                output->AddRow(result);
                if (State->GetSource().GetHashingMethod() != NLab::HM_IDENTITY) {
                    result[mappedRowName] = hasher->Map(mapResult.GetRef()).GetRef();
                    output->AddRow(result);
                }
            }
        }
    }
}

void TJoinCryptaIDReducer::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    auto srcTypes = GetTypes(State->GetSource());
    auto scrIdType = State->GetSource().GetIdType();
    TMaybe<TString> cryptaID;
    TString idType;
    TString mappedRowName = State->GetSource().GetIdType() == NLab::LAB_ID_CRYPTA_ID ? "cryptaId" : "id";
    if (State->GetDestination().GetScope() == NLab::IN_DEVICE) {
        mappedRowName = "id";
    }
    for (; input->IsValid(); input->Next()) {
        auto row = input->GetRow();
        if (input->GetTableIndex() == 0) {
            if (scrIdType == NLab::LAB_ID_CRYPTA_ID) {
                continue;
            }
            idType = row["id_type"].ConvertTo<TString>();
            if (srcTypes.contains(idType))
                cryptaID = row["cryptaId"].ConvertTo<TString>();
            continue;
        }
        if (input->GetTableIndex() == 1) {
            if (scrIdType == NLab::LAB_ID_CRYPTA_ID) {
                output->AddRow(row);
                continue;
            }
            if (!cryptaID) {
                return;
            }
            TNode result;
            if (State->GetDestination().GetIdType() == NLab::LAB_ID_CRYPTA_ID) {
                auto dstIdKey = State->GetDestination().GetKey();
                if (State->GetDestination().GetHashingMethod() != NLab::HM_IDENTITY) {
                    auto hasher = GetHasher(State->GetDestination().GetHashingMethod());
                    cryptaID = hasher->Map(cryptaID.GetRef()).GetRef();
                }
                if (State->GetDestination().GetIncludeOriginal()) {
                    if (row.HasKey("_" + mappedRowName)) {
                        row[mappedRowName] = row["_" + mappedRowName];
                        row.AsMap().erase("_" + mappedRowName);
                    } else {
                        row.AsMap().erase(mappedRowName);
                    }
                    row[dstIdKey] = cryptaID.GetRef();
                    output->AddRow(row);
                } else {
                    result[dstIdKey] = cryptaID.GetRef();
                    output->AddRow(result);
                }
            } else {
                if (State->GetDestination().GetIncludeOriginal()) {
                    result = row;
                    if (result.HasKey("cryptaId")) {
                        result["_cryptaId"] = result["cryptaId"];
                        row.AsMap().erase("cryptaId");
                    }
                }

                result["cryptaId"] = cryptaID.GetRef();
                output->AddRow(result);
            }
            continue;
        }
    }
}

void TJoinIdentifiersReducer::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    auto hasher = GetHasher(State->GetDestination().GetHashingMethod());
    auto dstTypes = GetTypes(State->GetDestination());
    auto dstIdKey = State->GetDestination().GetKey();
    auto srcIdType = State->GetSource().GetIdType();

    TSet<TString> identifiers;
    bool need_indevice = State->GetDestination().GetScope() == NLab::IN_DEVICE;
    TString mappedRowName = srcIdType == NLab::LAB_ID_CRYPTA_ID ? "cryptaId" : "id";
    if (State->GetDestination().GetScope() == NLab::IN_DEVICE) {
        mappedRowName = "id";
    }
    for (; input->IsValid(); input->Next()) {
        auto row = input->GetRow();
        if (input->GetTableIndex() == 0) {
            if (need_indevice || dstTypes.contains(row["id_type"].ConvertTo<TString>())) {
                TString id;
                if (need_indevice) {
                    id = row["target_id"].ConvertTo<TString>();
                } else {
                    id = row["id"].ConvertTo<TString>();
                }
                identifiers.insert(hasher->Map(id).GetRef());
            }
            continue;
        }
        if (input->GetTableIndex() == 1) {
            if (row.HasKey("_" + mappedRowName)) {
                row[mappedRowName] = row["_" + mappedRowName];
                row.AsMap().erase("_" + mappedRowName);
            } else {
                row.AsMap().erase(mappedRowName);
            }
            if (State->GetDestination().GetIncludeOriginal()) {
                for (const auto& identifier : identifiers) {
                    // TODO this is some hack, needs better way
                    if (srcIdType != NLab::LAB_ID_CRYPTA_ID) {
                        row.AsMap().erase("cryptaId");
                    }
                    row[dstIdKey] = identifier;
                    output->AddRow(row);
                }
            }
            continue;
        }
    }
    if (!State->GetDestination().GetIncludeOriginal()) {
        for (const auto& identifier : identifiers) {
            TNode result;
            result[dstIdKey] = identifier;
            output->AddRow(result);
        }
    }
}

void TJoinIdentifiersStatisticReducer::Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) {
    auto dstIdKey = State->GetDestination().GetKey();
    TNode identifierCountResult;
    THashMap<TString, ui64> identifierCount;
    TString mappedRowName = State->GetSource().GetIdType() == NLab::LAB_ID_CRYPTA_ID ? "cryptaId" : "id";
    if (State->GetDestination().GetScope() == NLab::IN_DEVICE) {
        mappedRowName = "id";
    }
    for (; input->IsValid(); input->Next()) {
        auto row = input->GetRow();
        if (input->GetTableIndex() == 0) {
            ++identifierCount[row["id_type"].ConvertTo<TString>()];
            continue;
        }
        if (input->GetTableIndex() == 1) {
            if (State->GetDestination().GetIncludeOriginal()) {
                if (!identifierCount.size()) {
                    row[dstIdKey] = TNode::CreateEntity();
                } else {
                    for (const auto& it : identifierCount) {
                        identifierCountResult[it.first] = it.second;
                    }
                    row[dstIdKey] = identifierCountResult;
                }
                if (row.HasKey("_" + mappedRowName)) {
                    row[mappedRowName] = row["_" + mappedRowName];
                    row.AsMap().erase("_" + mappedRowName);
                } else {
                    row.AsMap().erase(mappedRowName);
                }
                row["ccIdType"] = "cryptaId";
                if (row.HasKey("cryptaId")) {
                    row.AsMap().erase("cryptaId");
                }
                output->AddRow(row);
            }
            continue;
        }
    }
    if (!State->GetDestination().GetIncludeOriginal()) {
        for (const auto& it : identifierCount) {
            identifierCountResult[it.first] = it.second;
        }
        TNode result;
        result[dstIdKey] = identifierCountResult;
        result["ccIdType"] = "cryptaId";
        output->AddRow(result);
    }
}
