package ru.yandex.qe.bus.exception;

import java.lang.reflect.Method;
import java.util.function.Function;

import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;

import org.apache.cxf.jaxrs.impl.WebApplicationExceptionMapper;
import org.apache.cxf.jaxrs.utils.JAXRSUtils;
import org.apache.cxf.message.Message;
import org.apache.cxf.phase.PhaseInterceptorChain;

/**
 * @author rurikk
 */
public class ExceptionToResponseMapper implements ExceptionMapper<Throwable> {
    public static final String EXCEPTION_TRANSFER_HEADER = "exception-transfer";
    public static final String EXCEPTION_TRANSFER_VERSION = "1";

    private WebApplicationExceptionMapper defaultWaeMapper = new WebApplicationExceptionMapper();
    private ServerExceptionMapper defaultServerExceptionMapper = new ServerExceptionMapper();

    @Override
    public Response toResponse(Throwable exception) {
        Function<Throwable, WebApplicationException> translator = getExceptionTranslator();
        if (translator != null) {
            WebApplicationException wae = translator.apply(exception);
            return waeToResponse(wae);
        } else {
            return fallbackToOldBehavior(exception);
        }
    }

    private static Response waeToResponse(WebApplicationException wae) {
        return JAXRSUtils.fromResponse(wae.getResponse())
                .header(EXCEPTION_TRANSFER_HEADER, EXCEPTION_TRANSFER_VERSION)
                .entity(JsonException.fromThrowable(wae).asString())
                .type(MediaType.APPLICATION_JSON_TYPE)
                .build();
    }

    private static Function<Throwable, WebApplicationException> getExceptionTranslator() {
        Method invokedMethod = getInvokedMethod();
        if (invokedMethod == null) {
            return null;
        }
        ExceptionTranslator annotation = invokedMethod.getDeclaringClass().getAnnotation(ExceptionTranslator.class);
        if (annotation == null) {
            return null;
        }
        try {
            return annotation.value().newInstance();
        } catch (IllegalAccessException | InstantiationException e) {
            throw new RuntimeException(e);
        }
    }

    private static Method getInvokedMethod() {
        Message msg = PhaseInterceptorChain.getCurrentMessage();
        return (Method) msg.get("org.apache.cxf.resource.method");
    }

    private Response fallbackToOldBehavior(Throwable exception) {
        return exception instanceof WebApplicationException
                ? defaultWaeMapper.toResponse((WebApplicationException) exception)
                : defaultServerExceptionMapper.toResponse(exception);
    }

    public void setDefaultWaeMapper(final WebApplicationExceptionMapper defaultWaeMapper) {
        this.defaultWaeMapper = defaultWaeMapper;
    }

    public void setDefaultServerExceptionMapper(final ServerExceptionMapper defaultServerExceptionMapper) {
        this.defaultServerExceptionMapper = defaultServerExceptionMapper;
    }
}
