package ru.yandex.stockpile.server.shard;

import java.util.EnumSet;
import java.util.List;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;
import javax.annotation.WillClose;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.solomon.codec.archive.MetricArchiveImmutable;
import ru.yandex.solomon.codec.archive.header.DeleteBeforeField;
import ru.yandex.solomon.codec.archive.header.MetricHeader;
import ru.yandex.solomon.codec.serializer.StockpileFormat;
import ru.yandex.solomon.memory.layout.MemMeasurable;
import ru.yandex.solomon.memory.layout.MemoryCounter;
import ru.yandex.solomon.model.timeseries.decim.DecimPoliciesPredefined;
import ru.yandex.solomon.model.timeseries.decim.DecimPolicy;
import ru.yandex.solomon.util.CloseableUtils;
import ru.yandex.stockpile.api.EDecimPolicy;
import ru.yandex.stockpile.api.EProjectId;
import ru.yandex.stockpile.client.shard.StockpileMetricId;
import ru.yandex.stockpile.memState.MetricIdAndData;
import ru.yandex.stockpile.server.shard.merge.CompressCollector;
import ru.yandex.stockpile.server.shard.merge.DecimIterator;
import ru.yandex.stockpile.server.shard.merge.Iterator;
import ru.yandex.stockpile.server.shard.merge.SplitCollector;

/**
 * @author Vladimir Gordiychuk
 */
public class MergeTask implements MemMeasurable {
    private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(MergeTask.class);
    private static final Logger logger = LoggerFactory.getLogger(MergeTask.class);
    private static final int ARCHIVE_BYTES_SIZE_TO_FORCE_DECIM = 100 << 20; // 100 MiB;

    private final int shardId;
    private final List<MetricIdAndData> sources;
    private final long nowMillis;
    private final long splitDelayMillis;
    private final boolean allowDecim;
    private final MergeProcessMetrics.MergeTaskMetrics taskMetrics;

    public MergeTask(int shardId, @WillClose List<MetricIdAndData> sources, long nowMillis, long splitDelayMillis, boolean allowDecim, MergeProcessMetrics.MergeTaskMetrics taskMetrics) {
        this.shardId = shardId;
        this.sources = sources;
        this.nowMillis = nowMillis;
        this.splitDelayMillis = splitDelayMillis;
        this.allowDecim = allowDecim;
        this.taskMetrics = taskMetrics;
    }

    @Nonnull
    public MergeTaskResult run() {
        try {
            long startTimeNanos = taskMetrics.started();
            long recordsBefore = sources.stream()
                .mapToInt(value -> value.archive().getRecordCount())
                .sum();
            long archives = sources.size();
            MergeTaskResult result = runInner();

            taskMetrics.completed(startTimeNanos);
            taskMetrics.addCountArchive(archives);
            taskMetrics.addPointsCount(recordsBefore);
            taskMetrics.addCollapsedPoints(recordsBefore - result.getRecordCount());
            return result;
        } catch (Exception e) {
            taskMetrics.failed();
            EnumSet<EProjectId> projectIds = sources.stream()
                .map(source -> source.archive().getOwnerProjectIdOrUnknown())
                .collect(Collectors.toCollection(() -> EnumSet.noneOf(EProjectId.class)));
            throw new RuntimeException("failed to merge: " + new StockpileMetricId(shardId, sources.get(0).localId())
                + "; project: " + projectIds
                + "; headers: "+ sources.get(0).archive().header(),
                e);
        }
    }

    @Nonnull
    private MergeTaskResult runInner() {
        if (isSkipMerge()) {
            return skipMerge();
        }

        ArchiveCombiner.CombineResult combine = combine();
        MetricHeader header = combine.getHeader();
        if (!DecimPoliciesPredefined.isKnownId(header.getDecimPolicyId())) {
            // fix invalid decim policy id
            header = header.withDecimPolicy((short) EDecimPolicy.POLICY_KEEP_FOREVER_VALUE);
        }

        header = forceDecimCheck(header);
        var iterator = combine.getItemIterator();
        iterator = decim(header, nowMillis, iterator);

        final MergeTaskResult result;
        if (splitDelayMillis == 0) {
            result = merge(iterator, header, combine.getElapsedBytes());
        } else {
            result = splitMerge(iterator, header, combine.getElapsedBytes());
        }
        CloseableUtils.close(sources);
        return result;
    }

    private ArchiveCombiner.CombineResult combine() {
        var combiner = new ArchiveCombiner(shardId, getLocalId());
        for (var source : sources) {
            combiner.add(source.archive(), source.lastTsMillis());
        }
        return combiner.combine();
    }

    private MergeTaskResult skipMerge() {
        return new MergeTaskResult(sources.get(0));
    }

    private MergeTaskResult merge(Iterator it, MetricHeader header, int elapsedBytes) {
        var format = StockpileFormat.CURRENT;
        var result = CompressCollector.collect(format, header.getType(), it.columnSetMask(), elapsedBytes, it);
        var archive = new MetricArchiveImmutable(header, format, result.mask, result.buffer, result.records);
        this.taskMetrics.addCountFrames(result.frameCount);
        MetricIdAndData currentLevel = new MetricIdAndData(getLocalId(), result.lastTsMillis, nowMillis, archive);
        return new MergeTaskResult(currentLevel);
    }

