package ru.yandex.mail.so.factors.hnsw;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Consumer;
import java.util.logging.Level;

import com.github.jelmerk.knn.SearchResult;
import com.github.jelmerk.knn.hnsw.HnswIndex;
import org.apache.http.concurrent.Cancellable;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.protocol.HttpContext;

import ru.yandex.concurrent.SameThreadExecutor;
import ru.yandex.http.util.SynchronizedHttpContext;
import ru.yandex.http.util.nio.client.EmptyRequestsListener;
import ru.yandex.http.util.nio.client.RequestsListener;
import ru.yandex.json.dom.JsonList;
import ru.yandex.json.dom.JsonMap;
import ru.yandex.json.dom.JsonObject;
import ru.yandex.json.parser.JsonException;
import ru.yandex.logger.PrefixedLogger;
import ru.yandex.mail.so.factors.BasicSoFunctionInputs;
import ru.yandex.mail.so.factors.FactorsAccessViolationHandler;
import ru.yandex.mail.so.factors.LoggingFactorsAccessViolationHandler;
import ru.yandex.mail.so.factors.SoFactor;
import ru.yandex.mail.so.factors.SoFunctionInputs;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractor;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractorContext;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractorFactoryContext;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractorsRegistry;
import ru.yandex.mail.so.factors.extractors.TextPartExtractor;
import ru.yandex.mail.so.factors.fasttext.FastTextEmbedding;
import ru.yandex.mail.so.factors.fasttext.FastTextEmbeddingSoFactorType;
import ru.yandex.mail.so.factors.samples.SamplesSubscriber;
import ru.yandex.mail.so.factors.samples.SamplesSubscriberException;
import ru.yandex.mail.so.factors.types.LongSoFactorType;
import ru.yandex.mail.so.factors.types.SoFactorType;
import ru.yandex.mail.so.factors.types.StringSoFactorType;
import ru.yandex.parser.config.ConfigException;
import ru.yandex.parser.config.IniConfig;
import ru.yandex.parser.mail.errors.ErrorInfo;
import ru.yandex.parser.string.NonEmptyValidator;
import ru.yandex.search.document.mail.MailMetaInfo;

public class HnswExtractor implements SamplesSubscriber, SoFactorsExtractor {
    private static final double MIN_DISTANCE_DIFF = 0.001;
    private static final List<SoFactorType<?>> INPUTS =
        Collections.singletonList(
            FastTextEmbeddingSoFactorType.FAST_TEXT_EMBEDDING);
    private static final List<SoFactorType<?>> OUTPUTS =
        Arrays.asList(
            HnswSoFactorType.WMD_DISTANCE,
            HnswNeighbourIdFactorType.WMD_NEIGHBOUR_ID,
            LongSoFactorType.LONG,
            HnswNeighbourLabelsFactorType.NEIGHBOUR_LABELS,
            StringSoFactorType.STRING);
    private static final List<SoFactorType<?>> EXTRACTOR_INPUTS =
        Collections.singletonList(StringSoFactorType.STRING);
    private static final List<SoFactor<?>> NULL_RESULT =
        Arrays.asList(null, null, null, null, null);

    private final PrefixedLogger logger;
    private final SoFactorsExtractor fastTextEmbeddingExtractor;
    private final String sampleFieldName;
    private final String requiredLabel;
    private final boolean allParts;
    private final HnswIndex<
        String,
        FastTextEmbedding,
        BasicItem<FastTextEmbedding>,
        Float>
        index;
    private final LongAdder violationsCounter;

