#!/usr/bin/python3
# -*- coding: utf-8 -*-

from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.detection import Box, Detection
from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.feature import Feature, MdsLoader
from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.util import download, mask_from_png_base64
import maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.matcher.pylibs.matcher_consts as mconst

import tensorflow as tf
from tensorflow.core.framework import graph_pb2

import yt.wrapper as yt

import argparse
import functools
import numpy as np
import io
import tarfile
import json

SCORE_THR = 0.7


class FasterRCNNSimple(object):
    def __init__(self, gdef_file):
        self._images_ph = tf.placeholder(tf.uint8, shape=[1, None, None, 3])
        self._score_thr_ph = tf.placeholder_with_default(SCORE_THR, shape=[], name="score_thr")
        self._build_inference_graph(gdef_file)

    def _build_inference_graph(self, gdef_file):
        """Loads the inference graph and connects it to the input image,
        Args:
        image_tensor: The input image, uint8 tensor, shape=[1, None, None, 3]
        inference_graph_path: Path to the inference graph with embedded weights
        Returns:
        detected_boxes_tensor: Detected boxes, Float tensor,
            shape=[num_detections, 4]
        detected_scores_tensor: Detected scores, Float tensor,
            shape=[num_detections]
        detected_labels_tensor: Detected labels, Int64 tensor,
            shape=[num_detections]
        """
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(gdef_file.read())

        tf.import_graph_def(graph_def, name='', input_map={'image_tensor': self._images_ph})

        g = tf.get_default_graph()

        num_detections_tensor = tf.squeeze(g.get_tensor_by_name('num_detections:0'), 0)
        num_detections_tensor = tf.cast(num_detections_tensor, tf.int32)

        detected_boxes_tensor = tf.squeeze(g.get_tensor_by_name('detection_boxes:0'), 0)
        detected_boxes_tensor = detected_boxes_tensor[:num_detections_tensor]

        detected_scores_tensor = tf.squeeze(g.get_tensor_by_name('detection_scores:0'), 0)
        detected_scores_tensor = detected_scores_tensor[:num_detections_tensor]

        detected_labels_tensor = tf.squeeze(g.get_tensor_by_name('detection_classes:0'), 0)
        detected_labels_tensor = tf.cast(detected_labels_tensor, tf.int64)
        detected_labels_tensor = detected_labels_tensor[:num_detections_tensor]

        image_shape = tf.to_float(tf.shape(self._images_ph))
        detected_boxes_tensor = tf.to_int32(
            detected_boxes_tensor * tf.stack([image_shape[1], image_shape[2], image_shape[1], image_shape[2]])
        )

        val_ind = tf.where(detected_scores_tensor > self._score_thr_ph)
        self._boxes_tensor = tf.gather_nd(detected_boxes_tensor, val_ind)
        self._scores_tensor = tf.gather_nd(detected_scores_tensor, val_ind)
        self._labels_tensor = tf.gather_nd(detected_labels_tensor, val_ind)

        with tf.Session() as sess:
            self._supported_classes = [name.decode('utf-8') for name in sess.run(g.get_tensor_by_name('class_names:0'))]

    def eval_objects(self, sess, image):
        """Detected objects for image
        Args:
        sess: tensorflow session to run graph
        image: input BGR image, uint8 tensor, shape=[1, height, width, 3]
        Returns:
        objects: list, list of detected objects
        """

        cols = image.shape[1]
        rows = image.shape[0]
        # BGR to RGB
        image = image[:, :, ::-1]
        image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2])

        boxes, scores, labels = sess.run(
            [self._boxes_tensor, self._scores_tensor, self._labels_tensor], feed_dict={self._images_ph: image}
        )

        assert boxes.shape[0] == scores.shape[0], "different amount of detected boxes and scores"
        assert boxes.shape[0] == labels.shape[0], "different amount of detected boxes and labels"

        objects = []
        for i in range(boxes.shape[0]):
            x_min = int(boxes[i][1])
            y_min = int(boxes[i][0])
            x_max = int(boxes[i][3])
            y_max = int(boxes[i][2])

            if x_min < 0:
                x_min = 0
            if y_min < 0:
                y_min = 0
            if x_max >= cols:
                x_max = cols - 1
            if y_max >= rows:
                y_max = rows - 1

            bbox = Box(x_min, y_min, x_max, y_max)
            sign_type = self._supported_classes[labels[i] - 1]

            objects.append(Detection(i, bbox, sign_type))

        return objects


