package ru.yandex.crypta.lab.yt;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import javax.inject.Inject;

import com.google.common.collect.Lists;

import ru.yandex.crypta.clients.pgaas.PostgresClient;
import ru.yandex.crypta.common.exception.Exceptions;
import ru.yandex.crypta.lab.Site2VecService;
import ru.yandex.crypta.lab.base.BaseYtService;
import ru.yandex.crypta.lib.proto.EEnvironment;
import ru.yandex.crypta.lib.yt.YtReadingUtils;
import ru.yandex.crypta.lib.yt.YtService;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTree;
import ru.yandex.inside.yt.kosher.tables.YTableEntryTypes;
import ru.yandex.inside.yt.kosher.ytree.YTreeMapNode;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;
import ru.yandex.misc.lang.number.UnsignedLong;

public class DefaultSite2VecService extends BaseYtService<Site2VecService> implements Site2VecService {

    @Inject
    DefaultSite2VecService(EEnvironment environment, YtService yt, PostgresClient sql) {
        super(environment, sql, yt);
    }

    private YPath hostsPath() {
        return site2VecRootPath().child("site2vec");
    }

    private YPath yandexuidsPath() {
        return site2VecRootPath().child("yandexuid_vectors_monthly");
    }

    private YPath site2VecRootPath() {
        return YPath.cypressRoot()
                .child("home")
                .child("crypta")
                .child("production")
                .child("profiles")
                .child("export")
                .child("vectors");
    }

    private Vector readVector(YPath path) {
        List<YTreeMapNode> records = new ArrayList<>(1);
        yt().tables().read(path, YTableEntryTypes.YSON, (Consumer<YTreeMapNode>) records::add);
        if (records.isEmpty()) {
            throw Exceptions.notFound();
        }
        return bytesToVector(records.get(0).getBytes("vector"));
    }

    private Vector bytesToVector(byte[] bytes) {
        float[] results = new float[bytes.length / 4];
        ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(results);
        return new Vector(results);
    }

    private UnsignedLong yandexuidAsUint(String yandexuid) {
        return UnsignedLong.valueOf(Long.parseUnsignedLong(yandexuid));
    }

    private float cosine(Vector lhs, Vector rhs) {
        float result = 0.0f;
        float normLhs = 0.0f;
        float normRhs = 0.0f;
        for (int i = 0; i < lhs.getValues().length; i++) {
            result += lhs.getValues()[i] * rhs.getValues()[i];
            normLhs += lhs.getValues()[i] * lhs.getValues()[i];
            normRhs += rhs.getValues()[i] * rhs.getValues()[i];
        }
        return (float) (result / (Math.sqrt(normLhs) * Math.sqrt(normRhs)));
    }

    @Override
    public Vector getVectorByYandexuid(String yandexuid) {
        YTreeNode key = YTree.unsignedIntegerNode(yandexuidAsUint(yandexuid).longValue());
        return readVector(yandexuidsPath().withExact(YtReadingUtils.exact(key)));
    }

    @Override
    public Vector getVectorByHost(String host) {
        YTreeNode key = YTree.stringNode(host);
        return readVector(hostsPath().withExact(YtReadingUtils.exact(key)));
    }

    @SuppressWarnings("unchecked")
    @Override
    public List<Similarity> computePairwiseSimilarities(List<String> yandexuids, List<String> hosts) {
        return Lists.cartesianProduct(yandexuids, hosts).stream().map(pair -> {
            String yandexuid = pair.get(0);
            String host = pair.get(1);
            Vector yandexuidVector = getVectorByYandexuid(yandexuid);
            Vector hostVector = getVectorByHost(host);
            return new Similarity(yandexuid, host, cosine(yandexuidVector, hostVector));
        }).collect(Collectors.toList());
    }

    @Override
    public DefaultSite2VecService clone() {
        return new DefaultSite2VecService(environment(), ytService(), sql());
    }
}