    public HnswExtractor(
        final String name,
        final SoFactorsExtractorFactoryContext context,
        final IniConfig config)
        throws ConfigException
    {
        logger = context.logger().addPrefix(name);
        String extractorName = config.getString("extractor");
        fastTextEmbeddingExtractor =
            context.registry().getExtractor(extractorName);
        if (fastTextEmbeddingExtractor == null) {
            throw new ConfigException(
                "Extractor <" + extractorName + "> not found");
        }
        sampleFieldName =
            config.get("sample-field-name", NonEmptyValidator.TRIMMED);
        requiredLabel =
            config.get("required-label", null, NonEmptyValidator.TRIMMED);
        allParts = config.getBoolean("all-parts", false);
        index =
            HnswIndex.newBuilder(
                config.getEnum(
                    WmdDistance.class,
                    "distance-type",
                    WmdDistance.RELAXED),
                config.getInt("index-size", 1000))
                .withM(config.getInt("m", 32))
                .withEf(config.getInt("ef", 100))
                .withEfConstruction(config.getInt("ef-construction", 100))
                .withRemoveEnabled()
                .build();
        violationsCounter = context.violationsCounter();

        SoFactorsExtractor.forceOutputTypes(
            EXTRACTOR_INPUTS,
            fastTextEmbeddingExtractor.inputs());
        SoFactorsExtractor.forceOutputTypes(
            INPUTS,
            fastTextEmbeddingExtractor.outputs());
    }

    @Override
    public void close() {
    }

    @Override
    public List<SoFactorType<?>> inputs() {
        return INPUTS;
    }

    @Override
    public List<SoFactorType<?>> outputs() {
        return OUTPUTS;
    }

    @Override
    public void registerInternals(final SoFactorsExtractorsRegistry registry)
        throws ConfigException
    {
        HnswExtractorFactory.INSTANCE.registerInternals(registry);
        fastTextEmbeddingExtractor.registerInternals(registry);
    }

    public int indexSize() {
        return index.size();
    }

    public String extractText(final JsonList docs, final String hid)
        throws JsonException
    {
        JsonMap textPart;
        if (hid == null) {
            textPart = TextPartExtractor.extractTextPart(docs);
        } else {
            textPart = JsonMap.EMPTY;
            for (JsonObject docObject: docs) {
                JsonMap doc = docObject.asMap();
                String partHid = doc.getOrNull(MailMetaInfo.HID);
                if (hid.equals(partHid)) {
                    textPart = doc;
                    break;
                }
            }
        }
        String text = textPart.get(sampleFieldName).asStringOrNull();
        if (text != null && text.trim().isEmpty()) {
            text = null;
        }
        return text;
    }

    @Override
    public void addSample(
        final String id,
        final String labels,
        final String fromDomain,
        final JsonList docs)
        throws SamplesSubscriberException
    {
        try {
            Set<String> labelsSet;
            if (labels == null) {
                labelsSet = Collections.emptySet();
            } else {
                labelsSet =
                    new HashSet<>(Arrays.asList(labels.split("\n")));
            }
            if (requiredLabel != null) {
                if (!labelsSet.contains(requiredLabel)) {
                    logger.info(
                        "Skipping " + id
                        + ", because required label <" + requiredLabel
                        + "> wasn't found in " + labelsSet);
                    return;
                }
            }
            if (allParts) {
                for (JsonObject doc: docs) {
                    String text = doc.get(sampleFieldName).asStringOrNull();
                    String docId =
                        id + '/' + doc.get("hid").asStringOrNull();
                    if (text == null || text.trim().isEmpty()) {
                        logger.info(
                            "Skipping " + id + ", because it has empty text");
                    } else {
                        SoFactorsExtractorContext context =
                            new FakeSoFactorsExtractorContext(
                                violationsCounter,
                                logger);
                        fastTextEmbeddingExtractor.extract(
                            context,
                            new BasicSoFunctionInputs(
                                context.accessViolationHandler(),
                                StringSoFactorType.STRING.createFactor(text)),
                            new Callback(docId, labels, fromDomain));
                    }
                }
            } else {
                boolean hidFound = false;
                for (String label: labelsSet) {
                    if (label.startsWith("hid_")) {
                        hidFound = true;
                        String hid = label.substring(4);
                        String text = extractText(docs, hid);
                        if (text == null) {
                            logger.warning(
                                "Skipping " + id + '/' + hid
                                + ", because it has empty text or"
                                + " part wasn't found");
                        } else {
                            SoFactorsExtractorContext context =
                                new FakeSoFactorsExtractorContext(
                                    violationsCounter,
                                    logger);
                            fastTextEmbeddingExtractor.extract(
                                context,
                                new BasicSoFunctionInputs(
                                    context.accessViolationHandler(),
                                    StringSoFactorType.STRING.createFactor(
                                        text)),
                                new Callback(id, labels, fromDomain));
                        }
                    }
                }
                if (!hidFound) {
                    String text = extractText(docs, null);
                    if (text == null) {
                        logger.warning(
                            "Skipping " + id
                            + ", because it has empty text or"
                            + " text part wasn't found");
                    } else {
                        SoFactorsExtractorContext context =
                            new FakeSoFactorsExtractorContext(
                                violationsCounter,
                                logger);
                        fastTextEmbeddingExtractor.extract(
                            context,
                            new BasicSoFunctionInputs(
                                context.accessViolationHandler(),
                                StringSoFactorType.STRING.createFactor(text)),
                            new Callback(id, labels, fromDomain));
                    }
                }
            }
        } catch (JsonException e) {
            throw new SamplesSubscriberException(e);
        }
    }

