# coding: utf-8
# cython: wraparound=False

from libcpp.vector cimport vector

from util.generic.string cimport TString
from util.generic.vector cimport TVector
from util.system.types cimport ui8

import logging
import six


logger = logging.getLogger("nn_applier")


cdef extern from "bindings/python/dssm_nn_applier_lib/error.h":
    cdef void raise_py_error()

cdef extern from "util/memory/blob.h":
    cdef cppclass TBlob:
        pass

cdef extern from "util/memory/blob.h":
    cdef TBlob TBlob_FromFile "TBlob::FromFile"(TString& filename) except +raise_py_error
    cdef TBlob TBlob_NoCopy "TBlob::NoCopy"(const void* data, size_t length)

cdef extern from "util/generic/ptr.h":
    cdef cppclass TAtomicSharedPtr[TYPE]:
        TAtomicSharedPtr(TYPE* t) nogil

cdef extern from "kernel/dssm_applier/nn_applier/lib/states.h" namespace "NNeuralNetApplier":
    cdef cppclass ISample:
        pass
    cdef cppclass TSample(ISample):
        TSample(TVector[TString] annotations, TVector[TString] variables)

cdef extern from "kernel/dssm_applier/nn_applier/lib/layers.h" namespace "NNeuralNetApplier":
    cdef cppclass TModel:
        void Apply(TAtomicSharedPtr[ISample] sample, TVector[TString]& variables, TVector[float]& result) nogil except +raise_py_error
        void Apply(TAtomicSharedPtr[ISample] sample, TVector[TString]& variables, TVector[ui8]& result) nogil except +raise_py_error
        void Load(const TBlob& blob) except +raise_py_error
        void Init() except +raise_py_error

cdef TString _to_string(s):
    if s is None:
        return TString()
    try:
        encoded = six.ensure_binary(s)
        return TString(<const char*>encoded, len(encoded))
    except TypeError:
        logger.warning("Error while converting %s to bytes", s)
        return TString()

cdef class Model:
    cdef TModel* model
    cdef object _model_bytes  # Used to prevent disposing (in case of loading from memory)

    def __cinit__(self, model_file = None, model_bytes = None):
        if (model_file is None) == (model_bytes is None):
            raise ValueError("Using of either `model_file` or `model_bytes` is required")

        cdef TBlob blob
        if model_file is not None:
            blob = TBlob_FromFile(_to_string(model_file))
        else:
            blob = TBlob_NoCopy(<const char*> model_bytes, len(model_bytes))
            self._model_bytes = model_bytes  # Used to prevent disposing

        self.model = new TModel()
        self.model.Load(blob)
        self.model.Init()

    def __dealloc__(self):
        del self.model

    @staticmethod
    def from_file(model_file):
        model = Model(model_file=model_file)
        return model

    @staticmethod
    def from_bytes(model_bytes):
        model = Model(model_bytes=model_bytes)
        return model

    cdef _predict(self, inputs_dict, output_variables, output_type):
        cdef TVector[TString] anns
        cdef TVector[TString] strs
        for (a, s) in inputs_dict.iteritems():
            anns.push_back(_to_string(a))
            strs.push_back(_to_string(s))

        cdef TVector[TString] varis
        for v in output_variables:
            varis.push_back(_to_string(v))

        sample = new TSample(anns, strs)
        cdef TVector[float] float_results
        cdef TVector[ui8] bytes_results
        if output_type == "float":
            with nogil:
                self.model.Apply(TAtomicSharedPtr[ISample](sample), varis, float_results)
            return [float_results.at(i) for i in range(float_results.size())]
        elif output_type == "byte":
            with nogil:
                self.model.Apply(TAtomicSharedPtr[ISample](sample), varis, bytes_results)
            return [bytes_results.at(i) for i in range(bytes_results.size())]
        else:
            raise Exception("unknown output type")


    def predict(self, annotations_variables_dict, outputs, output_type="float"):
        return self._predict(annotations_variables_dict, outputs, output_type)

# vim: set ts=4 et sw=4 ai:
