#include "bm25levels.h"
#include <kernel/keyinv/hitlist/full_pos.h>
#include <library/cpp/wordpos/wordpos.h>
#include <util/system/yassert.h>


void TBm25LevelsTracker::Init(const size_t queryWordNumber) {
    Y_ASSERT(queryWordNumber > 0);
    QueryWordNumber = queryWordNumber;

    for (ui32 i = EQUAL_BY_STRING; i <= EQUAL_BY_SYNONYM; ++i) {
        WordFreq[i].resize(queryWordNumber);
    }

    NewDoc();
}


void TBm25LevelsTracker::NewDoc() {
    for (ui32 i = EQUAL_BY_STRING; i <= EQUAL_BY_SYNONYM; ++i) {
        for (size_t j = 0; j < QueryWordNumber; ++j) {
            WordFreq[i][j] = 0;
        }
    }
}


void TBm25LevelsTracker::Add(const TFullPosition& pos, const size_t wordIdx) {
    Y_ASSERT(wordIdx < QueryWordNumber);

    const ui32 form = TWordPosition::Form(pos.End);
    if (form <= EQUAL_BY_SYNONYM) {
        ++WordFreq[EQUAL_BY_SYNONYM][wordIdx];
        if (form <= EQUAL_BY_LEMMA) {
            ++WordFreq[EQUAL_BY_LEMMA][wordIdx];
            if (form <= EQUAL_BY_STRING) {
                ++WordFreq[EQUAL_BY_STRING][wordIdx];
            }
        }
    }
}


float TBm25LevelsTracker::CalcScore(const TVector<float>& wordWeight, const ui32 form, const ui32 zoneLength, const ui32 avgZoneLength) {
    Y_ASSERT(form <= EQUAL_BY_SYNONYM);
    Y_ASSERT(avgZoneLength > 0);
    Y_ASSERT(wordWeight.size() == WordFreq[form].size());

    float score = 0;
    for (size_t i = 0; i < wordWeight.size(); ++i) {
        const float tf = (float)WordFreq[form][i] / ((float)WordFreq[form][i] + 2.f * (1.f - 0.75f + 0.75f * (float)zoneLength / (float)avgZoneLength));
        score += tf * wordWeight[i];
    }

    return score;
}

float TBm25LevelsTracker::CalcCmScore(const ui32 form) {
    return (float)CalcCZScore(form) / (float)QueryWordNumber;
}

float TBm25LevelsTracker::CalcCZScore(const ui32 form) {
    Y_ASSERT(form <= EQUAL_BY_SYNONYM);

    size_t matchedWords = 0;
    for (size_t i = 0; i < QueryWordNumber; ++i) {
        if (WordFreq[form][i] > 0) {
            ++matchedWords;
        }
    }

    return (float)matchedWords;
}

float TBm25LevelsTracker::CalcInvFreq(const ui32 form) {
    Y_ASSERT(form <= EQUAL_BY_SYNONYM);

    float inv = 0;
    size_t count = 0;
    for (size_t i = 0; i < QueryWordNumber; ++i) {
        const ui32 frequency = WordFreq[form][i];
        if (!frequency)
            continue;

        inv += 1.0f / frequency;
        count += 1;
    }

    return count ? inv / count : 0;
}
