package ru.yandex.webmaster3.storage.nca;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey;
import java.util.Base64;
import java.util.List;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;

import ru.yandex.webmaster3.storage.nca.data.LogEntryType;
import ru.yandex.webmaster3.storage.nca.data.ParsedCertificateState;
import ru.yandex.webmaster3.storage.nca.data.ParsedLogEntry;
import ru.yandex.webmaster3.storage.nca.data.UnparsedEntry;

/**
 * @author kravchenko99
 * @date 4/26/22
 * Основные идеи взяты из файла, но исправлен баг
 * https://github.com/google/certificate-transparency-java/blob/master/src/main/java/org/certificatetransparency
 * /ctlog/serialization/Deserializer.java
 */
@Slf4j
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class DeserializerCertificateTransparencyLogs {
    private static final int VERSION_V1 = 0;
    private static final int TIMESTAMPED_ENTRY_LEAF_TYPE = 0;

    private static final int ONE_BYTE = 1;
    private static final int TWO_BYTES = 2;
    private static final int THREE_BYTES = 3;
    private static final int TIMESTAMP_LENGTH = 8;

    private static final String BEGIN_CERT = "-----BEGIN CERTIFICATE-----";
    private static final String END_CERT = "-----END CERTIFICATE-----";
    private static final String BEGIN_KEY = "-----BEGIN PUBLIC KEY-----";
    private static final String END_KEY = "-----END PUBLIC KEY-----";
    private static final String LINE_SEPARATOR = System.getProperty("line.separator");

    private static final char[] HEX_ARRAY = "0123456789abcdef".toCharArray();
    private static String bytesToHex(byte[] bytes) {
        char[] hexChars = new char[bytes.length * 2];
        for (int j = 0; j < bytes.length; j++) {
            int v = bytes[j] & 0xFF;
            hexChars[j * 2] = HEX_ARRAY[v >>> 4];
            hexChars[j * 2 + 1] = HEX_ARRAY[v & 0x0F];
        }
        return new String(hexChars);
    }

    private static String formatCrtFileContents(final Certificate certificate) throws CertificateEncodingException {
        final Base64.Encoder encoder = Base64.getMimeEncoder(64, LINE_SEPARATOR.getBytes());

        final byte[] rawCrtText = certificate.getEncoded();
        final String encodedCertText = new String(encoder.encode(rawCrtText));
        return String.join(LINE_SEPARATOR, BEGIN_CERT, encodedCertText, END_CERT);
    }

    private static String formatKeyFileContents(final Certificate certificate) {
        final Base64.Encoder encoder = Base64.getMimeEncoder(64, LINE_SEPARATOR.getBytes());

        final byte[] rawCrtText = certificate.getPublicKey().getEncoded();
        final String encodedCertText = new String(encoder.encode(rawCrtText));
        return String.join(LINE_SEPARATOR, BEGIN_KEY, encodedCertText, END_KEY);
    }

    public static ParsedCertificateState parseCertificate(LogEntryType logEntryType, byte[] certBytes) {
        try {
            X509Certificate cert = parseCertificates(certBytes);
            List<String> domains =
                    cert.getSubjectAlternativeNames() != null
            ? cert.getSubjectAlternativeNames().stream().map(x -> x.get(1)).map(Object::toString).toList()
                    : List.of();
            String algorithmPublicKey = cert.getPublicKey().getAlgorithm();
            if (cert.getPublicKey() instanceof RSAPublicKey rsaPublicKey) {
                algorithmPublicKey += " " + rsaPublicKey.getModulus().bitLength() + "-bit";
            }
            return new ParsedCertificateState(
                    domains,
                    cert.getSubjectX500Principal().toString(),
                    cert.getIssuerX500Principal().toString(),
                    cert.getNotBefore(),
                    cert.getNotAfter(),
                    logEntryType,
                    cert.getSerialNumber().toString(),
                    bytesToHex(cert.getSignature()),
                    cert.getSigAlgName(),
                    formatCrtFileContents(cert),
                    algorithmPublicKey,
                    formatKeyFileContents(cert)
            );
        } catch (Exception e) {
            log.error("Error was occurred while parsing certificate - {}", e.getMessage());
            throw new RuntimeException(e);
        }
    }

    public static X509Certificate parseCertificates(byte[] certBytes) {
        CertificateFactory factory;
        try {
            factory = CertificateFactory.getInstance("X.509");
        } catch (CertificateException e) {
            throw new RuntimeException("Failure getting X.509 factory", e);
        }

        try {
            Certificate cert = factory.generateCertificate(new ByteArrayInputStream(certBytes));
            if (cert instanceof X509Certificate x509Cert) {
                return x509Cert;
            } else {
                throw new RuntimeException("Can't cast certificate to X.509 ");
            }
        } catch (CertificateException e) {
            throw new RuntimeException("Not a valid PEM stream", e);
        }
    }

    public static ParsedLogEntry parseLogEntry(UnparsedEntry unparsedEntry) {
        var merkleTreeLeaf = new ByteArrayInputStream(Base64.getDecoder().decode(unparsedEntry.getLeafInput()));
        var extraData = new ByteArrayInputStream(Base64.getDecoder().decode(unparsedEntry.getExtraData()));
        TimestampedEntry timestampedEntry = parseMerkleTreeLeaf(merkleTreeLeaf);
        LogEntryType entryType = timestampedEntry.logEntryType;

        if (entryType == LogEntryType.X509_ENTRY) {
            // в данной реализации не вытаскиваем всю цепочку поэтому можно не заглядывать в extraData
            return new ParsedLogEntry(entryType, timestampedEntry.getX509());
        } else if (entryType == LogEntryType.PRECERT_ENTRY) {
            // нужно добыть сертификат который должен лежать первым в PrecertChainEntry
            byte[] preCertX509 =
                    parsePrecertChainEntry(extraData);
            return new ParsedLogEntry(entryType, preCertX509);
        } else {
            throw new RuntimeException(String.format("Unknown entry type: %s", entryType));
        }
    }

    /**
     * Parses PrecertChainEntry structure.
     *
     * @param in PrecertChainEntry structure, byte stream of binary encoding.
     * @return {@link byte[]} precert x509 with poison.
     */
    private static byte[] parsePrecertChainEntry(InputStream in) {
        int length = (int) readNumber(in, THREE_BYTES);
        return readFixedLength(in, length);
    }

    /**
     * Parses a {@link TimestampedEntry} from binary encoding.
     *
     * @param in byte stream of binary encoding.
     * @return Built {@link TimestampedEntry}.
     * @throws RuntimeException if the data stream is too short.
     */
    private static TimestampedEntry parseMerkleTreeLeaf(InputStream in) {
        int version = (int) readNumber(in, ONE_BYTE);

        // пока только одна версия поэтому не стал плодить enums
        if (version != VERSION_V1) {
            throw new RuntimeException(String.format("Unknown version: %d", version));
        }

        int leafType = (int) readNumber(in, ONE_BYTE);
        // пока только один тип поэтому не стал плодить enums
        if (leafType != TIMESTAMPED_ENTRY_LEAF_TYPE) {
            throw new RuntimeException(String.format("Unknown entry type: %d", leafType));
        }

        return parseTimestampedEntry(in);
    }

    /**
     * Parses a {@link TimestampedEntry} from binary encoding.
     *
     * @param in byte stream of binary encoding.
     * @return Built {@link TimestampedEntry}.
     */
    private static TimestampedEntry parseTimestampedEntry(InputStream in) {
        TimestampedEntry.TimestampedEntryBuilder builder = TimestampedEntry.builder();
        readNumber(in, TIMESTAMP_LENGTH);

        int entryType = (int) readNumber(in, TWO_BYTES);
        builder.logEntryType(LogEntryType.R.fromValue(entryType));

        if (entryType == LogEntryType.X509_ENTRY.value()) {
            int length = (int) readNumber(in, 3);
            builder.x509(readFixedLength(in, length));
        } else if (entryType == LogEntryType.PRECERT_ENTRY.value()) {
            // что хранится тут нам не нужно, а сам x509 лежит в extraData
        } else {
            throw new RuntimeException(String.format("Unknown entry type: %d", entryType));
        }
        return builder.build();
    }

    /**
     * Reads a fixed-length byte array.
     *
     * @param inputStream byte stream of binary encoding.
     * @param dataLength  exact data length.
     * @return read byte array.
     */
    private static byte[] readFixedLength(InputStream inputStream, int dataLength) {
        byte[] toReturn = new byte[dataLength];
        try {
            int bytesRead = inputStream.read(toReturn);
            if (bytesRead < dataLength) {
                throw new RuntimeException(
                        String.format("Not enough bytes: Expected %d, got %d.", dataLength, bytesRead));
            }
            return toReturn;
        } catch (IOException e) {
            throw new RuntimeException("Error while reading fixed-length buffer", e);
        }
    }

    /**
     * Read a number of numBytes bytes (Assuming MSB first).
     *
     * @param inputStream byte stream of binary encoding.
     * @param numBytes    exact number of bytes representing this number.
     * @return a number of at most 2^numBytes
     */
    private static long readNumber(InputStream inputStream, int numBytes) {

        long toReturn = 0;
        try {
            for (int i = 0; i < numBytes; i++) {
                int valRead = inputStream.read();
                if (valRead < 0) {
                    throw new RuntimeException(
                            String.format("Missing length bytes: Expected %d, got %d.", numBytes, i));
                }
                toReturn = (toReturn << 8) | valRead;
            }
            return toReturn;
        } catch (IOException e) {
            throw new RuntimeException("IO Error when reading number", e);
        }
    }

    @Value
    @Builder
    public static class TimestampedEntry {
        LogEntryType logEntryType;
        byte[] x509;
    }
}
