package ru.yandex.direct.api.v5.ws.exceptionresolver;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBElement;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.namespace.QName;
import javax.xml.transform.Result;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.ws.WebServiceMessage;
import org.springframework.ws.soap.SoapFault;
import org.springframework.ws.soap.SoapMessage;

import ru.yandex.direct.api.v5.ws.ApiMessage;


/**
 * Creates error response at {@link SoapMessage}
 *
 * @see SoapMessage
 * @see ApiExceptionResolver
 */
@Component
public class SoapFaultResponseCreator implements FaultResponseCreator {
    private static final Logger logger = LoggerFactory.getLogger(SoapFaultResponseCreator.class);

    private final ConcurrentMap<Class<?>, JAXBContext> jaxbContexts = new ConcurrentHashMap<>();

    @Override
    public boolean support(WebServiceMessage response) {
        return response instanceof SoapMessage;
    }

    @Override
    @SuppressWarnings("unchecked")
    public void createFaultResponse(WebServiceMessage response, String faultMessage, QName webFaultQName,
                                    Object faultInfo) {
        // get fault detail from SoapMessage
        SoapFault fault = ((SoapMessage) response).getSoapBody().addClientOrSenderFault(faultMessage, null);
        Result result = fault.addFaultDetail().getResult();
        ((ApiMessage) response).setApiFault(faultInfo);

        if (webFaultQName != null && faultInfo != null) {
            // serialize faultInfo to fault detail
            Class<?> faultInfoClass = faultInfo.getClass();
            JAXBElement soapFaultDetailElement = new JAXBElement(webFaultQName, faultInfoClass, faultInfo);
            try {
                JAXBContext jaxbContext = getJaxbContext(faultInfoClass);
                Marshaller jaxbMarshaller = jaxbContext.createMarshaller();
                jaxbMarshaller.marshal(soapFaultDetailElement, result);
            } catch (JAXBException e) {
                logger.warn("Can't serialize faultInfo", e);
            }
        }
    }

    private JAXBContext getJaxbContext(Class<?> clazz) throws JAXBException {
        JAXBContext jaxbContext = jaxbContexts.get(clazz);
        if (jaxbContext == null) {
            jaxbContext = JAXBContext.newInstance(clazz);
            jaxbContexts.putIfAbsent(clazz, jaxbContext);
        }
        return jaxbContext;
    }

}
