package ru.yandex.intranet.d.web.log;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import org.reactivestreams.Subscription;
import org.slf4j.MDC;
import org.springframework.lang.NonNull;
import reactor.core.CoreSubscriber;
import reactor.util.context.Context;

import ru.yandex.intranet.d.util.MdcTaskDecorator;

/**
 * Push MDC from reactor context for reactor threads.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
public class MdcContextSupplier<T> implements CoreSubscriber<T> {

    public static final String COMMON_CONTEXT_KEY = "common";

    private final CoreSubscriber<T> coreSubscriber;

    public MdcContextSupplier(CoreSubscriber<T> coreSubscriber) {
        this.coreSubscriber = coreSubscriber;
    }

    @Override
    public void onSubscribe(@NonNull Subscription subscription) {
        MdcResult mdcResult = supplyMdc(coreSubscriber.currentContext());
        try {
            coreSubscriber.onSubscribe(subscription);
        } finally {
            restoreMdc(mdcResult);
        }
    }

    @Override
    public void onNext(T obj) {
        MdcResult mdcResult = supplyMdc(coreSubscriber.currentContext());
        try {
            coreSubscriber.onNext(obj);
        } finally {
            restoreMdc(mdcResult);
        }
    }

    @Override
    public void onError(Throwable t) {
        MdcResult mdcResult = supplyMdc(coreSubscriber.currentContext());
        try {
            coreSubscriber.onError(t);
        } finally {
            restoreMdc(mdcResult);
        }
    }

    @Override
    public void onComplete() {
        MdcResult mdcResult = supplyMdc(coreSubscriber.currentContext());
        try {
            coreSubscriber.onComplete();
        } finally {
            restoreMdc(mdcResult);
        }
    }

    @NonNull
    @Override
    public Context currentContext() {
        return coreSubscriber.currentContext();
    }

    private MdcResult supplyMdc(Context context) {
        Map<String, String> oldMdc = new HashMap<>();
        if (context.hasKey(COMMON_CONTEXT_KEY)) {
            Map<String, String> commonContext = context.getOrDefault(COMMON_CONTEXT_KEY, Map.of());
            if (commonContext != null) {
                commonContext.forEach((key, v) -> oldMdc.put(key, MDC.get(key)));
                commonContext.forEach(MDC::put);
            }
        }
        if (context.hasKey(AccessLogAttributesProducer.LOG_ID)) {
            String logId = context.get(AccessLogAttributesProducer.LOG_ID);
            String mdcLogId = MDC.get(MdcTaskDecorator.LOG_ID_MDC_KEY);
            if (!Objects.equals(logId, mdcLogId)) {
                MDC.put(MdcTaskDecorator.LOG_ID_MDC_KEY, logId);
                oldMdc.put(MdcTaskDecorator.LOG_ID_MDC_KEY, mdcLogId);
            }
        } else {
            String mdcLogId = MDC.get(MdcTaskDecorator.LOG_ID_MDC_KEY);
            if (mdcLogId != null) {
                MDC.remove(MdcTaskDecorator.LOG_ID_MDC_KEY);
                oldMdc.put(MdcTaskDecorator.LOG_ID_MDC_KEY, mdcLogId);
            }
        }
        return new MdcResult(oldMdc);
    }

    private void restoreMdc(MdcResult mdcResult) {
        mdcResult.getOldValues().forEach((key, value) -> {
            if (value != null) {
                MDC.put(key, value);
            } else {
                MDC.remove(key);
            }
        });
    }
    private static final class MdcResult {

        private final Map<String, String> oldValues;

        private MdcResult(Map<String, String> oldValues) {
            this.oldValues = oldValues;
        }

        public Map<String, String> getOldValues() {
            return oldValues;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            MdcResult mdcResult = (MdcResult) o;
            return Objects.equals(getOldValues(), mdcResult.getOldValues());
        }

        @Override
        public int hashCode() {
            return Objects.hash(getOldValues());
        }

        @Override
        public String toString() {
            return "MdcResult{" +
                    "oldValues=" + oldValues +
                    '}';
        }
    }

}
