package ru.yandex.bannerstorage.messaging.utils;

import java.io.IOException;
import java.io.StringReader;
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;

import org.apache.commons.lang3.StringEscapeUtils;
import org.jetbrains.annotations.NotNull;

/**
 * @author egorovmv
 */
public final class MessageSerializer {
    private static final Map<Class<?>, JAXBContext> contextCache = new ConcurrentHashMap<>();

    private MessageSerializer() {
    }

    private static JAXBContext getContext(Class<?> clazz) throws JAXBException {
        JAXBContext result = contextCache.get(clazz);
        if (result == null) {
            result = JAXBContext.newInstance(clazz);
            contextCache.put(clazz, result);
        }
        return result;
    }

    public static String marshal(@NotNull Object object) {
        Objects.requireNonNull(object, "object");
        try {
            Class<?> clazz = object.getClass();
            if (clazz == String.class)
                return StringEscapeUtils.escapeXml10((String) object);
            StringWriter writer = new StringWriter();
            Marshaller marshaller = getContext(clazz).createMarshaller();
            marshaller.setProperty(Marshaller.JAXB_FRAGMENT, Boolean.TRUE);
            marshaller.marshal(object, writer);
            return writer.toString();
        } catch (JAXBException e) {
            throw new RuntimeException(e);
        }
    }

    public static <T> T unmarshal(@NotNull String message, @NotNull Class<T> clazz) {
        Objects.requireNonNull(message, "message");
        Objects.requireNonNull(clazz, "clazz");
        try {
            if (clazz == String.class) {
                @SuppressWarnings("unchecked") T result = (T) StringEscapeUtils.unescapeXml(message);
                return result;
            } else {
                StringReader reader = new StringReader(message);
                if (message.startsWith("\uFEFF")) {
                    try {
                        reader.read();
                    } catch (IOException e) {
                        throw new UncheckedIOException(e);
                    }
                }
                @SuppressWarnings("unchecked") T result = (T) getContext(clazz)
                        .createUnmarshaller()
                        .unmarshal(reader);
                return result;
            }
        } catch (JAXBException e) {
            throw new RuntimeException(e);
        }
    }
}
