import pytest

import nn_applier
from library.python import resource


def test_load_predict():
    m = nn_applier.Model("joint_output_bigrams_softsign.dssm")
    res = m.predict({"query": "hello", "doc_url": "http://ya.ru", "doc_title": "Yandex"}, ["joint_output_bigrams_softsign"])
    assert len(res) == 1
    assert abs(res[0] - 0.45467755) < 1e-5

    with pytest.raises(RuntimeError):
        m.predict({"query": "hello"}, ["joint_output_bigrams_softsign"])

    with pytest.raises(ValueError):
        # Not DSSM model
        nn_applier.Model("not_dssm_model")
    with pytest.raises(RuntimeError):
        # Directory path
        nn_applier.Model(".")
    with pytest.raises(IOError):
        # Wrong path
        nn_applier.Model("wrong_path")
    with pytest.raises(IOError):
        # Empty path
        nn_applier.Model("")


def test_loading_from_bytes():
    model_bytes = resource.find("joint_output_bigrams_softsign")
    model = nn_applier.Model.from_bytes(model_bytes)
    annotations_variables_dict = {"query": "hello", "doc_url": "http://ya.ru", "doc_title": "Yandex"}
    outputs = ["joint_output_bigrams_softsign"]
    result = model.predict(annotations_variables_dict, outputs)

    assert len(result) == 1
    assert abs(result[0] - 0.45467755) < 1e-5


def test_loading_from_incorrect_bytes():
    with open("not_dssm_model", "rb") as model_file:
        model_bytes = model_file.read()

    with pytest.raises(ValueError):
        nn_applier.Model.from_bytes(model_bytes)


def test_wrong_type():
    model_bytes = resource.find("joint_output_bigrams_softsign")
    model = nn_applier.Model.from_bytes(model_bytes)
    outputs = ["joint_output_bigrams_softsign"]

    annotations_variables_dict = {
        "query": "hello",
        "doc_url": "http://ya.ru",
        "doc_title": "Yandex",
        "none_field": None
    }
    result = model.predict(annotations_variables_dict, outputs)
    assert len(result) == 1

    annotations_variables_dict = {
        "query": "hello",
        "doc_url": "http://ya.ru",
        "doc_title": "Yandex",
        "int_field": 1
    }
    result = model.predict(annotations_variables_dict, outputs)
    assert len(result) == 1
