# -*- coding: utf-8 -*-

import mock
from passport.backend.core.builders.base.faker.fake_builder import BaseFakeBuilder
from passport.backend.core.builders.tensornet.tensornet import TensorNet
from passport.backend.core.test.test_utils import single_entrant_patch
from passport.backend.utils.math import logit
from quality.tsnet.protos import ts_calc_pb2


class FakeTensorNet(BaseFakeBuilder):
    """
    Тестовая обертка для API TensorNet
    """
    def __init__(self):
        super(FakeTensorNet, self).__init__(TensorNet)

        self.set_tensornet_response_value = self.set_response_value_without_method
        self.set_tensornet_response_side_effect = self.set_response_side_effect_without_method


def tensornet_eval_response(estimate):
    """
    Ответ вызова eval API TensorNet. Для получения невалидного ответа
    (пустой список Target) необходимо передать значение None.
    """
    response = ts_calc_pb2.TMasterResponse()
    if estimate is not None:
        target = logit(estimate)
        response.Target.append(target)
    return response.SerializeToString()


@single_entrant_patch
class FakeLocalTensorNet(object):
    """
    Тестовая обертка для локального вызова tsnet2
    """
    def __init__(self, class_path):
        self._mock = mock.Mock()
        self._patch = mock.patch(
            class_path,
            mock.Mock(return_value=self._mock),
        )
        self.set_predict_return_value(1.0)

    def start(self):
        self._patch.start()

    def stop(self):
        self._patch.stop()

    def set_predict_side_effect(self, side_effect):
        self._mock.predict.side_effect = side_effect

    def set_predict_return_value(self, value):
        self._mock.predict.return_value = value
        self._mock.predict.side_effect = None

    @property
    def predict_call_count(self):
        return self._mock.predict.call_count
