#include "zone_gator.h"
#include <kernel/keyinv/hitlist/full_pos.h>


TZoneGator::TZoneGator(const float* wordWeight, const ui32 wordWeightSize, const ISentenceZonesReader* sentenceZonesReader)
    : SentenceZonesReader(sentenceZonesReader)
{
    float sum = 1e-6f;
    for (size_t i = 0; i < wordWeightSize; ++i) {
        sum += wordWeight[i];
    }
    Y_ASSERT(sum > 0.f);
    NormalizedWordWeight.resize(wordWeightSize);
    for (size_t i = 0; i < wordWeightSize; ++i) {
        NormalizedWordWeight[i] = wordWeight[i] / sum;
    }

    Bm25Levels.resize(SentenceZonesReader->GetNumberOfZones());
    for (ui32 i = 0; i < SentenceZonesReader->GetNumberOfZones(); ++i) {
        Bm25Levels[i].Init(wordWeightSize);
    }
}


void TZoneGator::InitNextDoc(TNewDocParams* gatorParams) {
    DocId = gatorParams->DocId;
    for (ui32 i = 0; i < SentenceZonesReader->GetNumberOfZones(); ++i) {
        Bm25Levels[i].NewDoc();
    }
}


void TZoneGator::SetFactorStorage(TFactorStorage* factors) {
    Factors = factors;
}


void TZoneGator::AddPositions(TFullPositionEx* pos, const size_t count, ERelevanceType rt) {
    if (rt == RT_TEXT) {
        for (TFullPositionEx *it = pos ; it < pos + count ; ++it) {
            Add(it->Pos, it->WordIdx);
        }
    }
}


void TZoneGator::Add(const TFullPosition& pos, const size_t wordIdx) {
    const ui32 sent = TWordPosition::Break(pos.Beg);
    const TSentenceZones zones = SentenceZonesReader->GetSentZones(DocId, sent);
    ui32 mask = 1;
    for (ui32 i = 0; i < SentenceZonesReader->GetNumberOfZones(); ++i) {
        if (zones & mask) {
            Bm25Levels[i].Add(pos, wordIdx);
        }
        mask <<= 1;
    }
}


void TZoneGator::CalcFeatures(TPosGatorCalcFeaturesParams&) {
    for (ui32 i = 0; i < SentenceZonesReader->GetNumberOfZones(); ++i) {
        const ui32 zoneLength = SentenceZonesReader->GetZoneLength(DocId, i);
        const ui32 zoneAvgLength = SentenceZonesReader->GetZoneAvgLength(i);
        ui32 zoneLengthIndex;
        if (SentenceZonesReader->GetFeatureIndex(NZoneFactors::zftZL, i, EQUAL_BY_STRING, zoneLengthIndex)) {
            (*Factors)[zoneLengthIndex] = zoneLength;
        }
        for (ui32 j = EQUAL_BY_STRING; j <= EQUAL_BY_SYNONYM; ++j) {
            ui32 featureIndex;
            if (SentenceZonesReader->GetFeatureIndex(NZoneFactors::zftBM25, i, (EFormClass)j, featureIndex)) {
                (*Factors)[featureIndex] = Bm25Levels[i].CalcScore(NormalizedWordWeight, j, zoneLength, zoneAvgLength);
            }
            if (SentenceZonesReader->GetFeatureIndex(NZoneFactors::zftCM, i, (EFormClass)j, featureIndex)) {
                (*Factors)[featureIndex] = Bm25Levels[i].CalcCmScore(j);
            }
            if (SentenceZonesReader->GetFeatureIndex(NZoneFactors::zftCZ, i, (EFormClass)j, featureIndex)) {
                (*Factors)[featureIndex] = Bm25Levels[i].CalcCZScore(j);
            }
            if (SentenceZonesReader->GetFeatureIndex(NZoneFactors::zftCZL, i, (EFormClass)j, featureIndex)) {
                if (zoneLength > 0) {
                    (*Factors)[featureIndex] = Bm25Levels[i].CalcCZScore(j) / zoneLength;
                } else {
                    (*Factors)[featureIndex] = 0;
                }
            }
            if (SentenceZonesReader->GetFeatureIndex(NZoneFactors::zftIF, i, (EFormClass)j, featureIndex)) {
                (*Factors)[featureIndex] = Bm25Levels[i].CalcInvFreq(j);
            }
        }
    }
}
