package ru.yandex.so.dssm.applier;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import ru.yandex.function.GenericAutoCloseable;
import ru.yandex.zen.JNILoader;

public abstract class DssmApplierBase
    implements GenericAutoCloseable<RuntimeException>
{
    static {
        JNILoader.loadLibrary("dssm_applier");
    }

    private final ReadWriteLock lock = new ReentrantReadWriteLock();
    private final Lock readLock = lock.readLock();
    private final Lock writeLock = lock.writeLock();
    private long handle;

    protected DssmApplierBase(final long handle) {
        this.handle = handle;
    }

    protected abstract float[] doApply(long handle, String text);

    protected abstract void doDestroy(long handle);

    public float[] apply(final String input) throws DssmApplierException {
        readLock.lock();
        try {
            long handle = this.handle;
            if (handle == 0L) {
                throw new DssmApplierException("DssmApplier already closed");
            }
            return doApply(handle, input);
        } catch (RuntimeException e) {
            throw new DssmApplierException(
                "Failed to process input <" + input + '>',
                e);
        } finally {
            readLock.unlock();
        }
    }

    public List<float[]> applyBatch(final List<String> inputs)
        throws DssmApplierException
    {
        int size = inputs.size();
        List<float[]> outputs = new ArrayList<>(size);
        int i = 0;
        String input = null;
        readLock.lock();
        try {
            long handle = this.handle;
            if (handle == 0L) {
                throw new DssmApplierException("DssmApplier already closed");
            }
            for (;i < size; ++i) {
                input = inputs.get(i);
                outputs.add(doApply(handle, input));
            }
        } catch (RuntimeException e) {
            throw new DssmApplierException(
                "Failed to process input #" + i + ": <" + input + '>',
                e);
        } finally {
            readLock.unlock();
        }
        return outputs;
    }

    public List<float[]> applyBatch(final Reader reader)
        throws DssmApplierException, IOException
    {
        List<float[]> outputs = new ArrayList<>();
        try (BufferedReader bufferedReader = new BufferedReader(reader)) {
            int i = 0;
            String input = null;
            readLock.lock();
            try {
                long handle = this.handle;
                if (handle == 0L) {
                    throw new DssmApplierException(
                        "DssmApplier already closed");
                }
                while (true) {
                    input = bufferedReader.readLine();
                    if (input == null) {
                        break;
                    }
                    ++i;
                    outputs.add(doApply(handle, input));
                }
            } catch (RuntimeException e) {
                throw new DssmApplierException(
                    "Failed to process input #" + i + ": <" + input + '>',
                    e);
            } finally {
                readLock.unlock();
            }
        }
        return outputs;
    }

    @Override
    public void close() {
        writeLock.lock();
        try {
            if (handle != 0L) {
                long handle = this.handle;
                this.handle = 0L;
                doDestroy(handle);
            }
        } finally {
            writeLock.unlock();
        }
    }

    @Override
    @SuppressWarnings("deprecation")
    protected void finalize() {
        close();
    }

    public static void main(final String... args)
        throws DssmApplierException, IOException
    {
        try (DssmApplier applier =
                new DssmApplier(
                    args[0],
                    args[1],
                    args[2].equals("true"),
                    true);
            Reader reader =
                new InputStreamReader(System.in, StandardCharsets.UTF_8))
        {
            List<float[]> outputs = applier.applyBatch(reader);
            StringBuilder sb = new StringBuilder();
            System.out.println("Embeddings:");
            for (int i = 0; i < outputs.size(); ++i) {
                sb.setLength(0);
                sb.append(i);
                sb.append(':');
                for (float f: outputs.get(i)) {
                    sb.append(' ');
                    sb.append(f);
                }
                System.out.println(sb.toString());
            }

            System.out.println();
            System.out.println("Distances:");
            int size = outputs.size();
            for (int i = 0; i < size; ++i) {
                float[] left = outputs.get(i);
                for (int j = i + 1; j < size; ++j) {
                    float[] right = outputs.get(j);
                    double sum = 0;
                    for (int k = 0; k < left.length; ++k) {
                        sum += ((double) left[k]) * right[k];
                    }
                    System.err.println(
                        "dist(" + i + ", " + j + ") = " + (1 - sum) / 2);
                }
            }
        }
    }
}