    @Override
    public void removeSample(final String id) {
        if (index.remove(id, 0L)) {
            logger.info("Spam sample removed: " + id);
        } else {
            logger.info("Spam sample not found for removal: " + id);
        }
    }

    public List<SearchResult<BasicItem<FastTextEmbedding>, Float>> neighbours(
        final FastTextEmbedding embedding,
        final int count)
    {
        return index.findNearest(embedding, count);
    }

    @Override
    public void extract(
        final SoFactorsExtractorContext context,
        final SoFunctionInputs inputs,
        final FutureCallback<? super List<SoFactor<?>>> callback)
    {
        FastTextEmbedding embedding =
            inputs.get(0, FastTextEmbeddingSoFactorType.FAST_TEXT_EMBEDDING);
        if (embedding == null) {
            callback.completed(NULL_RESULT);
            return;
        }
        List<SearchResult<BasicItem<FastTextEmbedding>, Float>> neighbours =
            neighbours(embedding, 1);

        if (neighbours.size() < 1) {
            context.logger().info("Zero neighbours");
            callback.completed(NULL_RESULT);
            return;
        }

        SearchResult<BasicItem<FastTextEmbedding>, Float> neighbour =
            neighbours.get(0);
        BasicItem<FastTextEmbedding> item = neighbour.item();

        List<SoFactor<?>> factors = new ArrayList<>(5);

        factors.add(
            HnswSoFactorType.WMD_DISTANCE.createFactor(
                (double) neighbour.distance()));
        factors.add(
            HnswNeighbourIdFactorType.WMD_NEIGHBOUR_ID.createFactor(
                item.id()));
        factors.add(
            LongSoFactorType.LONG.createFactor(
                (long) item.vector().wordCount()));
        String labels = item.labels();
        if (labels == null) {
            factors.add(null);
        } else {
            factors.add(
                HnswNeighbourLabelsFactorType.NEIGHBOUR_LABELS.createFactor(
                    labels));
        }

        String fromDomain = item.fromDomain();
        if (fromDomain == null) {
            factors.add(null);
        } else {
            factors.add(StringSoFactorType.STRING.createFactor(fromDomain));
        }

        callback.completed(factors);
    }

    private class Callback implements FutureCallback<List<SoFactor<?>>> {
        private final AtomicBoolean done = new AtomicBoolean();
        private final String id;
        private final String labels;
        private final String fromDomain;

        Callback(
            final String id,
            final String labels,
            final String fromDomain)
        {
            this.id = id;
            this.labels = labels;
            this.fromDomain = fromDomain;
        }

        @Override
        public void cancelled() {
            done.set(true);
            logger.warning(
                "Fast text embedding extraction cancelled for id " + id);
        }

