package ru.yandex.solomon.experiments.gordiychuk.grid;

import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.ToLongFunction;
import java.util.stream.Collectors;
import java.util.stream.LongStream;

import com.fasterxml.jackson.databind.ObjectMapper;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import org.apache.logging.log4j.Level;

import ru.yandex.monlib.metrics.JvmThreads;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.core.db.dao.ServicesDao;
import ru.yandex.solomon.core.db.dao.ShardsDao;
import ru.yandex.solomon.core.db.dao.ydb.YdbServicesDao;
import ru.yandex.solomon.core.db.dao.ydb.YdbShardsDao;
import ru.yandex.solomon.core.db.model.Service;
import ru.yandex.solomon.core.db.model.Shard;
import ru.yandex.solomon.core.db.model.ShardSettings;
import ru.yandex.solomon.experiments.gordiychuk.recovery.MetricsPushScheduler;
import ru.yandex.solomon.main.logger.LoggerConfigurationUtils;
import ru.yandex.solomon.tool.YdbClient;
import ru.yandex.solomon.tool.YdbHelper;
import ru.yandex.solomon.tool.cfg.SolomonCluster;
import ru.yandex.solomon.util.actors.AsyncActorBody;
import ru.yandex.solomon.util.actors.AsyncActorRunner;
import ru.yandex.solomon.util.host.HostUtils;
import ru.yandex.solomon.util.time.DurationUtils;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static ru.yandex.solomon.experiments.gordiychuk.grid.IoUtils.mkdir;

/**
 * @author Vladimir Gordiychuk
 */
public class UpdateShardsCli implements AutoCloseable {
    private static final Set<Long> VALID_GRID = Set.of(
            1_000L,
            5_000L,
            10_000L,
            15_000L,
            20_000L,
            30_000L,
            45_000L,
            60_000L,
            90_000L,
            TimeUnit.MINUTES.toMillis(5),
            TimeUnit.MINUTES.toMillis(10),
            TimeUnit.MINUTES.toMillis(15),
            TimeUnit.MINUTES.toMillis(30),
            TimeUnit.HOURS.toMillis(1),
            TimeUnit.HOURS.toMillis(2),
            TimeUnit.HOURS.toMillis(3),
            TimeUnit.HOURS.toMillis(4),
            TimeUnit.HOURS.toMillis(5),
            TimeUnit.HOURS.toMillis(6),
            TimeUnit.DAYS.toMillis(1));
    private static final List<Long> VALID_SORTED = VALID_GRID.stream().sorted().collect(Collectors.toList());
    private static final List<Long> VALID_UNITS = List.of(
            TimeUnit.SECONDS.toMillis(1),
            TimeUnit.MINUTES.toMillis(1),
            TimeUnit.HOURS.toMillis(1),
            TimeUnit.DAYS.toMillis(1));

    private final YdbClient ydb;
    private final ServicesDao serviceDao;
    private final ShardsDao shardDao;
    private final ExecutorService executor;

    private Map<ServiceKey, Service> serviceByKey;
    private Map<Integer, Shard> shardsByNumId;
    private EstimationReader reader;
    private Path root;
    private final ConcurrentMap<Integer, Long> gridByNumId = new ConcurrentHashMap<>();

    public UpdateShardsCli(SolomonCluster cluster) {
        this.ydb = YdbHelper.createYdbClient(cluster);
        var mapper = new ObjectMapper();
        serviceDao = new YdbServicesDao(ydb.table, cluster.kikimrRootPath() + "/Config/V2/Service", mapper, ForkJoinPool.commonPool());
        shardDao = new YdbShardsDao(ydb.table, cluster.kikimrRootPath() + "/Config/V2/Shard", new ObjectMapper(), ForkJoinPool.commonPool());
        executor = new ForkJoinPool(Runtime.getRuntime().availableProcessors());
        JvmThreads.addExecutorMetrics("CpuLowPriority", executor, MetricRegistry.root());
    }

