#include <library/cpp/timezone_conversion/convert.h>
#include "segment_finders.h"

using namespace NRasp;

auto segmentKey(const THolder<TRawSegment>& segment) {
    const auto& departure = segment->Departure();
    const auto& arrival = segment->Arrival();

    return std::make_tuple(
        departure.StationMajority(),
        arrival.StationMajority(),
        departure.ArrivalOffset().Defined(),
        -(departure.DepartureOffset().GetOrElse(0)),
        arrival.ArrivalOffset().GetOrElse(0));
};

TVector<THolder<TRawSegment>> TCommonSegmentFinder::Find(
    const TRaspSearchIndex& searchIndex,
    const TPointKey& fromPointKey,
    const TPointKey& toPointKey,
    const TLimitConditions& majorityLimit,
    const TString& threadNumber) const {
    TVector<THolder<TRawSegment>> result;
    const TPointStops& threadDepartures = searchIndex.GetPointStops(fromPointKey);
    const TPointStops& threadArrivals = searchIndex.GetPointStops(toPointKey);

    TVector<object_id_t> commonThreads = threadDepartures.GetCommonKeysWith(threadArrivals);

    for (auto threadId : commonThreads) {
        const auto& thread = searchIndex.GetItemWithId<TRThreadWrapper>(threadId);

        if (!threadNumber.Empty() && thread.Number() != threadNumber) {
            continue;
        }

        const auto& departures = threadDepartures.at(threadId);
        const auto& arrivals = threadArrivals.at(threadId);

        auto threadSegments = FindSegmentsForThread(departures, arrivals, thread, majorityLimit);

        if (threadSegments.empty()) {
            continue;
        }

        auto bestSegmentId = MinElementBy(threadSegments, segmentKey) - threadSegments.begin();

        result.emplace_back(std::move(threadSegments[bestSegmentId]));
    }
    return result;
}

TVector<THolder<TRawSegment>> TCommonSegmentFinder::FindStationSegments(
    const TRaspSearchIndex& searchIndex,
    const object_id_t stationId,
    const EEventType eventType) const {
    TVector<THolder<TRawSegment>> result;

    const auto& stops = searchIndex.GetStationStops(stationId);

    for (const auto& [threadId, threadStops] : stops) {
        const auto& thread = searchIndex.GetItemWithId<TRThreadWrapper>(threadId);
        if (thread.TypeId() == TRThread::INTERVAL_ID || thread.TypeId() == TRThread::THROUGH_TRAIN_ID) {
            continue;
        }

        if (eventType == EEventType::Departure) {
            const auto arrivalStop = searchIndex.GetThreadArrivalStop(thread);
            if (arrivalStop.Defined()) {
                for (const auto stopPtr : threadStops) {
                    const auto& stopItem = stopPtr->Item();
                    if (stopItem.has_departure() && stopItem.in_station_schedule() && !stopItem.departure_code_sharing()) {
                        auto segment = MakeHolder<TRawSegment>(*stopPtr, *arrivalStop.GetRef(), thread);
                        result.emplace_back(std::move(segment));
                    }
                }
            }
        } else if (eventType == EEventType::Arrival) {
            const auto departureStop = searchIndex.GetThreadDepartureStop(thread);
            if (departureStop.Defined()) {
                for (const auto stopPtr : threadStops) {
                    const auto& stopItem = stopPtr->Item();
                    if (stopItem.has_arrival() && stopItem.in_station_schedule() && !stopItem.arrival_code_sharing()) {
                        auto segment = MakeHolder<TRawSegment>(*departureStop.GetRef(), *stopPtr, thread);
                        result.emplace_back(std::move(segment));
                    }
                }
            }
        } else {
            ythrow TUnknownEventTypeException() << "Unknown event type";
        }
    }

    return result;
}

TVector<THolder<TRawSegment>> TCommonSegmentFinder::FindSegmentsForThread(
    const TVector<const TThreadStationWrapper*>& departures,
    const TVector<const TThreadStationWrapper*>& arrivals,
    const TRThreadWrapper& thread,
    const TLimitConditions& majorityLimit) const {
    TVector<THolder<TRawSegment>> threadSegments;
    auto departuresIt = departures.begin();
    auto arrivalIt = arrivals.begin();

    for (; departuresIt != departures.end(); departuresIt++) {
        auto& departure = *departuresIt;
        if (!departure->IsSearchableFrom()) {
            continue;
        }

        while (arrivalIt != arrivals.end() && (*arrivalIt)->Id() <= departure->Id())
            arrivalIt++;

        for (auto it = arrivalIt; it != arrivals.end(); it++) {
            auto& arrival = *it;
            if (!arrival->IsSearchableTo()) {
                continue;
            }
            if (!majorityLimit.Allow(
                    departure->StationMajority(),
                    arrival->StationMajority(),
                    thread.TransportType()))
            {
                continue;
            }
            if (departure->DepartureCodeSharing() && arrival->ArrivalCodeSharing()) {
                continue;
            }

            auto segment = MakeHolder<TRawSegment>(*departure, *arrival, thread);
            threadSegments.emplace_back(std::move(segment));
        }
    }
    return threadSegments;
}

TVector<THolder<TRawSegment>> TThroughTrainsFilter::Filter(
    const TRaspSearchIndex& searchIndex,
    TVector<THolder<TRawSegment>>&& threadSegments) const {
    THashMap<ui64, TVector<THolder<TRawSegment>>> basicTrains, throughTrains;
    TVector<THolder<TRawSegment>> result;

    TRawSegmentKeyGetter keyGetter(searchIndex.GetItems<TStationWrapper>().ysize());

    for (auto&& segment : threadSegments) {
        const auto& thread = segment->Thread();

        auto key = keyGetter(*segment);

        if (thread.TypeId() == TRThread::THROUGH_TRAIN_ID) {
            if (basicTrains.contains(key))
                continue;
            throughTrains[key].emplace_back(std::move(segment));
        } else {
            if (throughTrains.contains(key))
                throughTrains.erase(key);
            basicTrains[key].emplace_back(std::move(segment));
        }
    }
    for (auto& segments : basicTrains) {
        for (auto& segment : segments.second)
            result.emplace_back(std::move(segment));
    }
    for (auto& segments : throughTrains) {
        for (auto& segment : segments.second)
            result.emplace_back(std::move(segment));
    }
    return result;
}

NDatetime::TSimpleTM NRasp::CalculateThreadStart(const NDatetime::TTimeZone& timezone,
                                                 const NDatetime::TSimpleTM& minEventTM,
                                                 ui32 threadStartTime,
                                                 i64 eventOffset) {
    using namespace NDatetime;

    // FIXME AsTimeT неправильно работает с летним временем, нужно использовать TInstant и для этого переделать TScheduleRangeProvider::GetRange
    auto minThreadStartTM = ChangeTimeZone(TSimpleTM::New(minEventTM.AsTimeT()), GetUtcTimeZone(), timezone);

    minThreadStartTM.Add(TSimpleTM::F_MIN, -eventOffset);

    auto threadStartTM = CreateCivilTime(
        timezone, minThreadStartTM.RealYear(), minThreadStartTM.RealMonth(), minThreadStartTM.MDay);

    threadStartTM.Add(TSimpleTM::F_MIN, threadStartTime);

    if (threadStartTM < minThreadStartTM) {
        threadStartTM.Add(TSimpleTM::F_DAY, 1);
    }

    return threadStartTM;
}

TVector<THolder<TSegment>> TDateSegmentFilter::Filter(
    const TVector<THolder<TRawSegment>>& segments,
    const NDatetime::TSimpleTM minEventDate,
    const NDatetime::TSimpleTM maxEventDate,
    const TCommonSegmentFinder::EEventType eventType,
    bool findAll,
    int maxThreadSegments) const {
    using namespace NDatetime;

    constexpr ui32 maxDaysToSearch = 100;
    TVector<THolder<TSegment>> result;

    for (const auto& rawSegment : segments) {
        const auto& departure = rawSegment->Departure();
        const auto& arrival = rawSegment->Arrival();
        const auto& thread = rawSegment->Thread();

        const auto eventOffset =
            eventType == TCommonSegmentFinder::EEventType::Departure
                ? departure.DepartureOffset().GetRef()
                : arrival.ArrivalOffset().GetRef();
        const auto& eventTimezone =
            eventType == TCommonSegmentFinder::EEventType::Departure
                ? departure.Timezone()
                : arrival.Timezone();

        auto threadStartTime = rawSegment->Thread().StartTimeInMinutes();

        auto threadStartTM = CalculateThreadStart(eventTimezone, minEventDate, threadStartTime, eventOffset);
        auto maxThreadStartTM = CalculateThreadStart(eventTimezone, maxEventDate, threadStartTime, eventOffset);

        auto maxThreadStartInstant = ToAbsoluteTime(maxThreadStartTM, eventTimezone);

        int found = 0;

        for (ui32 i = 0; i < maxDaysToSearch && (maxThreadSegments == 0 || found < maxThreadSegments); ++i) {
            TInstant threadStartInstant = ToAbsoluteTime(threadStartTM, eventTimezone);
            if (!findAll && threadStartInstant > maxThreadStartInstant) {
                break;
            }
            if (thread.RunsAt(threadStartTM.RealMonth(), threadStartTM.MDay)) {
                auto segment = MakeHolder<TSegment>(*rawSegment, threadStartTM);
                result.emplace_back(std::move(segment));
                ++found;
            }
            threadStartTM.Add(NDatetime::TSimpleTM::F_DAY, 1);
        }
    }
    if (eventType == TCommonSegmentFinder::EEventType::Departure) {
        SortBy(result, [](const THolder<TSegment>& ptr) { return ptr->DepartureDt(); });
    } else {
        SortBy(result, [](const THolder<TSegment>& ptr) { return ptr->ArrivalDt(); });
    }
    return result;
}

