#include "split_path.h"

#include "path_search.h"

#include <cstdlib>

namespace maps::mrc::gen_targets {

const int MAX_ATTEMPTS_FOR_OVERHEAD_SPLITTING = 1000;

TaskSplitter::TaskSplitter(const RoadNetworkData& roadNetwork,
                           const LoopsPath& path,
                           Meters minTaskLength,
                           Overhead allowedOverhead, // more overhead -> more
                                                     // chances to split
                           const std::unordered_set<EdgeId>& allTargetEdges)
    : roadNetwork_(roadNetwork)
    , sourcePath_(path)
    , minTaskLength_(minTaskLength)
    , allowedOverhead_(allowedOverhead)
    , curPathLength_(0)
    , remainingPathLength_(0)
    , allTargetEdges_(allTargetEdges)
{
    if (path.empty()) {
        return;
    }
    if (allowedOverhead_ != Overhead::BigOverhead) {
        sourcePath_.insert(sourcePath_.begin(), sourcePath_.back());
        sourcePath_.push_back(sourcePath_[1]);
        splitPath();
    } else {
        splitPathInto2TasksWithMinOverhead();
    }
}

void TaskSplitter::splitPath()
{
    std::unordered_map<LoopId, size_t> loopStart;        // in curPath
    std::unordered_map<LoopId, size_t> loopEnd;          // in sourcePath
    std::unordered_map<LoopId, size_t> loopPreviousEdge; // in curPath

    for (size_t i = 0; i < sourcePath_.size(); i++) {
        numberOfEdgeVisits_[sourcePath_[i].edgeId]++;
        remainingPathLength_
            += roadNetwork_.edge(sourcePath_[i].edgeId).length;
        loopEnd[sourcePath_[i].loopId] = i;
    }

    // Copies edges from sourcePath to curPath. If at some moment some
    // suffix of curPath looks like a loop, extracts that suffix as
    // a separate task

    curPathLength_ = roadNetwork_.edge(sourcePath_[0].edgeId).length;
    remainingPathLength_ -= curPathLength_;
    curPath_.push_back(sourcePath_[0]);
    curPathPrefixLength_.push_back(curPathLength_);

    for (size_t i = 1; i < sourcePath_.size() - 1; i++) {
        EdgeId edgeId = sourcePath_[i].edgeId;
        const Edge& edge = roadNetwork_.edge(edgeId);
        LoopId curLoopId = sourcePath_[i].loopId;

        curPathLength_ += edge.length;
        remainingPathLength_ -= edge.length;
        curPath_.push_back(sourcePath_[i]);
        curPathPrefixLength_.push_back(curPathLength_);

        if (!loopPreviousEdge.count(curLoopId)) {
            loopStart[curLoopId] = curPath_.size() - 1;
        }

        if (i == loopEnd[curLoopId]) {
            // extract last loop
            tryExtractTask(loopStart[curLoopId], sourcePath_[i + 1].edgeId);
        }

        // extract all the loops(as one task) on the current crossroad
        if (sourcePath_[i + 1].loopId != curPath_.back().loopId
            && loopPreviousEdge.count(sourcePath_[i + 1].loopId)) {
            tryExtractTask(loopPreviousEdge[sourcePath_[i + 1].loopId] + 1,
                           sourcePath_[i + 1].edgeId);
        }
        loopPreviousEdge[curPath_.back().loopId] = curPath_.size() - 1;
    }
    curPath_.push_back(sourcePath_.back());
    curPath_.erase(curPath_.begin());
    curPath_.pop_back();
    resultPaths_.push_back(curPath_);
}


// extract [curPath[loopStart], curPath.end()] as separated task
// if neccessary
void TaskSplitter::tryExtractTask(size_t loopStart, EdgeId nextEdgeInPath)
{
    if (loopStart >= curPath_.size() - 1) {
        return;
    }

    Meters pathLengthBeforeLoop = curPathPrefixLength_[loopStart - 1];

    if (roadNetwork_.edgesAreConnected(curPath_[loopStart - 1].edgeId,
                                       nextEdgeInPath)) {
        bool canDelete = true;
        for (size_t j = loopStart; j < curPath_.size(); j++) {
            if (numberOfEdgeVisits_[curPath_[j].edgeId] < 2) {
                canDelete = false;
                break;
            }
        }
        if (canDelete) {
            for (size_t j = loopStart; j < curPath_.size(); j++) {
                numberOfEdgeVisits_[curPath_[j].edgeId]--;
                curPathLength_
                    -= roadNetwork_.edge(curPath_[j].edgeId).length;
            }
            resizeCurPath(loopStart);
            return;
        }
        if (curPathLength_ - pathLengthBeforeLoop < minTaskLength_
            || remainingPathLength_ + pathLengthBeforeLoop < minTaskLength_) {
            return;
        }
        if (roadNetwork_.edgesAreConnected(curPath_.back().edgeId,
                                           curPath_[loopStart].edgeId)
            || allowedOverhead_ == Overhead::SmallOverhead) {
            extractTask(loopStart, nextEdgeInPath);
            return;
        }
    }
}

// resize curPath_ and curPathPrefixLength_
void TaskSplitter::resizeCurPath(size_t newSize) {
    curPath_.resize(newSize);
    curPathPrefixLength_.resize(newSize);
}

// extract [curPath[loopStart], curPath.end()] as separated task
void TaskSplitter::extractTask(size_t loopStart, EdgeId nextEdgeInSourcePath)
{
    LoopsPath newTask(curPath_.begin() + loopStart, curPath_.end());
    fixTask(newTask);
    resultPaths_.push_back(newTask);

    for (size_t j = loopStart; j < curPath_.size(); j++) {
        curPathLength_ -= roadNetwork_.edge(curPath_[j].edgeId).length;
    }
    resizeCurPath(loopStart);

    fixCurPath(nextEdgeInSourcePath);
}

// connects task end and begin if neccessary
void TaskSplitter::fixTask(LoopsPath& task)
{
    if (roadNetwork_.edgesAreConnected(task.back().edgeId, task[0].edgeId)) {
        return;
    }
    std::vector<EdgeId> connector
        = PathSearch(roadNetwork_, task.back().edgeId,
                     std::unordered_set<EdgeId>{task[0].edgeId},
                     allTargetEdges_)
              .getResult();
    connector.pop_back();
    for (EdgeId pathEdge : connector) {
        task.push_back(LoopEdge{pathEdge, task[0].loopId});
    }
}

// reconnects curPath with remaining sourcePath
void TaskSplitter::fixCurPath(EdgeId nextEdgeInSourcePath)
{
    if (roadNetwork_.edgesAreConnected(curPath_.back().edgeId,
                                       nextEdgeInSourcePath)) {
        return;
    }

    std::vector<EdgeId> connector
        = PathSearch(roadNetwork_, curPath_.back().edgeId,
                     std::unordered_set<EdgeId>{nextEdgeInSourcePath},
                     allTargetEdges_)
              .getResult();
    connector.pop_back();
    LoopId loopId = curPath_.back().loopId;
    for (EdgeId pathEdge : connector) {
        curPath_.push_back(LoopEdge{pathEdge, loopId});
        curPathLength_ += roadNetwork_.edge(pathEdge).length;
        curPathPrefixLength_.push_back(curPathLength_);
    }
}

void TaskSplitter::splitPathInto2TasksWithMinOverhead()
{
    LoopsPaths bestTasks{};
    Meters bestLength = std::numeric_limits<Meters>::max(); // i.e. with the smallest overhead

    // trying extract one task from sourcePath with different
    // offsets of the first edge of the extracted task
    for (size_t offset = 0; offset < MAX_ATTEMPTS_FOR_OVERHEAD_SPLITTING;
         offset++) {
        extractOneTaskWithAnyOverhead(offset);
        if (resultPaths_.size() > 1) {
            Meters l = getPathLength(roadNetwork_, resultPaths_[0])
                       + getPathLength(roadNetwork_, resultPaths_[1]);
            if (l < bestLength) {
                bestLength = l;
                bestTasks = resultPaths_;
            }
        }

        // erase data for the next attempts
        resultPaths_.resize(0);
        resizeCurPath(0);
        curPathLength_ = 0;
        remainingPathLength_ = 0;
    }
    if (bestTasks.size() > 1) {
        resultPaths_ = bestTasks;
    } else {
        resultPaths_ = LoopsPaths{sourcePath_};
    }
}

void TaskSplitter::extractOneTaskWithAnyOverhead(int taskStartOffset)
{
    std::unordered_map<LoopId, size_t> loopPreviousEdge;
    bool taskStartWasSelected = false;
    size_t taskStart;
    LoopId loopIdBeforeTask;
    Meters pathLengthBeforeTask;

    for (size_t i = 0; i < sourcePath_.size(); i++) {
        remainingPathLength_
            += roadNetwork_.edge(sourcePath_[i].edgeId).length;
    }

    // Copies edges from sourcePath to curPath. Selects two good
    // points and extract a task between that points

    curPathLength_ = roadNetwork_.edge(sourcePath_[0].edgeId).length;
    remainingPathLength_ -= curPathLength_;
    curPath_.push_back(sourcePath_[0]);
    curPathPrefixLength_.push_back(curPathLength_);
    loopPreviousEdge[sourcePath_[0].loopId] = 0;

    for (size_t i = 1; i < sourcePath_.size() - 1; i++) {
        EdgeId edgeId = sourcePath_[i].edgeId;
        const Edge& edge = roadNetwork_.edge(edgeId);
        LoopId curLoopId = sourcePath_[i].loopId;

        curPathLength_ += edge.length;
        remainingPathLength_ -= edge.length;
        curPath_.push_back(sourcePath_[i]);
        curPathPrefixLength_.push_back(curPathLength_);

        // takes the point on loop loopIdBeforeTask
        if (!loopPreviousEdge.count(curLoopId)) {
            if (taskStartOffset == 0) {
                taskStartWasSelected = true;
                taskStart = curPath_.size() - 1;
                loopIdBeforeTask = curPath_[curPath_.size() - 2].loopId;
                pathLengthBeforeTask = curPathPrefixLength_[taskStart - 1];
            }
            taskStartOffset--;
        }

        // takes another point on loop loopIdBeforeTask
        if (taskStartWasSelected
            && sourcePath_[i + 1].loopId != curPath_.back().loopId
            && sourcePath_[i + 1].loopId == loopIdBeforeTask
            && curPathLength_ - pathLengthBeforeTask > minTaskLength_
            && remainingPathLength_ + pathLengthBeforeTask > minTaskLength_) {
            // extract the path from the first point to the second point
            extractTask(taskStart, sourcePath_[i + 1].edgeId);
            taskStartWasSelected = false; // don't extract task anymore
        }
        loopPreviousEdge[curPath_.back().loopId] = curPath_.size() - 1;
    }
    curPath_.push_back(sourcePath_.back());
    resultPaths_.push_back(curPath_);
}

// return true if all the new tasks are smaller the the source task
bool splitIsGood(const RoadNetworkData& roadNetwork,
                 const LoopsPath& sourceTask,
                 const LoopsPaths& resultTasks)
{
    bool splitIsGood = true;
    for (const auto& resultTask : resultTasks) {
        if (getPathLength(roadNetwork, resultTask)
            >= getPathLength(roadNetwork, sourceTask)) {
            splitIsGood = false;
            break;
        }
    }
    return splitIsGood;
}

LoopsPaths splitLoopRouteIntoTasks(const RoadNetworkData& roadNetwork,
                                   const LoopsPath& path,
                                   Meters minTaskLength)
{
    std::unordered_set<EdgeId> allTargetEdges;
    for (auto& edge : roadNetwork.edges()) {
        if (edge.second.isTarget) {
            allTargetEdges.insert(edge.first);
        }
    }

    // trying to split path without overhead
    LoopsPaths tasks
        = TaskSplitter(roadNetwork, path, minTaskLength,
                       TaskSplitter::Overhead::NoOverhead, allTargetEdges)
              .getResult();
    for (size_t i = 0; i < tasks.size(); i++) {
        if (getPathLength(roadNetwork, tasks[i]) > 2 * minTaskLength) {
            // trying to split tasks[i] with small overhead
            LoopsPaths newTasks
                = TaskSplitter(roadNetwork, tasks[i], minTaskLength,
                               TaskSplitter::Overhead::SmallOverhead,
                               allTargetEdges)
                      .getResult();
            if (splitIsGood(roadNetwork, tasks[i], newTasks)) {
                tasks[i] = newTasks[0];
                tasks.insert(tasks.end(), newTasks.begin() + 1,
                             newTasks.end());
            }
        }
        while (getPathLength(roadNetwork, tasks[i]) > 2 * minTaskLength) {
            // split tasks[i] with big overhead
            LoopsPaths newTasks
                = TaskSplitter(roadNetwork, tasks[i], minTaskLength,
                               TaskSplitter::Overhead::BigOverhead,
                               allTargetEdges)
                      .getResult();
            if (splitIsGood(roadNetwork, tasks[i], newTasks)) {
                tasks[i] = newTasks[0];
                tasks.insert(tasks.end(), newTasks.begin() + 1,
                             newTasks.end());
            } else {
                break;
            }
        }
    }

    return tasks;
}

} // namespace maps::mrc::gen_targets
