from abc import abstractmethod
from typing import Dict

import tensorflow as tf
from .box_ops import compute_iou
from .standard_fields import BoxField


class Similarity:

    def __call__(self, inputs1: Dict[str, tf.Tensor], inputs2: Dict[str, tf.Tensor]):
        return self.call(inputs1, inputs2)

    @abstractmethod
    def call(self, inputs1, inputs2) -> tf.Tensor:
        pass


class IoUSimilarity(Similarity):

    def call(self, y_true: Dict[str, tf.Tensor], anchors: Dict[str, tf.Tensor]):
        """Computes pairwise intersection-over-union between boxes.

        Return:

        A 3-D tensor of float32 with shape [batch_size, N, M] representing
        pairwise  similarity scores defined in DeTr.
        """

        return compute_iou(y_true[BoxField.BOXES], anchors[BoxField.BOXES])