    public static void main(String[] args) {
        LoggerConfigurationUtils.simpleLogger(Level.INFO);
        MetricsPushScheduler.schedulePush();
        if (args.length != 2) {
            System.err.println("Usage: tool {dir} {enableUpdate}");
            System.exit(1);
        }

        try (var cli = new UpdateShardsCli(SolomonCluster.PROD_FRONT)) {
            cli.run(Path.of(args[0]), Boolean.parseBoolean(args[1]));
        } catch (Throwable t) {
            t.printStackTrace();
            System.exit(1);
        }

        System.exit(0);
    }

    private boolean filterShard(Shard shard) {
        var service = serviceByKey.get(ServiceKey.of(shard));
        if (service == null) {
            return true;
        }

        return service.getShardSettings().getGrid() != 0;
    }

    public void run(Path path, boolean enableUpdate) {
        serviceByKey = serviceDao.findAll().join().stream()
                .collect(Collectors.toMap(ServiceKey::of, Function.identity()));
        shardsByNumId = shardDao.findAll().join().stream()
                .collect(Collectors.toMap(Shard::getNumId, Function.identity()));

        reader = new EstimationReader(path);
        root = path;

        mkdir(root.resolve("stats"));
        var it = shardsByNumId.values().iterator();
        var progressMetric = MetricRegistry.root().gaugeDouble("progress");
        AtomicInteger completed = new AtomicInteger();
        AsyncActorBody body = () -> {
            while (it.hasNext()) {
                var shard = it.next();
                if (filterShard(shard) || !reader.hasEstimation(shard.getNumId())) {
                    completed.incrementAndGet();
                    continue;
                }

                return CompletableFuture.runAsync(() -> estimate(shard), executor)
                        .thenRun(() -> {
                            progressMetric.set(completed.incrementAndGet() * 100. / shardsByNumId.size());
                            String progress = String.format("%.2f%%", progressMetric.get());
                            System.out.println("Processing at " + HostUtils.getShortName() + " progress: " + progress);
                        });
            }
            return completedFuture(AsyncActorBody.DONE_MARKER);
        };
        AsyncActorRunner runner = new AsyncActorRunner(body, executor, 10);
        runner.start().join();
        writeTotalStatistics();
        if (enableUpdate) {
            updateGrid();
        }
    }

    private void updateGrid() {
        var it = gridByService().entrySet().iterator();
        AsyncActorBody body = () -> {
            while (it.hasNext()) {
                var entry = it.next();
                long[] grids = entry.getValue();
                var service = serviceByKey.get(entry.getKey());
                if (grids.length == 1 && grids[0] == 0) {
                    continue;
                }

                if (grids[0] == 0) {
                    grids = Arrays.copyOfRange(grids, 1, grids.length);
                }

                if (grids.length != 1) {
                    continue;
                }

                long grid = grids[0];
                if (!VALID_GRID.contains(grid)) {
                    continue;
                }

                return serviceDao.partialUpdate(service.toBuilder()
                        .setShardSettings(ShardSettings.of(service.getShardSettings().getType(),
                                service.getShardSettings().getPullSettings(),
                                (int) TimeUnit.MILLISECONDS.toSeconds(grid),
                                service.getShardSettings().getMetricsTtl(),
                                service.getShardSettings().getRetentionPolicy(),
                                service.getShardSettings().getAggregationSettings(),
                                service.getShardSettings().getInterval()))
                        .build());

            }
            return completedFuture(AsyncActorBody.DONE_MARKER);
        };
        AsyncActorRunner runner = new AsyncActorRunner(body, executor, 100);
        runner.start().join();
    }

