package ru.yandex.mail.so.factors.hnsw;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Consumer;
import java.util.logging.Level;

import com.github.jelmerk.knn.DistanceFunction;
import com.github.jelmerk.knn.SearchResult;
import com.github.jelmerk.knn.hnsw.HnswIndex;
import org.apache.http.concurrent.Cancellable;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.protocol.HttpContext;

import ru.yandex.concurrent.SameThreadExecutor;
import ru.yandex.http.util.SynchronizedHttpContext;
import ru.yandex.http.util.nio.client.EmptyRequestsListener;
import ru.yandex.http.util.nio.client.RequestsListener;
import ru.yandex.json.dom.JsonList;
import ru.yandex.logger.PrefixedLogger;
import ru.yandex.mail.so.factors.BasicSoFunctionInputs;
import ru.yandex.mail.so.factors.FactorsAccessViolationHandler;
import ru.yandex.mail.so.factors.LoggingFactorsAccessViolationHandler;
import ru.yandex.mail.so.factors.SoFactor;
import ru.yandex.mail.so.factors.SoFunctionInputs;
import ru.yandex.mail.so.factors.dssm.DssmEmbeddingSoFactorType;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractor;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractorContext;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractorFactoryContext;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractorsRegistry;
import ru.yandex.mail.so.factors.samples.SamplesSubscriber;
import ru.yandex.mail.so.factors.samples.SamplesSubscriberException;
import ru.yandex.mail.so.factors.types.SoFactorType;
import ru.yandex.mail.so.factors.types.StringSoFactorType;
import ru.yandex.mail.so.factors.types.TikaiteDocsSoFactorType;
import ru.yandex.parser.config.ConfigException;
import ru.yandex.parser.config.IniConfig;
import ru.yandex.parser.mail.errors.ErrorInfo;

