package ru.yandex.chemodan.queller.worker;

import java.util.UUID;

import org.joda.time.Instant;
import org.springframework.amqp.core.Message;
import org.springframework.amqp.core.MessageProperties;
import org.springframework.amqp.core.Queue;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.chemodan.http.YandexCloudRequestIdHolder;
import ru.yandex.chemodan.queller.celery.CeleryBender;
import ru.yandex.chemodan.queller.celery.QuellerQueues;
import ru.yandex.chemodan.queller.celery.control.CeleryApiRequest;
import ru.yandex.chemodan.queller.celery.control.CeleryControl;
import ru.yandex.chemodan.queller.celery.control.callback.replies.JavaWorkerInspectStatsReply;
import ru.yandex.chemodan.queller.celery.job.CeleryJob;
import ru.yandex.chemodan.queller.celery.job.CeleryJobContext;
import ru.yandex.chemodan.queller.celery.job.CeleryOnetimeJobConverter;
import ru.yandex.chemodan.queller.rabbit.RabbitPool;
import ru.yandex.chemodan.queller.rabbit.RoutedMessage;
import ru.yandex.chemodan.queller.support.BenderJsonMessageConverter;
import ru.yandex.chemodan.util.bender.BenderUtils;
import ru.yandex.commune.bazinga.impl.FullJobId;
import ru.yandex.commune.bazinga.impl.JobId;
import ru.yandex.commune.bazinga.impl.JobInfoValue;
import ru.yandex.commune.bazinga.impl.OnetimeJob;
import ru.yandex.commune.bazinga.impl.OnetimeUtils;
import ru.yandex.commune.bazinga.impl.TaskOverridesManager;
import ru.yandex.commune.bazinga.impl.worker.WorkerTask;
import ru.yandex.commune.bazinga.impl.worker.WorkerTaskRegistry;
import ru.yandex.commune.bazinga.impl.worker.WorkerTaskRunResult;
import ru.yandex.commune.bazinga.impl.worker.WorkerTaskRunner;
import ru.yandex.commune.bazinga.scheduler.OnetimeTask;
import ru.yandex.commune.json.JsonObject;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.bender.serialize.BenderSerializer;
import ru.yandex.misc.lang.CharsetUtils;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.monica.core.name.MetricName;
import ru.yandex.misc.monica.util.measure.MeasureInfo;
import ru.yandex.misc.support.tl.ThreadLocalHandle;
import ru.yandex.misc.time.InstantInterval;

/**
 * @author dbrylev
 */
public class CeleryJavaWorker {
    private static final Logger logger = LoggerFactory.getLogger(CeleryJavaWorker.class);

    private static final BenderSerializer<JavaWorkerInspectStatsReply> statsReplySerializer =
            CeleryBender.mapper.createSerializer(JavaWorkerInspectStatsReply.class);

    public final CeleryJavaWorkerConfiguration configuration;

    private final WorkerTaskRunner taskRunner;
    private final WorkerTaskRegistry taskRegistry;
    private final TaskOverridesManager taskOverridesManager;

    private final RabbitPool rabbitPool;
    private final BenderJsonMessageConverter<CeleryJob> messageConverter;

    private final CeleryJavaWorkerMetrics metrics;

    public CeleryJavaWorker(
            CeleryJavaWorkerConfiguration configuration,
            WorkerTaskRunner taskRunner,
            WorkerTaskRegistry taskRegistry,
            TaskOverridesManager taskOverridesManager,
            RabbitPool rabbitPool,
            BenderJsonMessageConverter<CeleryJob> messageConverter,
            CeleryJavaWorkerMetrics metrics)
    {
        this.configuration = configuration;
        this.taskRunner = taskRunner;
        this.taskRegistry = taskRegistry;
        this.taskOverridesManager = taskOverridesManager;
        this.rabbitPool = rabbitPool;
        this.messageConverter = messageConverter;
        this.metrics = metrics;
    }

    public void handleJob(CeleryJob job) {
        Option<ThreadLocalHandle> handle = job.getYcrid().map(YandexCloudRequestIdHolder::setAndPushToNdc);
        try {
            runJob(job);

        } finally {
            handle.forEach(ThreadLocalHandle::popSafely);
        }
    }

