#!/usr/bin/python3
# -*- coding: utf-8 -*-

import functools
from functools import lru_cache

import cv2
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import graph_pb2

from library.python import resource


def sift(image, mask):
    detector = cv2.SIFT_create(2000)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    key_points, descriptors = detector.detectAndCompute(gray, mask)

    size = descriptors.shape[0]

    return (
        np.float32(list(map(lambda kp: kp.pt, key_points))),
        np.float32(descriptors),
        np.ones((size,), dtype=np.float32),
    )


SUPERPOINTS_MAX_PTS_CNT = 2500
SUPERPOINTS_TF_MAX_PTS_CNT = 2000


def superpoints_impl(spts, max_pts_cnt, mask):
    MASK_DILATE_KERNEL_SZ = 5

    keypts = spts['keypoints']
    descs = spts['descriptors']
    scores = spts['scores']
    keypts_size = keypts.shape[0]
    assert keypts_size == descs.shape[1] and keypts_size == scores.shape[0]

    kernel = np.ones((MASK_DILATE_KERNEL_SZ, MASK_DILATE_KERNEL_SZ), np.uint8)

    mask = 1 - (mask > 0).astype(np.uint8)
    mask = cv2.dilate(mask, kernel, iterations=1)
    mask_pts = np.array([True] * keypts_size)
    for i in range(keypts_size):
        x = keypts[i, 0]
        y = keypts[i, 1]
        if mask[int(y), int(x)] != 0:
            mask_pts[i] = False
    keypts = keypts[mask_pts]
    descs = descs[:, mask_pts]
    scores = scores[mask_pts]
    keypts_size = keypts.shape[0]

    if keypts_size > max_pts_cnt:
        idx_sort = scores.argsort()
        keypts = keypts[idx_sort[-max_pts_cnt:]]
        descs = descs[:, idx_sort[-max_pts_cnt:]]
        scores = scores[idx_sort[-max_pts_cnt:]]
        keypts_size = max_pts_cnt

    return keypts, np.transpose(descs, axes=(1, 0)), scores


class SuperpointsTF(object):
    def __init__(self, gdef, threshold_def=0.005):
        self._images_ph = tf.placeholder(tf.uint8, shape=[None, None, None, 1])
        self._build_inference_graph(gdef)
        self._init_session()

    def _build_inference_graph(self, gdef):
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(gdef)

        tf.import_graph_def(graph_def, name='', input_map={'inference_images': self._images_ph})

        g = tf.get_default_graph()
        self._keypoints = g.get_tensor_by_name('inference_points:0')
        self._scores = g.get_tensor_by_name('inference_scores:0')
        self._descriptors = g.get_tensor_by_name('inference_descriptors:0')

    def _init_session(self):
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        self.sess = tf.Session()
        self.sess.run(init_op)

    def forward(self, images, eq_hist, resize):
        images = cv2.cvtColor(images, cv2.COLOR_BGR2GRAY)

        resize = int(resize)
        if resize > 0:
            ws = images.shape[1]
            hs = images.shape[0]
            wn, hn = ws, hs
            if ws > hs:
                if (ws > resize):
                    wn = int(resize)
                    hn = int(hs * resize / ws)
            else:
                if (hs > resize):
                    hn = resize
                    wn = int(ws * resize / hs)
            if (ws != wn):
                images = cv2.resize(images, (wn, hn))

        if eq_hist:
            images = cv2.equalizeHist(images)

        images = np.expand_dims(np.expand_dims(images, 0), -1)
        kpts, scores, descs = self.sess.run(
            [self._keypoints, self._scores, self._descriptors], feed_dict={self._images_ph: images}
        )
        idxs = kpts[:, 0] == 0
        kpts = kpts[idxs][:, -1:-3:-1]
        if (resize > 0) and (ws != wn):
            kpts[:, 0] = kpts[:, 0] * ws / wn
            kpts[:, 1] = kpts[:, 1] * hs / hn
        spts = {'keypoints': kpts, 'scores': scores[idxs], 'descriptors': descs[idxs]}
        return spts


@lru_cache(1)
def get_superpoint_tf_cnn():
    f = resource.find("/maps/mrc/keypoint/models/superpoints.gdef")
    return SuperpointsTF(f)


def superpoints_tf(max_pts_cnt, eq_hist, resize, image, mask):
    sptsCNN = get_superpoint_tf_cnn()
    spts = sptsCNN.forward(image, eq_hist, resize)
    # дескрипторы делаем в том же формате, что в Pytorch, каналами в первой размерности
    spts['descriptors'] = np.transpose(spts['descriptors'], [1, 0])
    return superpoints_impl(spts, max_pts_cnt, mask)


options = {
    'sift': sift,
    'superpoints_tf': functools.partial(superpoints_tf, SUPERPOINTS_TF_MAX_PTS_CNT, False, -1),
    'superpoints_tf_1000': functools.partial(superpoints_tf, 1000, False, -1),
    'superpoints_tf_1500': functools.partial(superpoints_tf, 1500, False, -1),
    'superpoints_tf_2000': functools.partial(superpoints_tf, 2000, False, -1),
    'superpoints_tf_2500': functools.partial(superpoints_tf, 2500, False, -1),
    'superpoints_tf_2500_eq_hist': functools.partial(superpoints_tf, 2500, True, -1),
    'superpoints_tf_2500_small_sz': functools.partial(superpoints_tf, 2500, False, 1000),
}