public class HnswDssmDistanceExtractor
    implements SamplesSubscriber, SoFactorsExtractor
{
    private static final List<SoFactorType<?>> INPUTS =
        Collections.singletonList(DssmEmbeddingSoFactorType.DSSM_EMBEDDING);
    private static final List<SoFactorType<?>> OUTPUTS =
        Arrays.asList(
            HnswSoFactorType.WMD_DISTANCE,
            HnswNeighbourIdFactorType.WMD_NEIGHBOUR_ID,
            HnswNeighbourLabelsFactorType.NEIGHBOUR_LABELS,
            StringSoFactorType.STRING);
    private static final List<SoFactorType<?>> EXTRACTOR_INPUTS =
        Collections.singletonList(TikaiteDocsSoFactorType.TIKAITE_DOCS);
    private static final List<SoFactor<?>> NULL_RESULT =
        Arrays.asList(null, null, null, null);

    private final PrefixedLogger logger;
    private final SoFactorsExtractor dssmEmbeddingExtractor;
    private final HnswIndex<String, float[], BasicItem<float[]>, Float> index;
    private final LongAdder violationsCounter;

    public HnswDssmDistanceExtractor(
        final String name,
        final SoFactorsExtractorFactoryContext context,
        final IniConfig config)
        throws ConfigException
    {
        logger = context.logger().addPrefix(name);
        String extractorName = config.getString("extractor");
        dssmEmbeddingExtractor =
            context.registry().getExtractor(extractorName);
        if (dssmEmbeddingExtractor == null) {
            throw new ConfigException(
                "Extractor <" + extractorName + "> not found");
        }

        index =
            HnswIndex.newBuilder(
                Distance.INSTANCE,
                config.getInt("index-size", 1000))
                .withM(config.getInt("m", 32))
                .withEf(config.getInt("ef", 100))
                .withEfConstruction(config.getInt("ef-construction", 100))
                .withRemoveEnabled()
                .build();

        violationsCounter = context.violationsCounter();

        SoFactorsExtractor.forceOutputTypes(
            EXTRACTOR_INPUTS,
            dssmEmbeddingExtractor.inputs());
        SoFactorsExtractor.forceOutputTypes(
            INPUTS,
            dssmEmbeddingExtractor.outputs());
    }

    @Override
    public void close() {
    }

    @Override
    public List<SoFactorType<?>> inputs() {
        return INPUTS;
    }

    @Override
    public List<SoFactorType<?>> outputs() {
        return OUTPUTS;
    }

    @Override
    public void registerInternals(final SoFactorsExtractorsRegistry registry)
        throws ConfigException
    {
        HnswDssmDistanceExtractorFactory.INSTANCE.registerInternals(registry);
        dssmEmbeddingExtractor.registerInternals(registry);
    }

    @Override
    public void addSample(
        final String id,
        final String labels,
        final String fromDomain,
        final JsonList docs)
        throws SamplesSubscriberException
    {
        SoFactorsExtractorContext context =
            new FakeSoFactorsExtractorContext(violationsCounter, logger);
        BasicSoFunctionInputs inputs =
            new BasicSoFunctionInputs(
                context.accessViolationHandler(),
                TikaiteDocsSoFactorType.TIKAITE_DOCS.createFactor(docs));
        dssmEmbeddingExtractor.extract(
            context,
            inputs,
            new Callback(id, labels, fromDomain));
    }

    @Override
    public void removeSample(final String id) {
        if (index.remove(id, 0L)) {
            logger.info("Spam sample removed: " + id);
        } else {
            logger.info("Spam sample not found for removal: " + id);
        }
    }

    @Override
    public void extract(
        final SoFactorsExtractorContext context,
        final SoFunctionInputs inputs,
        final FutureCallback<? super List<SoFactor<?>>> callback)
    {
        float[] vector =
            inputs.get(0, DssmEmbeddingSoFactorType.DSSM_EMBEDDING);
        if (vector == null) {
            callback.completed(NULL_RESULT);
            return;
        }

        List<SearchResult<BasicItem<float[]>, Float>> neighbours =
            index.findNearest(vector, 1);

        if (neighbours.size() < 1) {
            context.logger().info("Zero neighbours");
            callback.completed(NULL_RESULT);
            return;
        }

        SearchResult<BasicItem<float[]>, Float> neighbour = neighbours.get(0);
        BasicItem<float[]> item = neighbour.item();

        List<SoFactor<?>> factors = new ArrayList<>(4);

        factors.add(
            HnswSoFactorType.WMD_DISTANCE.createFactor(
                (double)neighbour.distance()));
        factors.add(
            HnswNeighbourIdFactorType.WMD_NEIGHBOUR_ID.createFactor(
                item.id()));

        String labels = item.labels();
        if (labels == null) {
            factors.add(null);
        } else {
            factors.add(
                HnswNeighbourLabelsFactorType.NEIGHBOUR_LABELS.createFactor(
                    labels));
        }

        String fromDomain = item.fromDomain();
        if (fromDomain == null) {
            factors.add(null);
        } else {
            factors.add(StringSoFactorType.STRING.createFactor(fromDomain));
        }

        callback.completed(factors);
    }

    private enum Distance implements DistanceFunction<float[], Float> {
        INSTANCE;

        @Override
        public Float distance(final float[] lhs, final float[] rhs) {
            float result;
            if (lhs.length == rhs.length) {
                float sum = 1f;
                for (int i = 0; i < lhs.length; ++i) {
                    sum -= lhs[i] * rhs[i];
                }
                result = sum * 0.5f;
                if (result < 0f) {
                    result = 0f;
                } else if (result > 1f) {
                    result = 1f;
                }
            } else {
                result = 1f;
            }
            return result;
        }
    }

    private class Callback implements FutureCallback<List<SoFactor<?>>> {
        private final AtomicBoolean done = new AtomicBoolean();
        private final String id;
        private final String labels;
        private final String fromDomain;

        Callback(
            final String id,
            final String labels,
            final String fromDomain)
        {
            this.id = id;
            this.labels = labels;
            this.fromDomain = fromDomain;
        }

        @Override
        public void cancelled() {
            done.set(true);
            logger.warning("Dssm embedding extraction cancelled for id " + id);
        }

        @Override
        public void completed(final List<SoFactor<?>> factors) {
            if (done.compareAndSet(false, true)) {
                for (SoFactor<?> factor: factors) {
                    if (factor != null) {
                        SoFactorType<?> type = factor.type();
                        if (type == DssmEmbeddingSoFactorType.DSSM_EMBEDDING) {
                            index.add(
                                new BasicItem<>(
                                    id,
                                    labels,
                                    fromDomain,
                                    DssmEmbeddingSoFactorType.DSSM_EMBEDDING
                                        .cast(factor.value())));
                            logger.info(
                                "Dssm embedding extracted for id " + id);
                            return;
                        }
                    }
                }
                logger.warning("No dssm embedding extracted for id " + id);
            }
        }

        @Override
        public void failed(final Exception e) {
            if (done.compareAndSet(false, true)) {
                logger.log(
                    Level.WARNING,
                    "Dssm embedding extraction failed for id " + id,
                    e);
            }
        }
    }

    private static class FakeSoFactorsExtractorContext
        implements Consumer<ErrorInfo>, SoFactorsExtractorContext
    {
        private final HttpContext httpContext = new SynchronizedHttpContext();
        private final PrefixedLogger logger;
        private final FactorsAccessViolationHandler accessViolationHandler;

        FakeSoFactorsExtractorContext(
            final LongAdder violationsCounter,
            final PrefixedLogger logger)
        {
            this.logger = logger;
            accessViolationHandler =
                new LoggingFactorsAccessViolationHandler(
                    violationsCounter,
                    logger);
        }

        @Override
        public FactorsAccessViolationHandler accessViolationHandler() {
            return accessViolationHandler;
        }

        @Override
        public PrefixedLogger logger() {
            return logger;
        }

        @Override
        public HttpContext httpContext() {
            return httpContext;
        }

        @Override
        public RequestsListener requestsListener() {
            return EmptyRequestsListener.INSTANCE;
        }

        @Override
        public Consumer<ErrorInfo> errorsConsumer() {
            return this;
        }

        @Override
        public Executor executor() {
            return SameThreadExecutor.INSTANCE;
        }

        @Override
        public boolean debugExtractors() {
            return false;
        }

        @Override
        public Set<String> debugFlags() {
            return Collections.emptySet();
        }

        @Override
        public void accept(final ErrorInfo errorInfo) {
            logger.warning(errorInfo.toString());
        }

        // CancellationSubscriber implementation
        @Override
        public boolean cancelled() {
            return false;
        }

        @Override
        public void subscribeForCancellation(final Cancellable callback) {
        }

        @Override
        public void subscribeForCancellation(final Future<?> callback) {
        }

        @Override
        public void unsubscribeFromCancellation(final Cancellable callback) {
        }

        @Override
        public void unsubscribeFromCancellation(final Future<?> callback) {
        }
    }
}

