package ru.yandex.crypta.graph2.dao.yt.proto;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
import java.util.stream.Collectors;

import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;

import ru.yandex.inside.yt.kosher.common.ProtobufYtFormat;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.operations.Yield;
import ru.yandex.inside.yt.kosher.tables.CloseableIterator;
import ru.yandex.inside.yt.kosher.tables.RetriableIterator;
import ru.yandex.inside.yt.kosher.tables.SimpleRetriableIterator;
import ru.yandex.inside.yt.kosher.tables.YTableEntryType;
import ru.yandex.inside.yt.kosher.ytree.YTreeStringNode;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.lang.Check;

/**
 * One-of protobuf rec format.
 * in which several types of recs are represented with single proto message with one-of field
 */
public class NativeProtobufOneOfMessageEntryType<T extends Message> implements YTableEntryType<T> {

    private final static int CONTROL_ATTR_TABLE_INDEX = -1;

    private final static int CONTROL_ATTR_KEY_SWITCH = -2;
    private final static int CONTROL_ATTR_RANGE_INDEX = -3;
    private final static int CONTROL_ATTR_ROW_INDEX = -4;
    private final Message.Builder builder;
    private final YTreeStringNode format;
    private final Descriptors.OneofDescriptor oneofDescriptor;
    private final List<Message.Builder> fieldsBuilders;

    public NativeProtobufOneOfMessageEntryType(Message.Builder builder, boolean enumsAsStrings) {
        this.builder = builder;

        Descriptors.Descriptor descriptor = builder.getDescriptorForType();
        List<Descriptors.OneofDescriptor> oneofs = descriptor.getOneofs();
        Check.equals(1, oneofs.size(), "Single one-of should contain all tables messsages");

        oneofDescriptor = oneofs.get(0);

        List<Descriptors.Descriptor> subMessages = this.oneofDescriptor.getFields().stream()
                .map(Descriptors.FieldDescriptor::getMessageType)
                .collect(Collectors.toList());

        // pre-compute builders for performance
        this.fieldsBuilders = this.oneofDescriptor.getFields().stream()
                .map(builder::getFieldBuilder)
                .collect(Collectors.toList());

        this.format = ProtobufYtFormat.fromDescriptors(subMessages, enumsAsStrings).spec();

    }

    @Override
    public YTreeStringNode format() {
        return this.format;
    }

    @Override
    public CloseableIterator<T> iterator(InputStream input) {
        return new CloseableIterator<>() {
            final CodedInputStream in = CodedInputStream.newInstance(input);
            boolean hasNextChecked = false;
            T next;
            int size = 0;

            int tableIndex = 0;

            @Override
            public boolean hasNext() {
                hasNextChecked = true;
                if (next != null) {
                    return true;
                }

                try {
                    if (in.isAtEnd()) {
                        return false;
                    }

                    in.resetSizeCounter();

                    size = in.readFixed32();

                    while (size < 0) {
                        switch (size) {
                            case CONTROL_ATTR_KEY_SWITCH -> size = in.readFixed32();
                            case CONTROL_ATTR_TABLE_INDEX -> {
                                tableIndex = in.readFixed32();
                                size = in.readFixed32();
                            }
                            case CONTROL_ATTR_ROW_INDEX -> {
                                in.readFixed64();
                                size = in.readFixed32();
                            }
                            case CONTROL_ATTR_RANGE_INDEX -> {
                                in.readFixed32();
                                size = in.readFixed32();
                            }
                            default -> throw new RuntimeException("broken stream");
                        }
                    }

                    if (size <= 0) {
                        throw new IllegalStateException();
                    }

                    byte[] bytes = in.readRawBytes(size);

                    Descriptors.FieldDescriptor subMessageDescriptor = oneofDescriptor.getField(tableIndex);
                    Message.Builder subMessageBuilder = fieldsBuilders.get(tableIndex);

                    subMessageBuilder.clear();
                    Message subMessage = subMessageBuilder.mergeFrom(bytes).build();

                    builder.clear();
                    builder.setField(subMessageDescriptor, subMessage);
                    next = (T) builder.build();

                    return true;
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }

            @Override
            public T next() {
                if (!hasNextChecked && !hasNext()) {
                    throw new IllegalStateException();
                }
                T ret = next;
                next = null;
                hasNextChecked = false;
                return ret;
            }

            @Override
            public void close() throws Exception {
                input.close();
            }
        };
    }

    @Override
    public RetriableIterator<T> iterator(YPath path, InputStream input, long startRowIndex) {
        return new SimpleRetriableIterator<>(path, iterator(input), startRowIndex);
    }

    @Override
    public Yield<T> yield(OutputStream[] output) {

        CodedOutputStream[] writers = new CodedOutputStream[output.length];
        for (int i = 0; i < output.length; ++i) {
            writers[i] = CodedOutputStream.newInstance(output[i]);
        }

        return new Yield<>() {
            @Override
            public void yield(int index, T oneOfMessage) {
                try {

                    Descriptors.FieldDescriptor field = oneofDescriptor.getField(index);
                    if (!oneOfMessage.hasField(field)) {
                        // TODO: use yield without index and determine index depending on one-of type
                        throw new IllegalStateException("Corresponding one-of message is not set " + index);
                    }
                    Message message = (Message) oneOfMessage.getField(field);
                    byte[] bytes = message.toByteArray();
                    writers[index].writeFixed32NoTag(bytes.length);
                    writers[index].writeRawBytes(bytes);
                    writers[index].flush();
                } catch (IOException e) {
                    throw ExceptionUtils.translate(e);
                }
            }

            @Override
            public void yield(T oneOfMessage) {
                Descriptors.FieldDescriptor field = oneOfMessage.getOneofFieldDescriptor(oneofDescriptor);
                int tableIndex = field.getIndex();
                this.yield(tableIndex, oneOfMessage);
            }

            @Override
            public void close() throws IOException {
                for (OutputStream stream : output) {
                    stream.close();
                }
            }
        };
    }

    public Message.Builder getBuilder() {
        return builder;
    }

    public Descriptors.OneofDescriptor getOneofDescriptor() {
        return oneofDescriptor;
    }

    public List<Message.Builder> getMessageBuilders() {
        return fieldsBuilders;
    }
}
