#include <maps/wikimap/mapspro/services/mrc/libs/graph_matcher_adapter/include/compact_graph_matcher_adapter.h>
#include <maps/wikimap/mapspro/services/mrc/libs/track_classifier/include/track_classifier.h>

#include <maps/libs/chrono/include/time_point.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>

#include <opencv2/opencv.hpp>

#include <fstream>
#include <numeric>
#include <filesystem>

namespace fs = std::filesystem;

namespace {
    struct Track {
        std::vector<size_t> sortedIndices;
        std::vector<maps::mrc::track_classifier::TrackType> types;
        std::vector<maps::mrc::db::TrackPoint> tpts;

        maps::chrono::TimePoint getSortedTimestamp(size_t idx) const {
            return tpts[sortedIndices[idx]].timestamp();
        }
        maps::mrc::track_classifier::TrackType getSortedType(size_t idx) const {
            return types[sortedIndices[idx]];
        }
        size_t size() const {
            return sortedIndices.size();
        }
        void resize(size_t newSize) {
            types.resize(newSize);
            tpts.resize(newSize);
            sortedIndices.resize(tpts.size());
            std::iota(sortedIndices.begin(), sortedIndices.end(), 0);
        }
        void sortTrackByTimestamp() {
            sortedIndices.resize(tpts.size());
            std::iota(sortedIndices.begin(), sortedIndices.end(), 0);
            std::sort(sortedIndices.begin(), sortedIndices.end(),
                [&](size_t a, size_t b) {
                    return tpts[a].timestamp() < tpts[b].timestamp();
                }
            );
        }
    };

    static const std::string TRACK_POINT_TYPE_UNKNOWN    = "unknown";
    static const std::string TRACK_POINT_TYPE_PEDESTRIAN = "pedestrian";
    static const std::string TRACK_POINT_TYPE_VEHICLE    = "vehicle";
    static const std::map<std::string, maps::mrc::track_classifier::TrackType> TRACK_POINT_TYPE_FROM_STRING = {
        {TRACK_POINT_TYPE_UNKNOWN,    maps::mrc::track_classifier::TrackType::Undefined},
        {TRACK_POINT_TYPE_PEDESTRIAN, maps::mrc::track_classifier::TrackType::Pedestrian},
        {TRACK_POINT_TYPE_VEHICLE,    maps::mrc::track_classifier::TrackType::Vehicle}
    };

    maps::mrc::track_classifier::TrackType fromString(const std::string& typeString) {
        return TRACK_POINT_TYPE_FROM_STRING.at(typeString);
    }

    Track loadTrack(const maps::json::Value& jsonData) {
        const std::string sourceId = jsonData.hasField("source_id") ? jsonData["source_id"].as<std::string>() : "";
        Track track;
        const maps::json::Value& jsonTPts = jsonData["track"];
        track.resize(jsonTPts.size());
        for (size_t i = 0; i < jsonTPts.size(); i++) {
            const maps::json::Value& jsonTPt = jsonTPts[i];
            track.types[i] = fromString(jsonTPt["type"].as<std::string>());
            maps::mrc::db::TrackPoint& tpt = track.tpts[i];
            tpt.setSourceId(sourceId);
            tpt.setGeodeticPos({jsonTPt["lon"].as<double>(), jsonTPt["lat"].as<double>()});
            tpt.setTimestamp(maps::chrono::parseSqlDateTime(jsonTPt["timestamp"].as<std::string>()));
            if (jsonTPt.hasField("accuracy_meters")) {
                tpt.setAccuracyMeters(jsonTPt["accuracy_meters"].as<double>());
            }
            if (jsonTPt.hasField("heading")) {
                tpt.setHeading(maps::geolib3::Heading(jsonTPt["heading"].as<double>()));
            }
            if (jsonTPt.hasField("speed_meters_per_sec")) {
                tpt.setSpeedMetersPerSec(jsonTPt["speed_meters_per_sec"].as<double>());
            }
        }

        track.sortTrackByTimestamp();
        return track;
    }

    std::vector<Track> loadTracks(const std::string& jsonPath) {
        std::vector<Track> result;
        maps::json::Value fileJson = maps::json::Value::fromFile(jsonPath);
        for(const auto& track : fileJson) {
            result.emplace_back(loadTrack(track));
        }
        return result;
    }

    double percent(int32_t numerator, int32_t denominator) {
        if (0 == denominator) {
            return 100.;
        }
        return (double)numerator / denominator * 100.;
    }
}// namespace

