import argparse
import cv2
import json
import base64
import copy
import numpy as np


# Keys
SHIFT = 65505
RIGHT_ARROW = 65363
LEFT_ARROW = 65361


class Rect:
    def __init__(self, coords):
        self.coords = ((coords[0][0], coords[0][1]), (coords[1][0], coords[1][1]))

    def draw(self, image, color, line_width):
        cv2.rectangle(image, self.coords[0], self.coords[1], color, line_width)

    def center(self):
        return ((self.coords[0][0] + self.coords[1][0]) // 2, (self.coords[0][1] + self.coords[1][1]) // 2)

    def resize(self, scale):
        x1 = int(scale * self.coords[0][0])
        y1 = int(scale * self.coords[0][1])
        x2 = int(scale * self.coords[1][0])
        y2 = int(scale * self.coords[1][1])
        return Rect([[x1, y1], [x2, y2]])

    def shift(self, pos):
        x1 = self.coords[0][0] + pos[0]
        y1 = self.coords[0][1] + pos[1]
        x2 = self.coords[1][0] + pos[0]
        y2 = self.coords[1][1] + pos[1]
        return Rect([[x1, y1], [x2, y2]])


class Object:
    def __init__(self, data):
        self.id = data['object_id']
        self.type = data['type']
        self.box = Rect(data['bbox'])


def parse_base64_image(base64_image):
    data = np.frombuffer(base64.b64decode(base64_image), dtype=np.uint8)
    return cv2.imdecode(data, cv2.IMREAD_COLOR)


class Frame:
    def __init__(self, image, height, pos):
        self.height = height
        self.pos = pos
        self.scale = float(self.height) / image.shape[0]
        self.width = int(self.scale * image.shape[1])

        self.resized_image = cv2.resize(image, (self.width, self.height))

        self.labeler = None

    def transform(self, object):
        frame_object = copy.deepcopy(object)
        frame_object.box = object.box.resize(self.scale).shift(self.pos)
        return frame_object

    def draw(self, canvas):
        canvas[self.pos[1]:self.pos[1] + self.height, self.pos[0]:self.pos[0] + self.width, :] = self.resized_image

    def borders(self):
        return Rect([[self.pos[0], self.pos[1]], [self.pos[0] + self.width, self.pos[1] + self.height]])


def shift(value, shift_value, max_value):
    assert(shift_value in [-1, 1])
    return (max_value + value + shift_value) % max_value


class Dataset:
    def __init__(self, images_path, objects_path):
        self.base64_image_by_feature_id = {}
        with open(images_path) as f:
            data = json.load(f)
            for feature_image in data['features_image']:
                feature_id = feature_image['feature_id']
                image = feature_image['image']
                self.base64_image_by_feature_id[feature_id] = image

        # в словаре по feature_id снимка хранится пара объектов:
        # 1) Список айдишников объектов на снимке
        # 2) Словарь объектов по их айдишнику
        self.objects_by_feature_id = {}
        with open(objects_path) as f:
            data = json.load(f)
            for feature_objects in data['features_objects']:
                feature_id = feature_objects['feature_id']
                assert feature_id in self.base64_image_by_feature_id

                object_ids = []
                object_by_id = {}
                for item in feature_objects['objects']:
                    object = Object(item)
                    object_ids.append(object.id)
                    object_by_id[object.id] = object
                self.objects_by_feature_id[feature_id] = (object_ids, object_by_id)

        assert len(self.base64_image_by_feature_id) == len(self.objects_by_feature_id)
        self.dataset_size = len(self.base64_image_by_feature_id)

    def size(self):
        return self.dataset_size

    def has(self, feature_id):
        return feature_id in self.base64_image_by_feature_id

    def image(self, feature_id):
        return parse_base64_image(self.base64_image_by_feature_id[feature_id])

    def objects(self, feature_id):
        return self.objects_by_feature_id[feature_id]


class Labeler:
    def __init__(self, dataset, matches, pairs, init_pair_indx, output):
        assert(dataset.size() > 0)
        assert(len(pairs) > 0)

        self.dataset = dataset
        self.matches = matches
        self.pairs = pairs
        self.pair_indx = init_pair_indx
        self.output = output

        self.cur_matches = {}
        self.reset()

    def dump_matches(self):
        if self.pairs and self.cur_matches:
            self.matches[self.pairs[self.pair_indx]] = self.cur_matches

    def save_results(self):
        print('Saving results...')
        self.dump_matches()
        result = {}

        features_pairs = []
        for pair, matches in self.matches.items():
            if not matches:
                continue
            features_pair = {}
            features_pair['feature_id_1'] = pair[0]
            features_pair['feature_id_2'] = pair[1]
            json_matches = []
            for object_id_1, object_id_2 in matches.items():
                json_match = {}
                json_match['object_id_1'] = object_id_1
                json_match['object_id_2'] = object_id_2
                json_matches.append(json_match)
            features_pair['matches'] = json_matches
            features_pairs.append(features_pair)
        result['features_pairs'] = features_pairs

        with open(self.output, 'w') as f:
            json.dump(result, f, indent=2)
        print('Done')

    def reset(self):
        print('Pair index: {}'.format(self.pair_indx))

        self.cur_matches = {}
        if self.pairs[self.pair_indx] in self.matches:
            self.cur_matches = self.matches[self.pairs[self.pair_indx]]

        master_frame_id, slave_frame_id = self.pairs[self.pair_indx]

        self.master_object_ids, self.master_object_by_id = self.dataset.objects(master_frame_id)
        self.slave_object_ids, self.slave_object_by_id = self.dataset.objects(slave_frame_id)

        self.create_canvas()

        self.is_master_active = True

        self.master_object_indx = 0
        self.update_selected_slave_object_ids()

    def update_selected_slave_object_ids(self):
        self.selected_slave_object_ids = []
        self.selected_slave_object_indx = None
        if self.master_object_indx is None:
            return

        master_object_id = self.master_object_ids[self.master_object_indx]
        if master_object_id in self.cur_matches:
            self.selected_slave_object_ids.append(self.cur_matches[master_object_id])
        else:
            master_object = self.master_object_by_id[master_object_id]
            for slave_object in self.slave_object_by_id.values():
                if slave_object.type != master_object.type:
                    continue

                is_used = False
                for master_object_id, slave_object_id in self.cur_matches.items():
                    if slave_object.id == slave_object_id:
                        is_used = True
                        break

                if not is_used:
                    self.selected_slave_object_ids.append(slave_object.id)
        if self.selected_slave_object_ids:
            self.selected_slave_object_indx = 0

    def create_canvas(self):
        height = 500
        pad = 2

        master_frame_id, slave_frame_id = self.pairs[self.pair_indx]

        self.master_frame = Frame(self.dataset.image(master_frame_id), height, (0, 0))
        self.slave_frame = Frame(self.dataset.image(slave_frame_id), height, (0, height + pad))

        canvas_height = self.master_frame.height + self.slave_frame.height + pad
        canvas_width = max(self.master_frame.width, self.slave_frame.width)
        self.canvas = np.zeros([canvas_height, canvas_width, 3], dtype=np.uint8)

    def next_object(self):
        if self.is_master_active:
            self.master_object_indx = shift(self.master_object_indx, 1, len(self.master_object_ids))
            self.update_selected_slave_object_ids()
        else:
            self.selected_slave_object_indx = shift(self.selected_slave_object_indx, 1, len(self.selected_slave_object_ids))

    def prev_object(self):
        if self.is_master_active:
            self.master_object_indx = shift(self.master_object_indx, -1, len(self.master_object_ids))
            self.update_selected_slave_object_ids()
        else:
            self.selected_slave_object_indx = shift(self.selected_slave_object_indx, -1, len(self.selected_slave_object_ids))

    def switch_frame(self):
        self.is_master_active = not self.is_master_active

    def create_link(self):
        if self.master_object_ids and self.selected_slave_object_ids:
            master_object_id = self.master_object_ids[self.master_object_indx]
            slave_object_id = self.selected_slave_object_ids[self.selected_slave_object_indx]
            self.cur_matches[master_object_id] = slave_object_id
            self.update_selected_slave_object_ids()

    def delete_link(self):
        if self.master_object_ids and self.selected_slave_object_ids:
            master_object_id = self.master_object_ids[self.master_object_indx]
            slave_object_id = self.selected_slave_object_ids[self.selected_slave_object_indx]

            if master_object_id in self.cur_matches:
                if self.cur_matches[master_object_id] == slave_object_id:
                    self.cur_matches.pop(master_object_id)
                    self.update_selected_slave_object_ids()

    def clear_matches(self):
        self.cur_matches = {}

    def render(self):
        selected = (0, 0, 255)  # red
        regular = (0, 255, 0)   # green
        used = (0, 255, 255)    # yellow

        line_width = 2
        border_width = 5

        self.master_frame.draw(self.canvas)
        self.slave_frame.draw(self.canvas)

        frame_borders = None
        if self.is_master_active:
            frame_borders = self.master_frame.borders()
        else:
            frame_borders = self.slave_frame.borders()
        frame_borders.draw(self.canvas, selected, border_width)

        for object in self.master_object_by_id.values():
            self.master_frame.transform(object).box.draw(self.canvas, regular, line_width)

        if self.master_object_ids:
            master_object = self.master_object_by_id[self.master_object_ids[self.master_object_indx]]
            if master_object.id in self.cur_matches:
                color = used
            else:
                color = selected
            self.master_frame.transform(master_object).box.draw(self.canvas, color, line_width)

        for slave_object_id in self.selected_slave_object_ids:
            slave_object = self.slave_object_by_id[slave_object_id]
            self.slave_frame.transform(slave_object).box.draw(self.canvas, regular, line_width)

        if self.selected_slave_object_ids:
            master_object = self.master_object_by_id[self.master_object_ids[self.master_object_indx]]
            slave_object = self.slave_object_by_id[self.selected_slave_object_ids[self.selected_slave_object_indx]]
            if master_object.id in self.cur_matches:
                if self.cur_matches[master_object.id] == slave_object.id:
                    color = used
            else:
                color = selected
            self.slave_frame.transform(slave_object).box.draw(self.canvas, color, line_width)

        if self.master_object_ids and self.selected_slave_object_ids:
            master_object = self.master_object_by_id[self.master_object_ids[self.master_object_indx]]
            slave_object = self.slave_object_by_id[self.selected_slave_object_ids[self.selected_slave_object_indx]]

            if master_object.id in self.cur_matches:
                if self.cur_matches[master_object.id] == slave_object.id:
                    master_center = self.master_frame.transform(master_object).box.center()
                    slave_center = self.slave_frame.transform(slave_object).box.center()
                    cv2.line(self.canvas, master_center, slave_center, used, line_width)

        cv2.imshow('Label tool', self.canvas)

    def next_pair(self):
        next_pair_indx = shift(self.pair_indx, 1, len(self.pairs))
        if next_pair_indx != self.pair_indx:
            self.dump_matches()
            self.pair_indx = next_pair_indx
            self.reset()

    def prev_pair(self):
        prev_pair_indx = shift(self.pair_indx, -1, len(self.pairs))
        if prev_pair_indx != self.pair_indx:
            self.dump_matches()
            self.pair_indx = prev_pair_indx
            self.reset()

    def exit(self):
        self.save_results()
        exit()


def load_matches(path):
    matches = {}
    data = json.load(open(path))
    for item in data['features_pairs']:
        feature_id_1 = item['feature_id_1']
        feature_id_2 = item['feature_id_2']
        pair = (feature_id_1, feature_id_2)
        cur_matches = {}
        for match in item['matches']:
            object_id_1 = match['object_id_1']
            object_id_2 = match['object_id_2']
            cur_matches[object_id_1] = object_id_2
        matches[pair] = cur_matches
    return matches


def load_pairs(dataset, path):
    pairs = []
    for line in open(path):
        feature_id_1, feature_id_2 = map(int, line.strip().split())
        if dataset.has(feature_id_1) and dataset.has(feature_id_2):
            pairs.append((feature_id_1, feature_id_2))
    return pairs


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--images', required=True)
    parser.add_argument('--objects', required=True)
    parser.add_argument('--matches', default=None)
    parser.add_argument('--pairs', required=True)
    parser.add_argument('--init_pair_indx', type=int, default=0)
    parser.add_argument('--output', required=True)
    args = parser.parse_args()

    print('Loading dataset...')
    dataset = Dataset(args.images, args.objects)
    print('Dataset size: {}'.format(dataset.size()))

    matches = {}
    if args.matches is not None:
        print('Loading matches...')
        matches = load_matches(args.matches)
        print('Loaded matches for {} pairs'.format(len(matches)))

    print('Loading pairs...')
    pairs = load_pairs(dataset, args.pairs)
    print('Loaded {} pairs'.format(len(pairs)))

    labeler = Labeler(dataset, matches, pairs, args.init_pair_indx, args.output)

    actions = {
        65505:    labeler.switch_frame,   # Press 'Shift' (Ubuntu)
        65363:    labeler.next_object,    # Press '->' (Ubuntu)
        65361:    labeler.prev_object,    # Press '<-' (Ubuntu)
        65535:    labeler.clear_matches,  # Press 'Delete' (Ubuntu)
        32:       labeler.create_link,    # Press 'Space' (Ubuntu)
        8:        labeler.delete_link,    # Press 'Backspace' (Ubuntu)
        27:       labeler.exit,           # Press 'Esc' (Ubuntu)
        ord('s'): labeler.save_results,   # Press 's'
        ord('n'): labeler.next_pair,      # Press 'n'
        ord('p'): labeler.prev_pair       # Press 'p'
    }

    while True:
        labeler.render()
        key = cv2.waitKeyEx()
        if key in actions:
            actions[key]()


if __name__ == '__main__':
    main()
