#!/usr/bin/python3
# -*- coding: utf-8 -*-

from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.util import image_to_png_base64
from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.feature import Feature, MdsLoader
import maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.matcher.pylibs.matcher_consts as mconst

import tensorflow as tf
from tensorflow.core.framework import graph_pb2

import cv2 as cv
import numpy as np
import yt.wrapper as yt

import argparse
import functools
import os

from library.python import resource

SCORE_THR = 0.5
MASK_THR = 0.5


class MaskRCNNSimple(object):
    def __init__(self, gdef):
        self._images_ph = tf.placeholder(tf.uint8, shape=[1, None, None, 3])
        self._score_thr_ph = tf.placeholder_with_default(SCORE_THR, shape=[], name="score_thr")
        self._build_inference_graph(gdef)

    def _build_inference_graph(self, gdef):
        """Loads the inference graph and connects it to the input image,
        Args:
        image_tensor: The input image, uint8 tensor, shape=[1, None, None, 3]
        inference_graph_path: Path to the inference graph with embedded weights
        Returns:
        detected_boxes_tensor: Detected boxes, Float tensor,
            shape=[num_detections, 4]
        detected_scores_tensor: Detected scores, Float tensor,
            shape=[num_detections]
        detected_labels_tensor: Detected labels, Int64 tensor,
            shape=[num_detections]
        detected_mask_tensor: Detected masks, Float tensor,
            shape=[num_detections, mask_height, mask_width]
        """
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(gdef)

        tf.import_graph_def(graph_def, name='', input_map={'image_tensor': self._images_ph})

        g = tf.get_default_graph()

        num_detections_tensor = tf.squeeze(g.get_tensor_by_name('num_detections:0'), 0)
        num_detections_tensor = tf.cast(num_detections_tensor, tf.int32)

        detected_boxes_tensor = tf.squeeze(g.get_tensor_by_name('detection_boxes:0'), 0)
        detected_boxes_tensor = detected_boxes_tensor[:num_detections_tensor]

        detected_scores_tensor = tf.squeeze(g.get_tensor_by_name('detection_scores:0'), 0)
        detected_scores_tensor = detected_scores_tensor[:num_detections_tensor]

        detected_labels_tensor = tf.squeeze(g.get_tensor_by_name('detection_classes:0'), 0)
        detected_labels_tensor = tf.cast(detected_labels_tensor, tf.int64)
        detected_labels_tensor = detected_labels_tensor[:num_detections_tensor]

        detected_masks_tensor = tf.squeeze(g.get_tensor_by_name('detection_masks:0'), 0)
        detected_masks_tensor = detected_masks_tensor[:num_detections_tensor]

        image_shape = tf.to_float(tf.shape(self._images_ph))
        detected_boxes_tensor = tf.to_int32(
            detected_boxes_tensor * tf.stack([image_shape[1], image_shape[2], image_shape[1], image_shape[2]])
        )

        val_ind = tf.where(detected_scores_tensor > self._score_thr_ph)
        self._boxes_tensor = tf.gather_nd(detected_boxes_tensor, val_ind)
        self._scores_tensor = tf.gather_nd(detected_scores_tensor, val_ind)
        self._labels_tensor = tf.gather_nd(detected_labels_tensor, val_ind)
        self._masks_tensor = tf.gather_nd(detected_masks_tensor, val_ind)

    def eval_mask(self, sess, image):
        """Make car masking for image
        Args:
        sess: tensorflow session to run graph
        image: input BGR image, uint8 tensor, shape=[1, height, width, 3]
        Returns:
        mask: uint8 tensor, shape=[height, width]
        * mask[i, j] = 0: discarded pixel
        * mask[i, j] > 0: useful pixel
        """

        cols = image.shape[1]
        rows = image.shape[0]
        # BGR to RGB
        image = image[:, :, ::-1]
        image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2])

        boxes, scores, labels, masks = sess.run(
            [self._boxes_tensor, self._scores_tensor, self._labels_tensor, self._masks_tensor],
            feed_dict={self._images_ph: image},
        )

        assert boxes.shape[0] == scores.shape[0], "different amount of detected boxes and scores"
        assert boxes.shape[0] == labels.shape[0], "different amount of detected boxes and labels"
        assert boxes.shape[0] == masks.shape[0], "different amount of detected boxes and masks"

        image_mask = np.ones((rows, cols), dtype=np.uint8)
        for i in range(boxes.shape[0]):
            x_min = int(boxes[i][1])
            y_min = int(boxes[i][0])
            x_max = int(boxes[i][3])
            y_max = int(boxes[i][2])

            if x_min < 0:
                x_min = 0
            if y_min < 0:
                y_min = 0
            if x_max >= cols:
                x_max = cols - 1
            if y_max >= rows:
                y_max = rows - 1

            w = abs(x_max - x_min)
            h = abs(y_max - y_min)
            mask = cv.resize(masks[i], (w, h))
            image_mask[y_min:y_max, x_min:x_max] = np.minimum(image_mask[y_min:y_max, x_min:x_max], np.where(mask > MASK_THR, 0, 1))

        return image_mask