@functools.lru_cache(1)
def get_detector_cnn():
    data = download('https://proxy.sandbox.yandex-team.ru/1945419455')
    tar = tarfile.open(fileobj=io.BytesIO(data), mode='r')
    return FasterRCNNSimple(tar.extractfile('./tf_model.gdef'))


def get_args():
    parser = argparse.ArgumentParser(description='Find objects')

    parser.add_argument(
        '--mds-host', dest='mds_host', default='storage-int.mds.yandex.net', help='mds host to load images'
    )

    parser.add_argument('--yt-proxy', dest='yt_proxy', default='hahn', help='name of yt proxy')

    parser.add_argument(
        '--porto-layer',
        dest='porto_layer',
        default=mconst.YT_JOB_PORTO_LAYER,
        help='yt path to file with porto container',
    )

    parser.add_argument(
        '--feature-mask-table',
        dest='feature_mask_table',
        type=str,
        required=True,
        help='input table with feature and mask pairs',
    )

    parser.add_argument(
        '--features-objects-yt-json',
        dest='features_objects_yt_json',
        type=str,
        required=True,
        help='output json for features and detected objects',
    )

    parser.add_argument(
        '--gpu',
        dest='gpu',
        type=str,
        choices=['none', 'gpu_geforce_1080ti', 'gpu_tesla_v100', 'gpu_tesla_a100'],
        default='none',
        help='gpu name',
    )

    return parser.parse_args()


class DetectorMapper(object):
    def __init__(self, mds_host):
        self.mds_loader = MdsLoader(mds_host)
        self.tf_session = None

    def get_tf_session(self):
        if self.tf_session is None:
            self.tf_session = tf.Session()
            self.tf_session.__enter__()
        return self.tf_session

    def __filter_objects(self, objects, mask):
        result = []
        for object in objects:
            box = object.box
            object_mask = mask[box.min_y : box.max_y, box.min_x : box.max_x]
            box_area = object_mask.shape[0] * object_mask.shape[1]
            mask_area = np.sum(object_mask == 0)
            ratio = mask_area / float(box_area)
            if ratio >= 0.5:
                continue
            result.append(object)
        return result

    def __call__(self, row):
        feature = Feature(row)
        mask = mask_from_png_base64(row[mconst.MASK_PNG_BASE64])
        image = self.mds_loader(feature)
        detector = get_detector_cnn()
        objects = detector.eval_objects(self.get_tf_session(), image)
        objects = self.__filter_objects(objects, mask)
        reverted_objects = [object.revert_by_feature_orientation(feature).to_dict() for object in objects]

        yield {'feature_id': feature.id, 'orientation': feature.orientation.to_exif(), 'objects': reverted_objects}


def main():
    args = get_args()

    yt_client = yt.YtClient(proxy=args.yt_proxy)
    size = yt_client.row_count(args.feature_mask_table)

    spec = {
        'title': "Detector::Detect",
        'job_count': max(1, int(size / 30)),
        'mapper': {
            "memory_limit": 8 * (1024 ** 3),
            'layer_paths': [args.porto_layer],
        },
    }

    if args.gpu != 'none':
        spec['pool_trees'] = [args.gpu]
        spec['scheduling_options_per_pool_tree'] = {args.gpu: {'pool': 'research_gpu'}}
        spec['mapper']['gpu_limit'] = 1

    detection_mapper = DetectorMapper(args.mds_host)

    with yt_client.TempTable() as tmp_table:
        yt_client.run_map(
            detection_mapper,
            args.feature_mask_table,
            yt_client.TablePath(tmp_table),
            spec=spec,
        )

        features_objects = []
        for row in yt_client.read_table(yt_client.TablePath(tmp_table)):
            feature_objects = {}
            feature_objects['feature_id'] = row['feature_id']
            feature_objects['orientation'] = row['orientation']
            feature_objects['objects'] = row['objects']
            features_objects.append(feature_objects)
        result = {'features_objects': features_objects}
        yt_client.write_file(args.features_objects_yt_json, json.dumps(result, indent=4).encode('utf-8'))


if __name__ == "__main__":
    main()
