package ru.yandex.direct.api.v5.ws.exceptionresolver;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Optional;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.xml.namespace.QName;
import javax.xml.ws.WebFault;

import com.yandex.direct.api.v5.general.ApiExceptionMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.Ordered;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.ws.WebServiceMessage;
import org.springframework.ws.context.MessageContext;
import org.springframework.ws.server.EndpointExceptionResolver;

import ru.yandex.direct.api.v5.context.ApiContextHolder;
import ru.yandex.direct.api.v5.entity.ApiValidationException;
import ru.yandex.direct.api.v5.service.accelinfo.AccelInfoHeaderSetter;
import ru.yandex.direct.api.v5.units.ApiUnitsService;
import ru.yandex.direct.api.v5.ws.ApiMessage;
import ru.yandex.direct.core.TranslatableException;

/**
 * Exception resolver creates response message by given exception.
 * <p>
 * The need for custom exception resolver caused by two problem:
 * 1. the requirement to return all faults in the predefined format
 * 2. the difference in creation of fault message for JsonMessage and SoapMessage types
 * <p>
 * The first problem solved by {@link ExceptionTranslator} that translate any exceptions to {@link WebFault}
 * format. The second case covered by delegation of message creation to {@link FaultResponseCreator}
 * <p>
 * {@link ExceptionTranslator} and {@link ExceptionTranslator} discovered by DI.
 */
@Component
public class ApiExceptionResolver implements EndpointExceptionResolver, Ordered {
    private static final Logger logger = LoggerFactory.getLogger(ApiExceptionResolver.class);

    // Method to get faultInfo as described in JAX-WS spec
    private static final String GET_FAULT_INFO_METHOD = "getFaultInfo";

    private static final int UNKNOWN_ERROR_CODE = -1;

    private final ApiContextHolder apiContextHolder;
    private final ApiUnitsService apiUnitsService;
    private final List<ExceptionTranslator> exceptionTranslators;
    private final List<FaultResponseCreator> faultResponseCreators;
    private final AccelInfoHeaderSetter accelInfoHeaderSetter;

    @Autowired
    public ApiExceptionResolver(
            ApiContextHolder apiContextHolder,
            ApiUnitsService apiUnitsService,
            List<ExceptionTranslator> exceptionTranslators,
            List<FaultResponseCreator> faultResponseCreators,
            AccelInfoHeaderSetter accelInfoHeaderSetter) {
        this.apiContextHolder = apiContextHolder;
        this.apiUnitsService = apiUnitsService;
        this.exceptionTranslators = exceptionTranslators;
        this.faultResponseCreators = faultResponseCreators;
        this.accelInfoHeaderSetter = accelInfoHeaderSetter;
    }

    private static ApiExceptionMessage getApiExceptionMessage(WebServiceMessage response) {
        if (!(response instanceof ApiMessage)) {
            return null;
        }
        ApiMessage apiMessage = (ApiMessage) response;
        if (!(apiMessage.getApiFault() instanceof ApiExceptionMessage)) {
            return null;
        }
        return (ApiExceptionMessage) apiMessage.getApiFault();
    }

    @Override
    public int getOrder() {
        // must be invoked before default spring exception resolvers
        return Ordered.HIGHEST_PRECEDENCE;
    }

    /**
     * Определить надо ли списывать баллы за ошибку
     */
    boolean shouldChargeUnitsForError(Exception ex) {
        if (!apiContextHolder.get().shouldChargeUnitsForRequest())
            return false;

        // Если это не ошибка пользователя, то баллы не списываем. Считаем, что TranslatableException
        // это признак того, что была ошибка со стороны пользователя. В противном случае это наш баг и за
        // это мы пользователя не штрафуем
        if (!(ex instanceof TranslatableException))
            return false;

        // Если это ошибка валидации, то проверяем списаны ли уже были баллы
        return !(ex instanceof ApiValidationException) || !((ApiValidationException) ex).isUnitsSpent();
    }

    public boolean resolveException(WebServiceMessage response, Exception ex) {
        if (!createFaultResponse(response, translateException(ex))) {
            return false;
        }

        int appErrorCode = UNKNOWN_ERROR_CODE;

        if (shouldChargeUnitsForError(ex)) {
            // Пытаемся определить код произошедшей ошибки для того чтобы списать правильное кол-во баллов
            // Делаем это на основе результирующего сообщения, потому что:
            // 1) Не хотим дублировать механизм извлечения кода ошибки
            // 2) ExceptionTranslator-ы могут подменять исключения
            // 3) Так как ExceptionResolver-ы обрабатываются специальным образом Spring-ом. В частности,
            //    ExceptionResolver-ы обрабатывают ошибки от Endpoint-ов и Interceptor-ов и упаковывают их
            //    в правильный формат. Но нам дополнительно надо обрабатывать ошибки parsing-га в Adapter-е,
            //    который расположен уровнем выше
            ApiExceptionMessage faultInfo = getApiExceptionMessage(response);

            if (faultInfo != null) {
                appErrorCode = faultInfo.getErrorCode();
            }

            apiUnitsService.withdrawForRequestError(appErrorCode == UNKNOWN_ERROR_CODE ? null : appErrorCode);
        }

        apiContextHolder.get().setAppErrorCode(appErrorCode);
        accelInfoHeaderSetter.setAccelInfoHeaderToHttpResponse();

        return true;
    }

    @Override
    public boolean resolveException(MessageContext messageContext, Object endpoint, Exception ex) {
        WebServiceMessage response = messageContext.getResponse();
        return resolveException(response, ex);
    }

    @Nonnull
    private Exception translateException(Exception ex) {
        if (ex.getClass().getAnnotation(WebFault.class) != null) {
            return ex;
        }

        for (ExceptionTranslator translator : exceptionTranslators) {
            try {
                Optional<? extends Exception> translationResult = translator.translate(ex);
                if (translationResult.isPresent()) {
                    logger.debug("Exception {} translated by {}", ex, translator);
                    return translationResult.get();
                }
            } catch (RuntimeException translatorException) {
                logger.error("{} throw an unexpected exception while translated exception {}:",
                        translator, ex, translatorException);
            }
        }
        logger.error("Exception {} not translated", ex);
        return ex;
    }

    boolean createFaultResponse(WebServiceMessage response, Exception translatedException) {
        String faultMessage = getFaultMessage(translatedException);
        QName faultQName = getFaultQName(translatedException);
        Object faultInfo = getFaultInfoBean(translatedException);
        for (FaultResponseCreator responseCreator : faultResponseCreators) {
            if (responseCreator.support(response)) {
                try {
                    logger.debug("Create fault response by {}", responseCreator);
                    responseCreator.createFaultResponse(response, faultMessage, faultQName, faultInfo);
                    return true;
                } catch (Exception creatorException) {
                    logger.error("FaultResponseCreator has thrown an unexpected exception", creatorException);
                }
            }
        }
        return false;
    }

    private String getFaultMessage(Exception ex) {
        return StringUtils.hasLength(ex.getMessage()) ? ex.getMessage() : ex.toString();
    }

    @Nullable
    private QName getFaultQName(Exception ex) {
        WebFault webFault = ex.getClass().getAnnotation(WebFault.class);
        if (webFault == null) {
            return null;
        }
        return new QName(webFault.targetNamespace(), webFault.name());
    }

    @Nullable
    private Object getFaultInfoBean(Exception webFaultEx) {
        try {
            Method getFaultInfo = webFaultEx.getClass().getMethod(GET_FAULT_INFO_METHOD);
            return getFaultInfo.invoke(webFaultEx);
        } catch (InvocationTargetException | IllegalAccessException | NoSuchMethodException e) {
            logger.warn("Exception '{}' does not conform the specification of WebFault", webFaultEx.toString());
            return null;
        }
    }
}
