package ru.yandex.mail.cerberus.worker;

import io.micrometer.core.instrument.LongTaskTimer;
import io.micrometer.core.instrument.MeterRegistry;
import io.micronaut.context.BeanDefinitionRegistry;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import one.util.streamex.EntryStream;
import one.util.streamex.StreamEx;
import ru.yandex.mail.cerberus.TaskType;
import ru.yandex.mail.cerberus.worker.api.Processor;
import ru.yandex.mail.cerberus.worker.api.TaskConfiguration;
import ru.yandex.mail.cerberus.worker.api.TaskProcessor;
import ru.yandex.mail.cerberus.worker.exception.InvalidTaskProcessorBeanException;

import javax.inject.Inject;
import javax.inject.Singleton;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import static java.util.function.Function.identity;

@Slf4j
@Singleton
public class TaskRegistry {
    @Value
    public static class Record {
        TaskConfiguration configuration;
        @SuppressWarnings("rawtypes") TaskProcessor processor;
        LongTaskTimer meter;
        AtomicBoolean errorMeter;
    }

    private final Map<TaskType, Record> records;

    private static TaskType findTaskType(BeanDefinitionRegistry beanDefinitionRegistry,
                                         @SuppressWarnings("rawtypes") TaskProcessor processor) {
        val processorClass = processor.getClass();
        val beanInfo = beanDefinitionRegistry.findBeanRegistration(processor)
            .orElseThrow(() -> new IllegalStateException("Task processor " + processorClass + " bean registration not found"));

        val processorAnnotation = beanInfo.getBeanDefinition().getAnnotation(Processor.class);
        if (processorAnnotation == null) {
            throw new InvalidTaskProcessorBeanException("Task processor " + processorClass +
                " should be annotated with @Processor annotation");
        }

        return new TaskType(processorAnnotation.getRequiredValue(String.class));
    }

    private void validate(String what, StreamEx<TaskType> types) {
        val duplicates = types
            .runLengths()
            .filterValues(count -> count > 1)
            .keys()
            .joining(", ");

        if (!duplicates.isEmpty()) {
            throw new IllegalStateException(String.format("Duplicate %s for tasks [%s] found.", what, duplicates));
        }
    }

    private Record resolveTask(TaskConfiguration configuration, MeterRegistry meterRegistry,
                               @SuppressWarnings("rawtypes") Map<TaskType, TaskProcessor> processors) {
        val taskType = configuration.getTaskType();
        val processor = processors.get(taskType);
        if (processor == null) {
            throw new IllegalStateException("Processor for task " + taskType + " not found");
        }

        val meterBaseName = "task_" + taskType.toString();
        val meter = meterRegistry.more().longTaskTimer(meterBaseName);
        val errorMeter = meterRegistry.gauge(meterBaseName + "_error", new AtomicBoolean(false), flag -> flag.get() ? 1. : 0.);

        log.info("Task {} registered successfully", taskType);
        return new Record(configuration, processor, meter, errorMeter);
    }

    @Inject
    public TaskRegistry(BeanDefinitionRegistry beanDefinitionRegistry, TaskConfiguration[] tasksConfiguration,
                        @SuppressWarnings("rawtypes") TaskProcessor[] processors,
                        MeterRegistry meterRegistry) {
        val processorsWithType = StreamEx.of(processors)
            .mapToEntry(processor -> findTaskType(beanDefinitionRegistry, processor), identity())
            .toImmutableList();

        validate("configuration", StreamEx.of(tasksConfiguration).map(TaskConfiguration::getTaskType));
        validate("processor", StreamEx.of(processorsWithType).map(Map.Entry::getKey));

        var processorByType = StreamEx.of(processorsWithType)
            .toMap(Map.Entry::getKey, Map.Entry::getValue);

        this.records = StreamEx.of(tasksConfiguration)
            .toMap(
                TaskConfiguration::getTaskType,
                config -> resolveTask(config, meterRegistry, processorByType)
            );
    }

    public Optional<Record> findTaskRecord(TaskType taskType) {
        return Optional.ofNullable(records.get(taskType));
    }

    @SuppressWarnings("rawtypes")
    public Optional<TaskProcessor> findTaskProcessor(TaskType taskType) {
        return findTaskRecord(taskType).map(Record::getProcessor);
    }

    public Optional<TaskConfiguration> findTaskConfiguration(TaskType taskType) {
        return findTaskRecord(taskType).map(Record::getConfiguration);
    }

    public StreamEx<TaskConfiguration> configurations() {
        return EntryStream.of(records)
            .mapValues(Record::getConfiguration)
            .values();
    }
}
