package ru.yandex.direct.interruption;

import java.util.function.Supplier;

import javax.annotation.Nullable;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Проверки {@code interrupt}-статуса потока с поддержкой защищенных секций,
 * в которых проверки игнорируются.
 * <p>
 * Для того, чтобы сгенерировать исключение, если выставлен {@code interrupt}-статус
 * текущего потока, нужно вызвать {@link #checkInterruption()} или
 * {@link #checkInterruption(String)}.
 * <p>
 * Генерируемое исключение: {@link ThreadInterruptedException}. При генерации исключения
 * {@code interrupt}-статус потока очищается.
 * <p>
 * Можно предотвратить генерацию исключения этими методами в определенной секции кода,
 * если перед этим вызвать метод входа в "защищенную секцию" {@link #enterProtectedSection()}.
 * Каждая защищенная секция должна быть закрыта вызовом метода {@link #leaveProtectedSection()}.
 * Защищенные секции могут быть вложенными, если в определенном потоке выполнения метод
 * {@link #enterProtectedSection()} вызван многократно. При этом для выхода из всех защищенных
 * секций для каждого вызова метода входа необходимо вызвать метод выхода.
 * <p>
 * Очень важно в каждом потоке завершать все защищенные секции. Если этого не сделать,
 * то при переиспользовании потока в трэд-пуле поведение будет некорректным.
 * Так же это необходимо для очистки {@link ThreadLocal}-переменных.
 * <p>
 * При выходе из последней защищенной секции будет проведена проверка на установку флага
 * и при положительном результате будет сгенерировано исключение.
 * <p>
 * Пример:
 *
 * <pre> {@code
 * public class Task implements Runnable {
 *
 *     @Override
 *     public void run() {
 *         mustNotBeInterrupted();
 *     }
 *
 *     // в этом методе мы хотим, чтобы oneMethod() и anotherMethod() выполнились без прерываний
 *     public void mustNotBeInterrupted() {
 *         // внутри этого метода, по окончании выполнения переданной функции,
 *         // будет проведена проверка, не выставлен ли interrupted-статус потока,
 *         // и если выставлен, будет сгенерировано исключение
 *         executeProtected(() -> {
 *             oneMethod();
 *             anotherMethod();
 *         });
 *     }
 *
 *     // может быть прерван исключением только в случае, если выполняется вне защищенной секции
 *     public void oneMethod() {
 *         for (int i = 0; i < 10; i++) {
 *             checkInterruption();
 *             // выполнение ресурсоемкой работы
 *         }
 *     }
 *
 *     // может быть прерван исключением только в случае, если выполняется вне защищенной секции
 *     public void anotherMethod() {
 *         for (int i = 0; i < 10; i++) {
 *             checkInterruption();
 *             // выполнение ресурсоемкой работы
 *         }
 *     }
 * }
 * }</pre>
 */
public class InterruptionChecker {

    private static final Logger logger = LoggerFactory.getLogger(InterruptionChecker.class);
    private static final ThreadLocal<Integer> protectionDepth = new ThreadLocal<>();

    private InterruptionChecker() {
    }

    public static void checkInterruption() {
        checkInterruptionInternal(null);
    }

    public static void checkInterruption(String message) {
        checkInterruptionInternal(message);
    }

    public static void enterProtectedSection() {
        int newDepth = getProtectionDepth() + 1;
        setProtectionDepth(newDepth);
        logger.debug("Entered protected section. Current protection depth: {}", newDepth);
    }

    public static void leaveProtectedSection() {
        int newDepth = getProtectionDepth() - 1;
        setProtectionDepth(newDepth);
        logger.debug("Left protected section. Current protection level: {}", newDepth);

        if (newDepth == 0 && Thread.interrupted()) {
            logger.debug("Thread has interrupted flag set when left outer protected session."
                    + "Throwing exception...", newDepth);
            throw new ThreadInterruptedException(
                    "Thread has interrupted flag set when left outer protected session.");
        }
    }

    public static <T> T executeProtected(Supplier<T> supplier) {
        enterProtectedSection();
        try {
            return supplier.get();
        } finally {
            leaveProtectedSection();
        }
    }

    public static void executeProtected(Runnable task) {
        enterProtectedSection();
        try {
            task.run();
        } finally {
            leaveProtectedSection();
        }
    }

    private static void checkInterruptionInternal(@Nullable String message) {
        if (Thread.currentThread().isInterrupted()) {
            if (getProtectionDepth() == 0) {
                Thread.interrupted();
                throw message != null ?
                        new ThreadInterruptedException(message) :
                        new ThreadInterruptedException();
            } else {
                logger.debug("Ignore interruption check: called at protected section and thread is interrupted");
            }
        }
    }

    private static int getProtectionDepth() {
        Integer depth = protectionDepth.get();
        return depth == null ? 0 : depth;
    }

    private static void setProtectionDepth(int depth) {
        if (depth < 0) {
            logger.error("Invalid usage: thread left more protected sections then entered!");
            return;
        }
        if (depth == 0) {
            protectionDepth.remove();
        } else {
            protectionDepth.set(depth);
        }
    }

    static void clear() {
        protectionDepth.remove();
    }
}
