package ru.yandex.qe.http.certificates;

import java.io.IOException;
import java.io.StringReader;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMException;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;

import static java.util.stream.Collectors.toList;

/**
 * @author rurikk
 */
public class PemParser {
    private static final JcaPEMKeyConverter keyConverter = new JcaPEMKeyConverter()
            .setProvider(new BouncyCastleProvider());
    private static final JcaX509CertificateConverter certificateConverter = new JcaX509CertificateConverter()
            .setProvider(new BouncyCastleProvider());


    public static List<X509Certificate> parseX509s(String certPem) throws IllegalArgumentException {
        try {
            return PemParser.<X509CertificateHolder>parse(certPem).stream()
                    .map(PemParser::convert)
                    .collect(toList());
        } catch (Exception e) {
            throw new IllegalArgumentException("Cannot parse X.509 PEM", e);
        }
    }

    public static List<Object> parseAll(String certPem)  throws IllegalArgumentException {
        try {
            return PemParser.parse(certPem).stream()
                    .map(PemParser::convertAny)
                    .collect(Collectors.toList());
        } catch (Exception e) {
            throw new IllegalArgumentException("Cannot parse pem: " +e.getMessage(), e);
        }
    }

    public static X509Certificate parseX509(String certPem) throws IllegalArgumentException {
        try {
            return convert((X509CertificateHolder) exactOne(parse(certPem)));
        } catch (Exception e) {
            throw new IllegalArgumentException("Cannot parse X.509 PEM", e);
        }
    }

    public static PrivateKey parseRsa(String keyPem) throws IllegalArgumentException {
        try {
            return convertKey(exactOne(parse(keyPem)));
        } catch (Exception e) {
            throw new IllegalArgumentException("Cannot parse key PEM", e);
        }
    }

    private static PrivateKey convertKey(Object key) {
        try {
            if (key instanceof PEMKeyPair) {
                return keyConverter.getKeyPair((PEMKeyPair) key).getPrivate();
            }
            if (key instanceof PrivateKeyInfo) {
                return keyConverter.getPrivateKey((PrivateKeyInfo) key);
            }
            throw new UnsupportedOperationException("Unsupported key class: " + key.getClass());
        } catch (PEMException e) {
            throw Throwables.propagate(e);
        }
    }

    private static X509Certificate convert(X509CertificateHolder h) {
        try {
            return certificateConverter.getCertificate(h);
        } catch (CertificateException e) {
            throw Throwables.propagate(e);
        }
    }

    private static Object convertAny(Object source) {
        try {
            return repository.get(source.getClass()).convert(source);
        } catch (CertificateException|IOException e) {
            throw new IllegalArgumentException("Cannot parse pem", e);
        }
    }

    @SuppressWarnings("unchecked")
    private static <T> List<T> parse(String rsaKeyPem) throws IOException {
        PEMParser parser = new PEMParser(new StringReader(rsaKeyPem));
        ArrayList<T> res = new ArrayList<>();
        for (Object o; (o = parser.readObject()) != null; ) {
            res.add((T) o);
        }
        return res;
    }

    private static <T> T exactOne(List<T> c) {
        Preconditions.checkState(c.size() == 1);
        return c.get(0);
    }

    @FunctionalInterface
    private interface Convertor {
        Object convert(Object source) throws CertificateException, IOException;
    }

    private static class ConvertorRepository {
        private final Map<Class<?>, Convertor> convertorMap = new HashMap<>();

        public ConvertorRepository() {
        }

        public <T> ConvertorRepository put(Class<T> clazz, Convertor convertor) {
            convertorMap.put(clazz, convertor);
            return this;
        }
        @SuppressWarnings("unchecked")
        public Convertor get(Class<?> clazz) {
            if (!convertorMap.containsKey(clazz)) {
                throw new IllegalArgumentException("Not supported source: " + clazz.getName());
            }
            return convertorMap.get(clazz);
        }
    }

    private final static ConvertorRepository repository = new ConvertorRepository()
            .put(X509CertificateHolder.class, h -> certificateConverter.getCertificate((X509CertificateHolder)h))
            .put(PrivateKeyInfo.class, p -> keyConverter.getPrivateKey((PrivateKeyInfo)p))
            .put(PEMKeyPair.class, p -> keyConverter.getKeyPair((PEMKeyPair)p).getPrivate())
            .put(SubjectPublicKeyInfo.class, p-> keyConverter.getPublicKey((SubjectPublicKeyInfo)p));

}
