#include "station_segments_index.h"

using namespace NRasp;

TStationSegmentsIndex::TStationSegmentsIndex(
    const TWrappedRaspDatabase& database,
    const TRaspSearchIndex& searchIndex,
    const TTransportSet& transportTypes) {
    THashSet<object_id_t> stationIds;

    for (const auto& threadStation : database.GetItems<TThreadStationWrapper>()) {
        if (threadStation.HasStation() && threadStation.HasThread() && transportTypes.contains(threadStation.Thread().TransportType())) {
            stationIds.insert(threadStation.StationId());
        }
    }

    auto minEventTM = NDatetime::TSimpleTM::CurrentUTC();
    auto maxEventTM = minEventTM;
    maxEventTM.Add(NDatetime::TSimpleTM::EField::F_DAY, 7);

    for (const auto stationId : stationIds) {
        Index[stationId] = {
            FindSegments(
                searchIndex, stationId, minEventTM, maxEventTM, TCommonSegmentFinder::EEventType::Departure),
            FindSegments(
                searchIndex, stationId, minEventTM, maxEventTM, TCommonSegmentFinder::EEventType::Arrival)};
    }
}

TSegments TStationSegmentsIndex::FindSegments(
    const TRaspSearchIndex& searchIndex,
    const object_id_t stationId,
    const NDatetime::TSimpleTM& minEventTM,
    const NDatetime::TSimpleTM& maxEventTM,
    const TCommonSegmentFinder::EEventType eventType) const {
    const TCommonSegmentFinder finder;
    const TDateSegmentFilter filter;
    const auto rawSegments = finder.FindStationSegments(searchIndex, stationId, eventType);
    const auto segments = filter.Filter(rawSegments, minEventTM, maxEventTM, eventType, false, 0);

    TSegments result;
    for (const auto& segmentHolder : segments) {
        result.emplace_back(*segmentHolder);
    }
    return result;
}

TArrayRange TStationSegmentsIndex::GetSegmentsRange(
    const TSegments& segments,
    const NDatetime::TSimpleTM& minEventTM,
    const NDatetime::TSimpleTM& maxEventTM,
    const TSegmentEventTMGetter getEventTM) const {
    const auto predicate = [&getEventTM](const TSegment& segment, const NDatetime::TSimpleTM& eventTM) {
        return getEventTM(segment) < eventTM;
    };
    const auto beginIt = std::lower_bound(segments.begin(), segments.end(), minEventTM, predicate);
    const auto endIt = std::lower_bound(beginIt, segments.end(), maxEventTM, predicate);
    return {static_cast<size_t>(beginIt - segments.begin()), static_cast<size_t>(endIt - segments.begin())};
}

TStationSegmentsWithRanges TStationSegmentsIndex::GetSegments(
    const object_id_t stationId,
    const NDatetime::TSimpleTM& minEventTM,
    const NDatetime::TSimpleTM& maxEventTM) const {
    auto stationSegmentsIt = Index.find(stationId);
    if (stationSegmentsIt == Index.end()) {
        return {};
    }

    const auto& [departureSegments, arrivalSegments] = stationSegmentsIt->second;
    return {
        {&departureSegments, GetSegmentsRange(departureSegments, minEventTM, maxEventTM, TSegment::GetDepartureTM)},
        {&arrivalSegments, GetSegmentsRange(arrivalSegments, minEventTM, maxEventTM, TSegment::GetArrivalTM)}};
}

TStationSegmentsWithRanges TStationSegmentsIndex::GetSegments(const TStationQuery& query) const {
    return GetSegments(query.StationId, query.MinEventTM, query.MaxEventTM);
}
