package ru.yandex.travel.tracing;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import io.grpc.Metadata;
import io.opentracing.propagation.TextMap;

public class GrpcMetadataExtractAdapter implements TextMap {
    private final Map<String, List<String>> headers;

    public GrpcMetadataExtractAdapter(Metadata metadata) {
        this.headers = metadataHeadersToMultiMap(metadata);
    }

    @Override
    public Iterator<Map.Entry<String, String>> iterator() {
        return new MultivaluedMapFlatIterator<>(headers.entrySet());
    }

    @Override
    public void put(String key, String value) {
        throw new UnsupportedOperationException("This class should be used only with Tracer.extract()!");
    }

    protected Map<String, List<String>> metadataHeadersToMultiMap(Metadata metadata) {
        Map<String, List<String>> headersResult = new HashMap<>();

        List<Metadata.Key<String>> headerKeys = metadata.keys().stream()
                .map(stringKey -> Metadata.Key.of(stringKey, Metadata.ASCII_STRING_MARSHALLER))
                .collect(Collectors.toList());
        headerKeys.forEach(headerKey -> {
            List<String> valuesList = new ArrayList<>(1);
            Iterable<String> valuesIt = metadata.getAll(headerKey);
            if (valuesIt != null) {
                valuesIt.forEach(valuesList::add);
            }
            headersResult.put(headerKey.originalName(), valuesList);
        });

        return headersResult;
    }

    public static final class MultivaluedMapFlatIterator<K, V> implements Iterator<Map.Entry<K, V>> {

        private final Iterator<Map.Entry<K, List<V>>> mapIterator;
        private Map.Entry<K, List<V>> mapEntry;
        private Iterator<V> listIterator;

        public MultivaluedMapFlatIterator(Set<Map.Entry<K, List<V>>> multiValuesEntrySet) {
            this.mapIterator = multiValuesEntrySet.iterator();
        }

        @Override
        public boolean hasNext() {
            if (listIterator != null && listIterator.hasNext()) {
                return true;
            }

            return mapIterator.hasNext();
        }

        @Override
        public Map.Entry<K, V> next() {
            if (mapEntry == null || (!listIterator.hasNext() && mapIterator.hasNext())) {
                mapEntry = mapIterator.next();
                listIterator = mapEntry.getValue().iterator();
            }

            if (listIterator.hasNext()) {
                return new AbstractMap.SimpleImmutableEntry<>(mapEntry.getKey(), listIterator.next());
            } else {
                return new AbstractMap.SimpleImmutableEntry<>(mapEntry.getKey(), null);
            }
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }
}
