import os
import json

import tensorflow as tf

from ..io_wrappers.io_wrapper import IOWrapper
from .data_preprocessor import DataPreprocessor

from ...faster_rcnn_impl.core.standard_fields import BoxField, DatasetField
from ...faster_rcnn_impl.core.constants import BATCH_IMAGE_HEIGHT, BATCH_IMAGE_WIDTH
from ...faster_rcnn_impl.utils.ops import resize_and_pad_image


IMAGE_HEIGHT = 'image/height'
IMAGE_WIDTH = 'image/width'
IMAGE = 'image/encoded'
BBOX_XMIN = 'image/object/bbox/xmin'
BBOX_XMAX = 'image/object/bbox/xmax'
BBOX_YMIN = 'image/object/bbox/ymin'
BBOX_YMAX = 'image/object/bbox/ymax'
CLASS_LABEL = 'image/object/class/label'
CLASS_TEXT = 'image/object/class/text'


def get_lines(f, res, num_lines) :
    res.clear()

    for i in range(num_lines) :
        conts = f.readline()
        if conts == "" :
            break
        res.append(json.loads(conts))


class FasterRCNNPreprocessor(DataPreprocessor):
    """data that should be provided to the preprocessor is a TFRecord with the following features:
    --- image/encoded: 3-channel image in binary format that can be decoded by tf.decode_image()
    --- image/height: image height in pixels
    --- image/width: image height in pixels
    --- image/object/bbox/y(x)min(max): bounding boxes borders in range [0..1]
    --- image/object/class/label: numerical representation of the class in range [1...n]
    --- image/object/class/text: class name
    """

    def __init__(
        self, params: dict = None, io_wrapper: IOWrapper = None
    ) -> None:
        self.ind_to_str = dict()
        self.str_to_ind = dict()
        self.feature = {
            IMAGE_HEIGHT:     tf.io.FixedLenFeature((), tf.int64, default_value=1),
            IMAGE_WIDTH:      tf.io.FixedLenFeature((), tf.int64, default_value=1),
            IMAGE:    tf.io.FixedLenFeature((), tf.string, default_value=''),
            BBOX_XMIN:   tf.io.VarLenFeature(tf.float32),
            BBOX_XMAX:   tf.io.VarLenFeature(tf.float32),
            BBOX_YMIN:   tf.io.VarLenFeature(tf.float32),
            BBOX_YMAX:   tf.io.VarLenFeature(tf.float32),
            CLASS_LABEL:  tf.io.VarLenFeature(tf.int64),
        }

        super().__init__(params=params, io_wrapper=io_wrapper)

    def preprocess_data(self, batch) :
        fs = tf.io.parse_sequence_example(batch, self.feature)[0]
        x = dict()

        fn = lambda tmp: resize_and_pad_image(tf.io.decode_jpeg(tmp, channels=3))
        x[DatasetField.IMAGES] = tf.map_fn(fn, fs[IMAGE], fn_output_signature=tf.float32)
        x[DatasetField.IMAGES_INFO] = tf.cast(tf.stack([fs[IMAGE_HEIGHT], fs[IMAGE_WIDTH]], axis=1), tf.float32)
        scale_aspect = tf.reduce_min(tf.convert_to_tensor([[BATCH_IMAGE_HEIGHT, BATCH_IMAGE_WIDTH]], dtype=tf.float32) / x[DatasetField.IMAGES_INFO], 1, True)
        x[DatasetField.IMAGES_INFO] = x[DatasetField.IMAGES_INFO] * scale_aspect

        y = dict()
        y[BoxField.LABELS] = tf.cast(tf.sparse.to_dense(fs[CLASS_LABEL]) - 1, tf.int32)
        y[BoxField.WEIGHTS] = tf.cast(tf.ones_like(y[BoxField.LABELS]), tf.float32)
        y[BoxField.BOXES] = tf.cast(tf.stack([
            tf.sparse.to_dense(fs[BBOX_YMIN]) * x[DatasetField.IMAGES_INFO][:, :1],
            tf.sparse.to_dense(fs[BBOX_XMIN]) * x[DatasetField.IMAGES_INFO][:, 1:],
            tf.sparse.to_dense(fs[BBOX_YMAX]) * x[DatasetField.IMAGES_INFO][:, :1],
            tf.sparse.to_dense(fs[BBOX_XMAX]) * x[DatasetField.IMAGES_INFO][:, 1:]], axis=2), tf.float32)
        y[BoxField.NUM_BOXES] = tf.cast(tf.math.count_nonzero(tf.sparse.to_dense(fs[BBOX_YMAX]), 1, keepdims=True), tf.int32)

        return (x, y)

    def extract_data(self, data_type: str) :

        def __compute_targets(signs, ids) :
            for curId, curClass in zip(ids.values, signs.values) :
                # Так как TFRecord Preprocessing выдаёт классы, начиная нумерацию с 1,
                # то следует привести их к нумерации с 0
                curId = curId.numpy() - 1
                curClass = curClass.numpy().decode("ascii")
                self.ind_to_str[curId] = curClass
                self.str_to_ind[curClass] = curId

        print("Preprocess start")

        dataPath = os.path.join(self.io_wrapper.get_input_dir(), data_type + "_tfrecord")
        data = tf.data.TFRecordDataset([dataPath])

        feature = {
            CLASS_LABEL:  tf.io.VarLenFeature(tf.int64),
            CLASS_TEXT:  tf.io.VarLenFeature(tf.string),
            IMAGE_HEIGHT:     tf.io.FixedLenFeature((), tf.int64, default_value=1),
            IMAGE_WIDTH:      tf.io.FixedLenFeature((), tf.int64, default_value=1)
        }
        sz = 0
        for obj in data :
            sz += 1
            fs = tf.io.parse_single_example(obj, feature)
            __compute_targets(
                fs[CLASS_TEXT],
                fs[CLASS_LABEL])

        self.params[data_type + "_size"] = sz
        print("Preprocess ended")

        return data, None

    def get_classes(self) :
        return list([self.ind_to_str[i] for i in range(len(self.ind_to_str))])

    def get_train_generator(self) :
        return self.train_features.repeat(self.params["epochs"]).batch(self.params["batch_size"]).map(self.preprocess_data)

    def get_valid_generator(self) :
        return self.valid_features.repeat(self.params["epochs"]).batch(self.params["batch_size"]).map(self.preprocess_data)
