#pragma once

#include <cstdint>
#include <iostream>
#include <iomanip>
#include <vector>
#include <semaphore>
#include <latch>
#include <thread>
#include <set>
#include "channel.h"
#include "histogram.h"

typedef struct bench_result {
    int64_t duration;
    bool failed;
    bool invalid;
} bench_result;
extern const std::vector<unsigned long> BUCKETS_US;

template <typename T>
class BenchGen {
public:
    virtual T next() = 0;
};

template <auto Func, typename FParam>
class Bench {
    int numThreads;
    int numRequests;
    int issuedRequests;
    std::counting_semaphore<1000> sem;
    BenchGen<FParam>* gen;
    Channel<FParam>** channels;
    Channel<bench_result*>** results;
    std::thread** threads;
    Histogram<unsigned long> hgram;
    unsigned long failed;
    unsigned long invalid;
    std::string name;
    std::latch done;
    double releaseFactor;

public:
    Bench(std::string name, int threads, int requests, BenchGen<FParam>* gen, double releaseFactor)
        : sem(threads)
        , hgram(BUCKETS_US)
        , name(name)
        , done(threads)
    {
        this->releaseFactor = releaseFactor;
        issuedRequests = 0;
        failed = 0;
        invalid = 0;
        numThreads = threads;
        numRequests = requests;
        this->gen = gen;
        results = new Channel<bench_result*>*[threads];
        channels = new Channel<FParam>*[threads];
        this->threads = new std::thread*[threads];
        for (int i = 0; i < threads; i++) {
            channels[i] = new Channel<FParam>();
            results[i] = new Channel<bench_result*>();
            sem.acquire();
        }
    }

    ~Bench() {
        for (int i = 0; i < numThreads; i++) {
            delete channels[i];
            delete threads[i];
            delete results[i];
        }
        delete channels;
        delete threads;
        delete results;
    }

    void bench() {
        bool released = false;
        std::cout << "Starting benchmark " << name << std::endl;
        for (int i = 0; i < numThreads; i++) {
            auto th = new std::thread(&Bench::bench_thread, this, channels[i], results[i]);
            th->detach();
            threads[i] = th;
        }
        std::cout << "Prepared " << numThreads << " threads" << std::endl;
        while (numRequests > 0) {
            for (int i = 0; i < numThreads; i++, numRequests--, issuedRequests++) {
                auto task = gen->next();
                channels[i]->put(task);
            }
            if (!released && issuedRequests > (issuedRequests + numRequests) * releaseFactor) {
                std::cout << "Prepared " << issuedRequests << " requests" << std::endl;
                sem.release(numThreads);
                released = true;
                std::cout << "Released semaphore for " << numThreads << " threads" << std::endl;
            }
        }
        for (int i = 0; i < numThreads; i++) {
            channels[i]->close();
        }
        done.wait();
        std::cout << "Threads finished" << std::endl;
    }

    void gather_results() {
        std::cout << "Gathering results" << std::endl;
        for (int i = 0; i < numThreads; i++) {
            for (auto r : results[i]->get_all()) {
                if (r->failed) {
                    failed++;
                }
                if (r->invalid) {
                    invalid++;
                }
                hgram.put(r->duration);
                delete r;
            }
        }
    }

    void print_results(std::ostream& s) {
        s << name << std::endl;
        s << "issued requests: " << issuedRequests << std::endl;
        s << "failed requests: " << failed << std::endl;
        s << "invalid requests: " << invalid << std::endl;
        s << hgram;
    }

private:
    void bench_thread(Channel<FParam>* chan, Channel<bench_result*>* rchan) {
        std::set<FParam> missedRequests;
        sem.acquire();
        while (chan->opened() || chan->has_data()) {
            while (chan->has_data()) {
                auto task = chan->get();
                auto result = Func(task);
                if (result->failed || result->invalid) {
                    if (!missedRequests.contains(task)) {
                        std::cerr << "missed request " << name << ": " << task << std::endl;
                        missedRequests.insert(task);
                    }
                }
                rchan->put(result);
            }
        }
        sem.release();
        done.count_down();
    }
};