TVector<THolder<TSegment>> TSegmentSearcher::FindDirectionSegments(
    const TRaspSearchIndex& searchIndex,
    const TDirectionQuery& query) const {
    const auto& fromPointKey = query.FromPointKey;
    const auto& toPointKey = query.ToPointKey;

    if (fromPointKey.Id() == 0 || toPointKey.Id() == 0)
        return {};

    auto majorityLimit = TLimitConditions(
        fromPointKey,
        toPointKey,
        searchIndex.GetTransportTypes(),
        searchIndex.GetItems<TStationWrapper>());

    if (majorityLimit.TransportTypes().empty())
        return {};

    auto threadSegments = Finder.Find(searchIndex, fromPointKey, toPointKey, majorityLimit, query.ThreadNumber);
    auto filteredSegments = ThroughFilter.Filter(searchIndex, std::move(threadSegments));
    auto rawSegments = DateFilter.Filter(filteredSegments, query.MinDepartureTime, query.MaxDepartureTime);

    return query.SearchDate.Empty()
               ? EmptyDateFilter.Filter(rawSegments, 3)
               : MaxDepartureFilter.Filter(rawSegments, query.MinDepartureTime);
}

TVector<THolder<TSegment>> TSegmentSearcher::FindEventSegments(
    const TRaspSearchIndex& searchIndex,
    const TStationQuery& query,
    const TCommonSegmentFinder::EEventType eventType) const {
    const auto& stationId = query.StationId;
    auto threadSegments = Finder.FindStationSegments(searchIndex, stationId, eventType);
    auto segments = DateFilter.Filter(threadSegments, query.MinEventTM, query.MaxEventTM, eventType);

    return query.EventDate.Empty()
               ? EmptyDateFilter.Filter(segments, 3)
               : MaxDepartureFilter.Filter(segments, query.MinEventTM);
}

TStationSegments TSegmentSearcher::FindStationSegments(
    const TRaspSearchIndex& searchIndex,
    const TStationQuery& query) const {
    return {
        FindEventSegments(searchIndex, query, TCommonSegmentFinder::EEventType::Departure),
        FindEventSegments(searchIndex, query, TCommonSegmentFinder::EEventType::Arrival)};
}

TVector<THolder<TSegment>>
TEmptySearchDateFilter::Filter(TVector<THolder<TSegment>>& segments,
                               size_t minCount) const {
    if (segments.empty())
        return {};
    auto localDepartureDate = segments[0]->LocalDepartureTime();
    const auto& tz = segments[0]->DepartureWrapper().Station().Timezone();

    auto [minTime, maxTime] = MinMax(
        NDatetime::CreateCivilTime(tz,
                                   localDepartureDate.RealYear(),
                                   localDepartureDate.RealMonth(),
                                   localDepartureDate.MDay)
            .Add(NDatetime::TSimpleTM::F_HOUR, 28),
        localDepartureDate.Add(NDatetime::TSimpleTM::F_DAY, 1));

    auto maxDepartureDate = NDatetime::ToAbsoluteTime(minTime, tz);
    auto rightBorder = NDatetime::ToAbsoluteTime(maxTime, tz);

    TVector<THolder<TSegment>> result;
    result.emplace_back(std::move(segments[0]));

    for (size_t i = 1; i < segments.size(); ++i) {
        if (maxDepartureDate <= segments[i]->DepartureDt()) {
            if (i >= minCount || segments[i]->DepartureDt() >= rightBorder)
                break;
            maxDepartureDate = rightBorder;
        }

        result.emplace_back(std::move(segments[i]));
    }

    return result;
}

TVector<THolder<TSegment>>
TMaxDepartureDateFilter::Filter(TVector<THolder<TSegment>>& segments, NDatetime::TSimpleTM minDate) const {
    if (segments.empty())
        return {};

    auto maxLocalDate = NDatetime::CreateCivilTime(
        segments[0]->DepartureWrapper().Timezone(),
        minDate.RealYear(),
        minDate.RealMonth(),
        minDate.MDay);

    maxLocalDate.Add(NDatetime::TSimpleTM::F_DAY, 1);

    auto absoluteMaxTime = NDatetime::ToAbsoluteTime(maxLocalDate, segments[0]->DepartureWrapper().Timezone());

    TVector<THolder<TSegment>> result;

    for (auto& segment : segments) {
        if (segment->DepartureDt() >= absoluteMaxTime) {
            break;
        }
        result.emplace_back(std::move(segment));
    }

    return result;
}