int main(int argc, const char** argv) try {
    maps::cmdline::Parser parser("Check track classifier on json files with track");

    maps::cmdline::Option<std::string> roadGraphFolder = parser.string("road-graph-path")
        .required()
        .help("Path to road graph folder");

    maps::cmdline::Option<std::string> pedestrianGraphFolder = parser.string("pedestrian-graph-path")
        .required()
        .help("Path to pedestrian graph folder");

    parser.parse(argc, const_cast<char**>(argv));

    /*
        В stat[i,j] кладем количество точек класса i (по GT считанному из файла),
        отнесенных алгоритмом классификации к классу j
        0  -- Undefined
        1  -- Pedestrian
        2  -- Vehicle
    */
    static const std::map<maps::mrc::track_classifier::TrackType, size_t> idxType = {
        {maps::mrc::track_classifier::TrackType::Undefined,  0},
        {maps::mrc::track_classifier::TrackType::Pedestrian, 1},
        {maps::mrc::track_classifier::TrackType::Vehicle,    2}
    };

    maps::mrc::adapters::CompactGraphMatcherAdapter roadMatcher(roadGraphFolder);
    maps::mrc::adapters::CompactGraphMatcherAdapter pedestrianMatcher(pedestrianGraphFolder);
    const std::map<maps::mrc::db::GraphType, const maps::mrc::adapters::Matcher*> graphTypeToMatcherMap =
        {
            {maps::mrc::db::GraphType::Road, &roadMatcher},
            {maps::mrc::db::GraphType::Pedestrian, &pedestrianMatcher}
        };

    cv::Mat fullStat = cv::Mat::zeros(3, 3, CV_32SC1);
    const std::vector<std::string>& filesList = parser.argv();
    INFO() << "Files count: " << filesList.size();
    for (size_t i = 0; i < filesList.size(); i++) {
        cv::Mat stat = cv::Mat::zeros(3, 3, CV_32SC1);
        std::vector<Track> tracks = loadTracks(filesList[i]);
        for (size_t j = 0; j < tracks.size(); j++) {
            const Track& track = tracks[j];

            std::vector<maps::mrc::track_classifier::TrackInterval> intervals = maps::mrc::track_classifier::classify(track.tpts, graphTypeToMatcherMap);
            std::sort(intervals.begin(), intervals.end(),
                [&](const maps::mrc::track_classifier::TrackInterval& a, const maps::mrc::track_classifier::TrackInterval& b) {
                    if (a.begin < b.begin) {
                        REQUIRE(a.end <= b.begin, "Time intervals intersected");
                        return true;
                    } else if (b.begin < a.begin) {
                        REQUIRE(b.end <= a.begin, "Time intervals intersected");
                        return false;
                    } else {
                        return false;
                    };
                }
            );
            size_t intervalIdx = 0;
            size_t trackPtIdx = 0;
            for (; trackPtIdx < track.size(); trackPtIdx++) {
                const maps::chrono::TimePoint tptTime = track.getSortedTimestamp(trackPtIdx);
                const maps::mrc::track_classifier::TrackType type = track.getSortedType(trackPtIdx);
                if (tptTime < intervals[intervalIdx].begin) {
                    stat.at<int32_t>(idxType.at(type), idxType.at(maps::mrc::track_classifier::TrackType::Undefined))++;
                } else if (tptTime <= intervals[intervalIdx].end) {
                    stat.at<int32_t>(idxType.at(type), idxType.at(intervals[intervalIdx].type))++;
                } else { //if (tptTime > intervals[intervalIdx].end)
                    intervalIdx++;
                    if (intervalIdx >= intervals.size()) {
                        break;
                    }
                    trackPtIdx--;
                }
            }
            for (; trackPtIdx < track.size(); trackPtIdx++) {
                const maps::mrc::track_classifier::TrackType type = track.getSortedType(trackPtIdx);
                stat.at<int32_t>(idxType.at(type), idxType.at(maps::mrc::track_classifier::TrackType::Undefined))++;
            }
            INFO() << fs::path(filesList[i]).filename() << " intervals count: " << intervals.size();
            for (size_t k = 0; k < intervals.size(); k++) {
                const maps::mrc::track_classifier::TrackInterval& interval = intervals[i];
                std::string ttype = "Undefined";
                if (maps::mrc::track_classifier::TrackType::Pedestrian == interval.type) {
                    ttype = "Pedestrian";
                } else if (maps::mrc::track_classifier::TrackType::Vehicle == interval.type) {
                    ttype = "Vehicle";
                }
                INFO() << "     " << i << ". " << ttype << ": "
                    << maps::chrono::formatSqlDateTime(interval.begin) << " -- " << maps::chrono::formatSqlDateTime(interval.end);
            }
        }
        {
            const int32_t PedestrianGT = stat.at<int32_t>(1, 0) + stat.at<int32_t>(1, 1) + stat.at<int32_t>(1, 2);
            const int32_t PedestrianTP = stat.at<int32_t>(1, 1);
            const int32_t PedestrianFP = stat.at<int32_t>(0, 1) + stat.at<int32_t>(2, 1);

            const int32_t VehicleGT = stat.at<int32_t>(2, 0) + stat.at<int32_t>(2, 1) + stat.at<int32_t>(2, 2);
            const int32_t VehicleTP = stat.at<int32_t>(2, 2);
            const int32_t VehicleFP = stat.at<int32_t>(0, 2) + stat.at<int32_t>(1, 2);
            if (PedestrianFP != 0 || PedestrianTP != PedestrianGT ||
                VehicleFP != 0 || VehicleTP != VehicleGT) {
                INFO() << fs::path(filesList[i]).filename();
                INFO() << "  pedestrian: " << percent(PedestrianTP, PedestrianTP + PedestrianFP) << "% / "
                                           << percent(PedestrianTP, PedestrianGT) << "% ";
                INFO() << "  vehicle:    " << percent(VehicleTP, VehicleTP + VehicleFP) << "% / "
                                           << percent(VehicleTP, VehicleGT) << "%";
            }
        }
        fullStat += stat;
    }

    INFO() << "Undefined GT points as error";
    {
        const int32_t PedestrianGT = fullStat.at<int32_t>(1, 0) + fullStat.at<int32_t>(1, 1) + fullStat.at<int32_t>(1, 2);
        const int32_t PedestrianTP = fullStat.at<int32_t>(1, 1);
        const int32_t PedestrianFP = fullStat.at<int32_t>(0, 1) + fullStat.at<int32_t>(2, 1);

        const int32_t VehicleGT = fullStat.at<int32_t>(2, 0) + fullStat.at<int32_t>(2, 1) + fullStat.at<int32_t>(2, 2);
        const int32_t VehicleTP = fullStat.at<int32_t>(2, 2);
        const int32_t VehicleFP = fullStat.at<int32_t>(0, 2) + fullStat.at<int32_t>(1, 2);

        INFO() << "Pedestrian";
        INFO() << "  precision: " << (double)PedestrianTP / (PedestrianTP + PedestrianFP) * 100 << "%";
        INFO() << "  recall:    " << (double)PedestrianTP / PedestrianGT * 100 << "%";
        INFO() << "  F1 score:  " << 2.0 * PedestrianTP / (PedestrianTP + PedestrianFP + PedestrianGT) * 100 << "%";
        INFO() << "Vehicle";
        INFO() << "  precision: " << (double)VehicleTP / (VehicleTP + VehicleFP) * 100 << "%";
        INFO() << "  recall:    " << (double)VehicleTP / VehicleGT * 100 << "%";
        INFO() << "  F1 score:  " << 2.0 * VehicleTP / (VehicleTP + VehicleFP + VehicleGT) * 100 << "%";
    }
    INFO() << "---------------------------------------------------";
    INFO() << "Undefined GT points ignored";
    {
        const int32_t PedestrianGT = fullStat.at<int32_t>(1, 0) + fullStat.at<int32_t>(1, 1) + fullStat.at<int32_t>(1, 2);
        const int32_t PedestrianTP = fullStat.at<int32_t>(1, 1);
        const int32_t PedestrianFP = fullStat.at<int32_t>(2, 1);

        const int32_t VehicleGT = fullStat.at<int32_t>(2, 0) + fullStat.at<int32_t>(2, 1) + fullStat.at<int32_t>(2, 2);
        const int32_t VehicleTP = fullStat.at<int32_t>(2, 2);
        const int32_t VehicleFP = fullStat.at<int32_t>(1, 2);

        INFO() << "Pedestrian";
        INFO() << "  precision: " << (double)PedestrianTP / (PedestrianTP + PedestrianFP) * 100 << "%";
        INFO() << "  recall:    " << (double)PedestrianTP / PedestrianGT * 100 << "%";
        INFO() << "  F1 score:  " << 2.0 * PedestrianTP / (PedestrianTP + PedestrianFP + PedestrianGT) * 100 << "%";
        INFO() << "Vehicle";
        INFO() << "  precision: " << (double)VehicleTP / (VehicleTP + VehicleFP) * 100 << "%";
        INFO() << "  recall:    " << (double)VehicleTP / VehicleGT * 100 << "%";
        INFO() << "  F1 score:  " << 2.0 * VehicleTP / (VehicleTP + VehicleFP + VehicleGT) * 100 << "%";
    }
    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Application failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Application failed: " << e.what();
    return EXIT_FAILURE;
}
