import io
import os.path
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
import numpy as np
import cv2
import urllib.request
import tarfile

import argparse


SCORE_THR = 0.5
MASK_THR = 0.5


class MaskRCNNSimple(object):
    def __init__(self, gdef_path):
        self._images_ph = tf.placeholder(tf.uint8, shape=[1, None, None, 3])
        self._score_thr_ph = tf.placeholder_with_default(SCORE_THR, shape=[], name="score_thr")
        self._build_inference_graph(gdef_path)

    def _build_inference_graph(self, gdef_path):
        graph_def = graph_pb2.GraphDef()
        with open(gdef_path, "rb") as f:
            graph_def.ParseFromString(f.read())

        tf.import_graph_def(graph_def, name='', input_map={'image_tensor': self._images_ph})

        g = tf.get_default_graph()

        num_detections_tensor = tf.squeeze(g.get_tensor_by_name('num_detections:0'), 0)
        num_detections_tensor = tf.cast(num_detections_tensor, tf.int32)

        detected_boxes_tensor = tf.squeeze(g.get_tensor_by_name('detection_boxes:0'), 0)
        detected_boxes_tensor = detected_boxes_tensor[:num_detections_tensor]

        detected_scores_tensor = tf.squeeze(g.get_tensor_by_name('detection_scores:0'), 0)
        detected_scores_tensor = detected_scores_tensor[:num_detections_tensor]

        detected_labels_tensor = tf.squeeze(g.get_tensor_by_name('detection_classes:0'), 0)
        detected_labels_tensor = tf.cast(detected_labels_tensor, tf.int64)
        detected_labels_tensor = detected_labels_tensor[:num_detections_tensor]

        detected_masks_tensor = tf.squeeze(g.get_tensor_by_name('detection_masks:0'), 0)
        detected_masks_tensor = detected_masks_tensor[:num_detections_tensor]

        image_shape = tf.to_float(tf.shape(self._images_ph))
        detected_boxes_tensor = tf.to_int32(detected_boxes_tensor * tf.stack([image_shape[1], image_shape[2], image_shape[1], image_shape[2]]))

        val_ind = tf.where(detected_scores_tensor > self._score_thr_ph)
        self._boxes_tensor = tf.gather_nd(detected_boxes_tensor, val_ind)
        self._scores_tensor = tf.gather_nd(detected_scores_tensor, val_ind)
        self._labels_tensor = tf.gather_nd(detected_labels_tensor, val_ind)
        self._masks_tensor = tf.gather_nd(detected_masks_tensor, val_ind)

    def eval_mask(self, sess, image):
        cols = image.shape[1]
        rows = image.shape[0]
        # BGR to RGB
        image = image[:, :, ::-1]
        image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2])

        boxes, scores, labels, masks = sess.run([self._boxes_tensor, self._scores_tensor, self._labels_tensor, self._masks_tensor], feed_dict={self._images_ph: image})

        assert boxes.shape[0] == scores.shape[0], "different amount of detected boxes and scores"
        assert boxes.shape[0] == labels.shape[0], "different amount of detected boxes and labels"
        assert boxes.shape[0] == masks.shape[0],  "different amount of detected boxes and masks"

        image_mask = np.zeros((rows, cols), dtype=np.uint8)
        for i in range(boxes.shape[0]):
            x_min = int(boxes[i][1])
            y_min = int(boxes[i][0])
            x_max = int(boxes[i][3])
            y_max = int(boxes[i][2])

            if (x_min < 0):
                x_min = 0
            if (y_min < 0):
                y_min = 0
            if (x_max >= cols):
                x_max = cols - 1
            if (y_max >= rows):
                y_max = rows - 1

            w = abs(x_max - x_min)
            h = abs(y_max - y_min)
            mask = cv2.resize(masks[i], (w, h))
            image_mask[y_min:y_max, x_min:x_max] = np.maximum(image_mask[y_min:y_max, x_min:x_max], np.where(mask < MASK_THR, 0, 1))
        return image_mask


def download_tf_model(output_path):
    URL = "https://proxy.sandbox.yandex-team.ru/626013642"
    RETRY_COUNT = 10
    for i in range(RETRY_COUNT):
        try:
            response = urllib.request.urlopen(URL)
            data = response.read()
            break
        except urllib.error.URLError:
            pass
        except urllib.error.HTTPError:
            pass

    io_bytes = io.BytesIO(data)
    tar = tarfile.open(fileobj=io_bytes, mode='r')
    tar.extract("./tf_model.gdef", output_path)
    tar.close()


def main():
    parser = argparse.ArgumentParser(description="Car mask generator on python")

    parser.add_argument('--input-image', required=True,
                        help='Path to input image file')
    parser.add_argument('--output-mask', required=True,
                        help='Path to output png image with mask')

    args = parser.parse_args()
    graph_def_path = os.path.abspath("./tf_model.gdef")
    if (not tf.gfile.Exists(graph_def_path)):
        download_tf_model(os.path.abspath('.'))
    rcnn = MaskRCNNSimple(graph_def_path)
    with tf.Session() as sess:
        image = cv2.imread(args.input_image)
        image_mask = rcnn.eval_mask(sess, image)
        cv2.imwrite(args.output_mask, image_mask)


if __name__ == '__main__':
    main()