        private boolean add(final SoFactor<?> factor) {
            SoFactorType<?> type = factor.type();
            if (type == FastTextEmbeddingSoFactorType.FAST_TEXT_EMBEDDING) {
                FastTextEmbedding embedding =
                    FastTextEmbeddingSoFactorType.FAST_TEXT_EMBEDDING.cast(
                        factor.value());
                if (embedding == null) {
                    logger.warning("Null embedding");
                    return false;
                }
                List<SearchResult<BasicItem<FastTextEmbedding>, Float>>
                    neighbours = neighbours(embedding, 1);
                if (neighbours.size() >= 1
                    && neighbours.get(0).distance() < MIN_DISTANCE_DIFF)
                {
                    BasicItem<FastTextEmbedding> item =
                        neighbours.get(0).item();
                    if (Objects.equals(labels, item.labels())
                        && Objects.equals(fromDomain, item.fromDomain()))
                    {
                        logger.info(
                            "Skipping spam sample " + id
                            + ", because it is similar to " + item.id()
                            + ", distance = " + neighbours.get(0).distance());
                        return false;
                    }
                }
                index.add(
                    new BasicItem<>(
                        id,
                        labels,
                        fromDomain,
                        embedding));
                return true;
            } else {
                return false;
            }
        }

        @Override
        public void completed(final List<SoFactor<?>> factors) {
            if (done.compareAndSet(false, true)) {
                boolean added = false;
                if (factors.size() == 1) {
                    SoFactor<?> factor = factors.get(0);
                    if (factor != null) {
                        added = add(factor);
                    }
                }
                if (added) {
                    logger.info(
                        "Fast text embedding extracted for id " + id);
                } else {
                    logger.warning(
                        "No fast text embedding extracted for id " + id);
                }
            }
        }

        @Override
        public void failed(final Exception e) {
            if (done.compareAndSet(false, true)) {
                logger.log(
                    Level.WARNING,
                    "Fast text embedding extraction failed for id " + id,
                    e);
            }
        }
    }

    private static class FakeSoFactorsExtractorContext
        implements Consumer<ErrorInfo>, SoFactorsExtractorContext
    {
        private final HttpContext httpContext = new SynchronizedHttpContext();
        private final PrefixedLogger logger;
        private final FactorsAccessViolationHandler accessViolationHandler;

        FakeSoFactorsExtractorContext(
            final LongAdder violationsCounter,
            final PrefixedLogger logger)
        {
            this.logger = logger;
            accessViolationHandler =
                new LoggingFactorsAccessViolationHandler(
                    violationsCounter,
                    logger);
        }

        @Override
        public FactorsAccessViolationHandler accessViolationHandler() {
            return accessViolationHandler;
        }

        @Override
        public PrefixedLogger logger() {
            return logger;
        }

        @Override
        public HttpContext httpContext() {
            return httpContext;
        }

        @Override
        public RequestsListener requestsListener() {
            return EmptyRequestsListener.INSTANCE;
        }

        @Override
        public Consumer<ErrorInfo> errorsConsumer() {
            return this;
        }

        @Override
        public Executor executor() {
            return SameThreadExecutor.INSTANCE;
        }

        @Override
        public boolean debugExtractors() {
            return false;
        }

        @Override
        public Set<String> debugFlags() {
            return Collections.emptySet();
        }

        @Override
        public void accept(final ErrorInfo errorInfo) {
            logger.warning(errorInfo.toString());
        }

        // CancellationSubscriber implementation
        @Override
        public boolean cancelled() {
            return false;
        }

        @Override
        public void subscribeForCancellation(final Cancellable callback) {
        }

        @Override
        public void subscribeForCancellation(final Future<?> callback) {
        }

        @Override
        public void unsubscribeFromCancellation(final Cancellable callback) {
        }

        @Override
        public void unsubscribeFromCancellation(final Future<?> callback) {
        }
    }
}