    private MergeTaskResult splitMerge(Iterator it, MetricHeader header, int elapsedBytes) {
        var result = SplitCollector.collect(elapsedBytes, splitBefore(header), it);
        final MetricIdAndData next;
        if (isHeaderChanged(header) || hasDelete(header) || result.before.records > 0 || result.after.records == 0) {
            var archive = result.before.getArchive(header);
            next = new MetricIdAndData(getLocalId(), result.before.lastTsMillis, nowMillis, archive);
        } else {
            next = null;
        }

        final MetricIdAndData current;
        if (result.after.records > 0) {
            var archive = result.after.getArchive(header.withDeleteBefore(DeleteBeforeField.KEEP));
            current = new MetricIdAndData(getLocalId(), result.after.lastTsMillis, nowMillis, archive);
        } else {
            current = null;
        }

        return new MergeTaskResult(current, next);
    }

    private boolean isHeaderChanged(MetricHeader actual) {
        var prev = sources.get(0).archive().header();
        return !prev.equals(actual);
    }

    private boolean hasDelete(MetricHeader header) {
        return header.getDeleteBefore() != DeleteBeforeField.KEEP;
    }

    private boolean isSkipMerge() {
        if (sources.size() > 1) {
            return false;
        }

        var source = sources.get(0);
        var archive = source.archive();
        var firstTsMillis = archive.getFirstTsMillis();
        if (archive.getFormat() != StockpileFormat.CURRENT) {
            return false;
        }

        if (!allowDecim && splitBefore(nowMillis) <= firstTsMillis) {
            return true;
        }

        if (!DecimPoliciesPredefined.isKnownId(archive.getDecimPolicyId())) {
            return false;
        }

        DecimPolicy policy = DecimPoliciesPredefined.policyByNumber(archive.getDecimPolicyId());
        if (policy.isEmpty()) {
            return archive.bytesCount() <= ARCHIVE_BYTES_SIZE_TO_FORCE_DECIM && splitBefore(nowMillis) <= firstTsMillis;
        }

        long decimFrom = policy.getDecimFrom(source.decimatedAt());
        if (splitBefore(decimFrom) > firstTsMillis) {
            return true;
        }

        return decimFrom > source.lastTsMillis();
    }

    private long splitBefore(MetricHeader header) {
        if (!allowDecim) {
            return splitBefore(nowMillis);
        }

        DecimPolicy policy = DecimPoliciesPredefined.policyByNumber(header.getDecimPolicyId());
        if (policy.isEmpty()) {
            return splitBefore(nowMillis);
        }

        return splitBefore(policy.getDecimFrom(nowMillis));
    }

    private long splitBefore(long tsMillis) {
        if (splitDelayMillis == 0) {
            return 0;
        }

        return tsMillis - (tsMillis % splitDelayMillis);
    }

    private MetricHeader forceDecimCheck(MetricHeader headers) {
        var policy = DecimPoliciesPredefined.policyByNumber(headers.getDecimPolicyId());
        // TODO: sequentially tighten the policy to fit into ARCHIVE_BYTES_SIZE_TO_FORCE_DECIM, e.g.
        // KEEP_FOREVER -> POLICY_5_MIN_AFTER_2_MONTHS -> POLICY_5_MIN_AFTER_7_DAYS ->
        if (policy.isEmpty()) {
            long bytes = 0;
            long records = 0;
            for (var source : sources) {
                bytes += source.archive().bytesCount();
                records += source.archive().getRecordCount();
            }

            if (bytes >= ARCHIVE_BYTES_SIZE_TO_FORCE_DECIM) {
                taskMetrics.addForceDecim();
                logger.warn("Metric {} with header {} have size {} bytes and {} points, force turn on decim policy: 5 min after 7 days",
                        StockpileMetricId.toString(shardId, getLocalId()),
                        headers,
                        bytes,
                        records);
                return headers.withDecimPolicy((short) EDecimPolicy.POLICY_5_MIN_AFTER_7_DAYS_VALUE);
            }
        }

        return headers;
    }

    private Iterator decim(MetricHeader headers, long now, Iterator iterator) {
        if (!allowDecim) {
            return iterator;
        }

        DecimPolicy policy = DecimPoliciesPredefined.policyByNumber(headers.getDecimPolicyId());
        return DecimIterator.of(iterator, policy, now, estimateDecimatedAt(headers.getDecimPolicyId()));
    }

    private long estimateDecimatedAt(int policy) {
        long decimatedAt = sources.get(0).decimatedAt();
        if (decimatedAt == 0) {
            return 0;
        }

        for (var source : sources) {
            if (source.archive().getDecimPolicyId() != policy) {
                return 0;
            }

            if (source.archive().isEmpty()) {
                continue;
            }

            long lastTime = source.decimatedAt();
            if (lastTime == 0) {
                lastTime = source.archive().getFirstTsMillis();
            }
            decimatedAt = Math.min(lastTime, decimatedAt);
        }
        return decimatedAt;
    }

    @Override
    public String toString() {
        return "MergeTask{" +
                "localId=" + getLocalId() +
                ", archivesToMerge=" + sources.size() +
                ", nowMillis=" + nowMillis +
                '}';
    }

    private long getLocalId() {
        return sources.get(0).localId();
    }

    @Override
    public long memorySizeIncludingSelf() {
        long size = SELF_SIZE;
        size += MemoryCounter.listDataSizeWithContent(sources);
        return size;
    }
}