    private void estimate(Shard shard) {
        var records = reader.read(shard.getNumId());
        if (records.size() <= 100) {
            return;
        }

        long grid = byDistribution(calculateDistribution(records, value -> value.dominant));
        grid = isValidGrid(grid) ? grid : byDistribution(calculateDistribution(records, value -> value.min));
        grid = isValidGrid(grid) ? grid : byDistribution(truncateToUnitDistribution(records, value -> value.dominant));
        grid = isValidGrid(grid) ? grid : byDistribution(truncateToGridDistribution(records, value -> value.dominant));
        grid = isValidGrid(grid) ? grid : minMaxBaseGrid(records);

        writeStatistics(grid, shard, records, root.resolve("stats").resolve(Integer.toUnsignedString(shard.getNumId())));
        gridByNumId.put(shard.getNumId(), grid);
    }

    private boolean isValidGrid(long grid) {
        return VALID_GRID.contains(grid);
    }

    private void writeStatistics(long grid, Shard shard, List<Estimation> records, Path file) {
        try (var writer = Files.newBufferedWriter(file)) {
            writer.write("ProjectId: " + shard.getProjectId() + "\n");
            writer.write("ShardId: " + shard.getId() + "\n");
            writer.write("Grid: " + DurationUtils.formatDurationMillis(grid) + "(" + grid + ")\n");
            writer.write("\nMin distribution:\n");
            writeDistribution(writer, calculateDistribution(records, value -> value.min));
            writer.write("\nDominant distribution:\n");
            writeDistribution(writer, calculateDistribution(records, value -> value.dominant));
            if (!VALID_GRID.contains(grid) || grid >= 60_000) {
                writeRecords(writer, records);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private Long2IntOpenHashMap calculateDistribution(List<Estimation> records, ToLongFunction<Estimation> fn) {
        var distribution = new Long2IntOpenHashMap();
        for (var record : records) {
            distribution.addTo(fn.applyAsLong(record), 1);
        }
        return distribution;
    }

    private Long2IntOpenHashMap truncateToGridDistribution(List<Estimation> records, ToLongFunction<Estimation> fn) {
        var distribution = new Long2IntOpenHashMap();
        for (var record : records) {
            long grid = fn.applyAsLong(record);
            for (int idx = 1; idx < VALID_SORTED.size(); idx++) {
                if (grid < VALID_SORTED.get(idx)) {
                    grid -= grid % VALID_SORTED.get(idx - 1);
                }
            }
            distribution.addTo(grid, 1);
        }
        return distribution;
    }

    private Long2IntOpenHashMap truncateToUnitDistribution(List<Estimation> records, ToLongFunction<Estimation> fn) {
        var distribution = new Long2IntOpenHashMap();
        for (var record : records) {
            long grid = fn.applyAsLong(record);
            for (int idx = 1; idx < VALID_UNITS.size(); idx++) {
                if (grid < VALID_UNITS.get(idx)) {
                    grid -= grid % VALID_UNITS.get(idx - 1);
                }
            }
            distribution.addTo(grid, 1);
        }
        return distribution;
    }

    private void writeDistribution(BufferedWriter writer, Long2IntOpenHashMap distribution) throws IOException {
        long[] sorted = distribution.keySet()
                .stream()
                .mapToLong(Long::longValue)
                .sorted()
                .toArray();
        long total = distribution.values().stream().mapToInt(Integer::intValue).sum();

        int size = String.valueOf(sorted[sorted.length - 1]).length();
        for (long gridMs : sorted) {
            double p = distribution.get(gridMs) * 100. / total;
            String percent = "*".repeat((int) Math.floor(p));
            String bucket = String.valueOf(gridMs);
            writer.write(String.format("%" + size + "s|%6.2f|%s\n", bucket, p, percent));
        }
    }

    private void writeRecords(BufferedWriter writer, List<Estimation> records) throws IOException {
        writer.write("\nMetrics:\n");
        writer.write(String.format("%10s|%10s|%10s|%10s|%10s%n", "top", "min", "max", "mean", "std"));
        for (var record : records) {
            writer.write(String.format("%10d|%10d|%10d|%10.2f|%10.2f%n", record.dominant, record.min, record.max, record.mean, record.std));
        }
    }

    private long minMaxBaseGrid(List<Estimation> records) {
        int unknown = 0;
        var grid = new Long2IntOpenHashMap();
        for (var record : records) {
            if (record.min != record.max) {
                unknown++;
            } else {
                grid.addTo(record.min, 1);
            }
        }

        if (unknown >= records.size() / 4) {
            return 0;
        }

        if (grid.size() == 1) {
            return grid.keySet().iterator().nextLong();
        }

        long dominant = 0;
        for (var entry : grid.long2IntEntrySet()) {
            if (entry.getIntValue() > grid.get(dominant)) {
                dominant = entry.getLongKey();
            }
        }

        int others = 0;
        for (var entry : grid.long2IntEntrySet()) {
            if (entry.getLongKey() != dominant) {
                others += entry.getIntValue();
            }
        }

        if (others + unknown < records.size() / 5) {
            return dominant;
        }

        return 0;
    }

    private long byDistribution(Long2IntOpenHashMap distribution) {
        if (distribution.size() == 1) {
            return distribution.keySet().iterator().nextLong();
        }

        long dominant = 0;
        int dominantCount = 0;
        int total = 0;
        var it = distribution.long2IntEntrySet().fastIterator();
        while (it.hasNext()) {
            var entry = it.next();
            total += entry.getIntValue();
            if (entry.getIntValue() > dominantCount) {
                dominant = entry.getLongKey();
                dominantCount = entry.getIntValue();
            }
        }

        double p = distribution.get(dominant) * 100. / total;
        if (p >= 85) {
            return dominant;
        }

        return 0;
    }

    @Override
    public void close() {
        ydb.close();
        executor.shutdownNow();
    }

    private void writeTotalStatistics() {
        try (var writer = Files.newBufferedWriter(root.resolve("summary_stats"))) {
            writeTotalDistributionStatistics(writer);
            writeTotalShardsStatistics(writer);
            writeTotalServiceStatistics(writer);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private void writeTotalDistributionStatistics(BufferedWriter writer) throws IOException {
        var dist = new Long2IntOpenHashMap();
        for (var entry : gridByNumId.entrySet()) {
            long grid = entry.getValue();
            dist.addTo(grid, 1);
        }

        writer.write("Distribution:\n");
        writeDistribution(writer, dist);
    }

    private int compareShards(int leftNumId, int rightNumId) {
        int compare = gridByNumId.get(leftNumId).compareTo(gridByNumId.get(rightNumId));
        if (compare != 0) {
            return compare;
        }
        var leftShard = shardsByNumId.get(leftNumId);
        var rightShard = shardsByNumId.get(rightNumId);
        compare = leftShard.getProjectId().compareTo(rightShard.getProjectId());
        if (compare != 0) {
            return compare;
        }
        compare = leftShard.getServiceId().compareTo(rightShard.getServiceId());
        if (compare != 0) {
            return compare;
        }

        return leftShard.getId().compareTo(rightShard.getId());
    }

    private void writeTotalShardsStatistics(BufferedWriter writer) throws IOException {
        writer.write("\nShards:\n");
        var sortedNumIds = gridByNumId.keySet()
                .stream()
                .sorted(this::compareShards)
                .mapToInt(Integer::intValue)
                .toArray();

        int maxNumId = gridByNumId.keySet().stream().max(Integer::compareTo).toString().length();
        int maxId = shardsByNumId.values().stream()
                .filter(shard -> gridByNumId.containsKey(shard.getNumId()))
                .mapToInt(shard -> shard.getId().length())
                .max().orElse(10);
        int maxProjectId = shardsByNumId.values().stream()
                .filter(shard -> gridByNumId.containsKey(shard.getNumId()))
                .mapToInt(shard -> shard.getProjectId().length())
                .max().orElse(10);
        int maxServiceId = shardsByNumId.values().stream()
                .filter(shard -> gridByNumId.containsKey(shard.getNumId()))
                .mapToInt(shard -> shard.getServiceId().length())
                .max().orElse(10);
        int maxGrid = gridByNumId.values().stream().max(Long::compareTo).toString().length();
        String format = "%"+maxNumId+"s|%"+maxProjectId+"s|%"+maxId+"s|%"+maxServiceId+"s|%"+maxGrid+"s|%10s%n";
        writer.write(String.format(format, "numId", "projectId", "id", "serviceId", "gridMs", "grid"));
        for (int numId : sortedNumIds) {
            var shards = shardsByNumId.get(numId);
            long grid = gridByNumId.get(numId);
            writer.write(String.format(format,
                    Integer.toUnsignedString(numId),
                    shards.getProjectId(),
                    shards.getId(),
                    shards.getServiceId(),
                    grid,
                    DurationUtils.formatDurationMillis(grid)));
        }
    }

    private void writeTotalServiceStatistics(BufferedWriter writer) throws IOException {
        record Record(String projectId, String id, String gridMs, String grid, long minGrid) implements Comparable<Record> {
            @Override
            public int compareTo(Record o) {
                int compare = Long.compare(minGrid(), o.minGrid());
                if (compare != 0) {
                    return compare;
                }
                compare = projectId().compareTo(o.projectId());
                if (compare != 0) {
                    return compare;
                }
                return id().compareTo(o.id());
            }
        }

        List<Record> records = new ArrayList<>();
        for (var entry : gridByService().entrySet()) {
            records.add(new Record(
                    entry.getKey().projectId,
                    entry.getKey().id,
                    Arrays.toString(entry.getValue()),
                    LongStream.of(entry.getValue())
                            .mapToObj(DurationUtils::formatDurationMillis)
                            .collect(Collectors.toList())
                            .toString(),
                    entry.getValue()[0]));
        }
        records.sort(Record::compareTo);

        int maxProjectId = records.stream().mapToInt(record -> record.projectId().length()).max().orElse(10);
        int maxId = records.stream().mapToInt(record -> record.id().length()).max().orElse(10);
        int maxGridMs = records.stream().mapToInt(record -> record.gridMs().length()).max().orElse(10);
        int maxGrid = records.stream().mapToInt(record -> record.grid().length()).max().orElse(10);
        String format = "%"+maxProjectId+"s|%"+maxId+"s|%"+maxGridMs+"s|%"+maxGrid+"s%n";

        writer.write("\nServices:\n");
        writer.write(String.format(format, "projectId", "id", "gridMs", "grid"));
        for (var record : records) {
            writer.write(String.format(format, record.projectId(), record.id(), record.gridMs(), record.grid()));
        }
    }

    private Map<ServiceKey, long[]> gridByService() {
        Map<ServiceKey, Set<Long>> gridsByService = new HashMap<>();
        for (var entry : gridByNumId.entrySet()) {
            int numId = entry.getKey();
            var shard = shardsByNumId.get(numId);
            var serviceKey = ServiceKey.of(shard);
            var grids = gridsByService.computeIfAbsent(serviceKey, ignore -> new HashSet<>());
            grids.add(entry.getValue());
        }

        Map<ServiceKey, long[]> result = new HashMap<>();
        for (var entry : gridsByService.entrySet()) {
            long[] grids = new long[entry.getValue().size()];
            int idx = 0;
            for (var grid : entry.getValue()) {
                grids[idx++] = grid;
            }
            Arrays.sort(grids);
            result.put(entry.getKey(), grids);
        }
        return result;
    }

    private static record ServiceKey(String projectId, String id) {
        public static ServiceKey of(Service service) {
            return new ServiceKey(service.getProjectId(), service.getId());
        }

        public static ServiceKey of(Shard shard) {
            return new ServiceKey(shard.getProjectId(), shard.getServiceId());
        }
    }
}