    private void runJob(CeleryJob job) {
        InstantInterval interval = new InstantInterval(Instant.now(), Instant.now());
        Option<Throwable> error;
        try {
            WorkerTaskRunResult result = run(job);
            interval = new InstantInterval(result.context.getStart(), result.end);
            error = result.error;

        } catch (Throwable t) {
            ExceptionUtils.throwIfUnrecoverable(t);
            logger.error("Task {} failed: {}", job.forLog(), t);

            interval = interval.withEnd(Instant.now());
            error = Option.of(t);
        }

        CeleryJobContext context = job.getContext().getOrElse(CeleryJobContext.empty());

        job = job.withContext(new CeleryJobContext(
                context.globalQueued, context.created, context.scheduled,
                context.activeUid, context.ycrid, context.host,
                Option.of(configuration.workerId), Option.of(interval.getStart()), Option.of(interval.getEnd()),
                error.map(Throwable::getMessage).filterNotNull(),
                error.map(ExceptionUtils::getStackTrace), Option.empty()));

        if (error.isPresent()) {
            resubmit(job.withRetriesIncremented(), QuellerQueues.SUBMIT_QUEUE);

        } else {
            resubmit(job, QuellerQueues.COMPLETED_QUEUE);
        }
    }

    public void inspectStats(CeleryApiRequest request) {
        String workerId = configuration.workerId.serialize();

        CeleryControl.Method method = CeleryControl.Method.INSPECT_STATS_JAVA_WORKER;

        if (request.ticket.isNotParsable()) {
            return;
        }

        if (!request.method.equals(method.methodName)) {
            return;
        }

        if (request.destination.isPresent() && !request.destination.get().containsTs(workerId)) {
            return;
        }

        JavaWorkerInspectStatsReply reply = new JavaWorkerInspectStatsReply(
                rabbitPool.getWorkingRabbits().map(c -> c.connectionData.host.format()), configuration.queues);

        byte[] bytes = BenderUtils.serialize(new JsonObject(
                Cf.map(workerId, BenderUtils.toJsonObject(statsReplySerializer, reply)))).getBytes();

        MessageProperties properties = new MessageProperties();
        properties.getHeaders().put("ticket", request.ticket.getTicket().withMethod(method.name()).serialize());

        properties.setContentType(MessageProperties.CONTENT_TYPE_JSON);
        properties.setContentEncoding(CharsetUtils.UTF8_CHARSET.name());
        properties.setContentLength(bytes.length);

        RoutedMessage replyMessage = new RoutedMessage(
                new Message(bytes, properties), request.replyTo.exchange, Option.of(request.replyTo.routingKey));

        rabbitPool.sendMessages(Cf.list(replyMessage), RabbitPool.MessageType.BROADCAST);
    }

    private WorkerTaskRunResult run(CeleryJob celeryJob) {
        try {
            return doRun(celeryJob);
        } catch (Throwable t) {
            metrics.tasksPreparationErrors.inc(new MetricName(celeryJob.task.getId()));
            metrics.tasksPreparationErrorsCounter.inc();
            throw t;
        }
    }

    private WorkerTaskRunResult doRun(CeleryJob celeryJob) {
        OnetimeTask task = taskOverridesManager.getWithOverrides(taskRegistry.getOnetimeTask(celeryJob.task));

        OnetimeJob job = new OnetimeJob(
                new FullJobId(task.id(), new JobId(UUID.fromString(celeryJob.id))),
                Instant.now(), Option.empty(), CeleryOnetimeJobConverter.getParameters(celeryJob),
                JobInfoValue.ready(), CeleryOnetimeJobConverter.getAttempt(celeryJob),
                Cf.set(), Option.empty(), task.priority(), Option.empty());

        Object parameters = OnetimeUtils.parseParameters(task, job.getParameters());
        if (celeryJob.getContext().isPresent()) {
            if (parameters instanceof CeleryOnetimeTaskParameters) {
                CeleryJobContext context = celeryJob.getContext().get();
                parameters = ((CeleryOnetimeTaskParameters) parameters)
                        .withYcrid(context.ycrid)
                        .withContext(CeleryOnetimeTaskContext.builder()
                                .currentTry(Option.of(celeryJob.retries))
                                .maxRetries(context.maxRetries)
                                .activeUid(context.activeUid)
                                .created(context.created)
                                .build());
            }
        }
        task.setParameters(parameters);

        WorkerTaskRunResult result = taskRunner.run(WorkerTask.onetime(task, job));

        ListF<MetricName> metricNames = Cf.list(new MetricName(task.id().getId()));

        MeasureInfo info = new MeasureInfo(result.getDuration(), !result.error.isPresent());
        metrics.tasksInvocations.update(info, metricNames);
        metrics.tasksInvocationsCount.inc(metricNames);

        return result;
    }

    private void resubmit(CeleryJob job, Queue queue) {
        try {
            doResubmit(job, queue);
        } catch (Throwable t) {
            metrics.resubmitToRabbitErrorsCounter.inc();
            throw t;
        }
    }

    private void doResubmit(CeleryJob job, Queue queue) {
        rabbitPool.sendMessageConfirmed(new RoutedMessage(
                messageConverter.toMessage(job, new MessageProperties()),
                QuellerQueues.SUBMIT_EXCHANGE.getName(), Option.of(queue.getName())));
    }
}