@functools.lru_cache(1)
def get_mask_cnn():
    gdef = resource.find("/maps/mrc/carsegm/models/tf_model.gdef")
    return MaskRCNNSimple(gdef)


class ComputeMaskMapper(object):

    def __init__(self, mds_host):
        self.mds_loader = MdsLoader(mds_host)
        self.tf_session = None

    def get_tf_session(self):
        if self.tf_session is None:
            self.tf_session = tf.Session()
            self.tf_session.__enter__()
        return self.tf_session

    def __call__(self, row):
        output_row = row
        feature = Feature(output_row)
        image = self.mds_loader(feature)
        masker = get_mask_cnn()
        mask = masker.eval_mask(self.get_tf_session(), image)
        output_row[mconst.MASK_PNG_BASE64] = image_to_png_base64(mask)

        yield output_row


def get_args():
    parser = argparse.ArgumentParser(description='Find masks')

    parser.add_argument(
        '--mds-host', dest='mds_host', default='storage-int.mds.yandex.net', help='mds host to load images'
    )

    parser.add_argument('--yt-proxy', dest='yt_proxy', default='hahn', help='name of yt proxy')

    parser.add_argument(
        '--porto-layer',
        dest='porto_layer',
        default=mconst.YT_JOB_PORTO_LAYER,
        help='yt path to file with porto container',
    )

    parser.add_argument(
        '--feature-table', dest='feature_table', type=str, required=True, help='input table with features'
    )

    parser.add_argument(
        '--feature-mask-table',
        dest='feature_mask_table',
        type=str,
        required=True,
        help='output table for feature and mask pairs',
    )

    parser.add_argument(
        '--gpu',
        dest='gpu',
        type=str,
        choices=['none', 'gpu_geforce_1080ti', 'gpu_tesla_v100', 'gpu_tesla_a100'],
        default='none',
        help='gpu name',
    )

    return parser.parse_args()


def main():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    args = get_args()

    yt_client = yt.YtClient(proxy=args.yt_proxy)
    size = yt_client.row_count(args.feature_table)

    spec = {
        'title': "Matcher::Mask",
        'job_count': max(1, int(size / 30)),
        'mapper': {
            "memory_limit": 6 * (1024 ** 3),
            'layer_paths': [args.porto_layer],
        },
    }

    if args.gpu != 'none':
        spec['pool_trees'] = [args.gpu]
        spec['scheduling_options_per_pool_tree'] = {args.gpu: {'pool': 'research_gpu'}}
        spec['mapper']['gpu_limit'] = 1

    mapper = ComputeMaskMapper(args.mds_host)

    yt_client.run_map(
        mapper,
        args.feature_table,
        args.feature_mask_table,
        spec=spec,
    )


if __name__ == "__main__":
    main()
