import numpy as np
import json
import argparse


class Cluster(object):
    def __init__(self, feature_id1, object_id1, feature_id2, object_id2):
        # self.feature_objects список объектов в кластере
        self.feature_objects = []

        # self.feature_objects_indx словарь
        # ключи - объекты на фотографии принадлежащие кластеру
        #   объект определяется как пара (feature_id, object_id)
        #   object_id - id объекта уникальное для данной фотографии
        # значение - index пары (feature_id, object_id) в списке self.feature_objects
        # уникальный внутри кластера
        self.feature_objects_indx = {}

        # список соответствий, которые использовались при выделении данного кластера.
        # match = (feature_object_id1, feature_object_id2) пара индексов объектов
        # из self.feature_objects
        self.matches = []

        self.matches += [(self._add_object(feature_id1, object_id1), self._add_object(feature_id2, object_id2))]

    def _add_object(self, feature_id, object_id):
        idx = len(self.feature_objects)
        self.feature_objects += [(feature_id, object_id)]
        self.feature_objects_indx[(feature_id, object_id)] = idx
        return idx

    def _get_connection_matrix(self):
        objs_cnt = len(self.feature_objects)
        connect = np.zeros((objs_cnt, objs_cnt), np.uint8)
        for match in self.matches:
            connect[match[0], match[1]] += 1
            connect[match[1], match[0]] += 1
        return connect

    def add_match(self, feature_id1, object_id1, feature_id2, object_id2):
        id1 = self.feature_objects_indx.get((feature_id1, object_id1))
        assert id1 is not None, \
            "Unable to add match, there is no object with feature_id = " + str(feature_id1) + " and object_id = " + str(object_id1) + " in cluster"
        id2 = self.feature_objects_indx.get((feature_id2, object_id2))
        if (id2 is None):
            id2 = self._add_object(feature_id2, object_id2)
        assert not (id1, id2) in self.matches
        self.matches += [(id1, id2)]

    def append(self, feature_id, object_id, other, feature_id_other, object_id_other):
        assert (feature_id != feature_id_other)

        if (self == other):
            self.add_match(feature_id, object_id, feature_id_other, object_id_other)
            return

        indx = self.feature_objects_indx.get((feature_id, object_id))
        assert indx is not None, \
            "There are not object with feature_id = " + str(feature_id) + " and object_id = " + str(object_id) + " in main cluster"
        assert self.feature_objects_indx.get((feature_id_other, object_id_other)) is None, \
            "Object with feature_id = " + str(feature_id_other) + " and object_id = " + str(object_id_other) + " already in main cluster"

        assert other.feature_objects_indx.get((feature_id_other, object_id_other)) is not None, \
            "There is no object with feature_id = " + str(feature_id_other) + " and object_id = " + str(object_id_other) + " in second cluster"
        assert other.feature_objects_indx.get((feature_id, object_id)) is None, \
            "Object with feature_id = " + str(feature_id) + " and object_id = " + str(object_id) + " already in second cluster"

        map_indexes = {}
        for feature_object, indx_other in other.feature_objects_indx.items():
            indx_self = self.feature_objects_indx.get(feature_object)
            if (indx_self is None):
                indx_self = self._add_object(feature_object[0], feature_object[1])
            map_indexes[indx_other] = indx_self
        for match in other.matches:
            assert not (map_indexes[match[0]], map_indexes[match[1]]) in self.matches
            self.matches += [(map_indexes[match[0]], map_indexes[match[1]])]

        indx_other = self.feature_objects_indx.get((feature_id_other, object_id_other))
        assert not (indx, indx_other) in self.matches, \
            "Link from object with feature_id = " + str(feature_id) + " and object_id = " + str(object_id) + " to object with feature_id = " + \
            str(feature_id_other) + " and object_id = " + str(object_id_other) + " already exists in cluster"
        self.matches += [(indx, indx_other)]

    def has_object(self, feature_id, object_id):
        return (feature_id, object_id) in self.feature_objects_indx

    def add_absent_links(self, dry_run):
        absent_links = []
        connect = self._get_connection_matrix()
        for idx1 in range(0, connect.shape[0] - 1):
            for idx2 in range(idx1 + 1, connect.shape[1]):
                if (connect[idx1, idx2] == 0):
                    if (not dry_run):
                        self.matches += [(idx1, idx2)]
                    absent_links += [[self.feature_objects[idx1], self.feature_objects[idx2]]]
        return absent_links

    def validate(self):
        for match in self.matches:
            assert match[0] != match[1]
        objs_cnt = len(self.feature_objects)
        connect = self._get_connection_matrix()

        result = {'double_match': [],
                  'multi_objects': [],
                  'absent_links': []}
        #  проверяем, что нет двойных связей - по идеи это на уровне добавления матча надо отбивать?
        for idx1 in range(0, objs_cnt - 1):
            for idx2 in range(idx1 + 1, objs_cnt):
                if connect[idx1, idx2] > 1:
                    feature_object1 = self.feature_objects[idx1]
                    feature_object2 = self.feature_objects[idx2]
                    result['double_match'] += [(feature_object1, feature_object2)]

        # проверяем, что нет двух объектов с одной фотографии в кластере
        feature_ids = {}
        for feature_object in self.feature_objects:
            if (feature_object[0] in feature_ids):
                feature_ids[feature_object[0]] += [feature_object[1]]
            else:
                feature_ids[feature_object[0]] = [feature_object[1]]
        for feature_id, object_ids in feature_ids.items():
            if (len(object_ids) > 1):
                result['multi_objects'] += [{'feature_id' : feature_id, 'object_ids': object_ids}]

        # проверяем, что все объекты в кластере связаны между собой
        for idx1 in range(0, objs_cnt - 1):
            for idx2 in range(idx1 + 1, objs_cnt):
                if (connect[idx1, idx2] == 0):
                    feature_object1 = self.feature_objects[idx1]
                    feature_object2 = self.feature_objects[idx2]
                    result['absent_links'] += [(feature_object1, feature_object2)]
        return result


class Clusterizer(object):
    def __init__(self):
        self.clusters = []

    def add_match(self, feature_id1, object_id1, feature_id2, object_id2):
        cluster1 = None
        cluster2 = None
        for cluster in self.clusters:
            if (cluster.has_object(feature_id1, object_id1)):
                cluster1 = cluster
            if (cluster.has_object(feature_id2, object_id2)):
                cluster2 = cluster
        if (cluster1 is None):
            if (cluster2 is None):
                self.clusters += [Cluster(feature_id1, object_id1, feature_id2, object_id2)]
            else:
                cluster2.add_match(feature_id2, object_id2, feature_id1, object_id1)
        else:
            if (cluster2 is None):
                cluster1.add_match(feature_id1, object_id1, feature_id2, object_id2)
            elif (cluster1 == cluster2):
                cluster1.add_match(feature_id1, object_id1, feature_id2, object_id2)
            else:
                cluster1.append(feature_id1, object_id1, cluster2, feature_id2, object_id2)
                self.clusters.remove(cluster2)

    def add_absent_links(self, dry_run):
        # у нас в кластере должны быть залинкованы "все со всеми"
        # здесь мы добавляем недостающие линки и возвращаем список добавленных
        added_links = []
        for cluster in self.clusters:
            added_links += cluster.add_absent_links(dry_run)
        return added_links

    def validate(self):
        all_errors = []
        for idx, cluster in enumerate(self.clusters):
            all_errors += [cluster.validate()]
        return all_errors


def _load_pairs(matching_path):
    with open(matching_path, 'r', encoding='utf-8') as f:
        json_data = json.load(f)
    return json_data["features_pairs"]


def validate_matching(matching_path, add_absent_link=False, print_errors=True, print_errors_ext=False):
    def print_all_errors(all_clusters_errors, print_errors_ext):
        for idx, errors in enumerate(all_clusters_errors):
            valid = (len(errors['double_match']) + len(errors['multi_objects']) + len(errors['absent_links'])) == 0
            if (valid):
                print("Cluster {} is valid".format(idx))
                continue
            short_errors = ''
            if (len(errors['double_match']) > 0):
                short_errors += 'double match'
            if (len(errors['multi_objects']) > 0):
                short_errors += (', ' if short_errors != '' else '') + 'multi objects'
            if (len(errors['absent_links']) > 0):
                short_errors += (', ' if short_errors != '' else '') + 'absent links'
            print("Cluster {} is invalid. Errors: {}".format(idx, short_errors))
            if (not print_errors_ext):
                continue
            if (len(errors['double_match']) > 0):
                print("  Double matches:")
                for item in errors['double_match']:
                    print("    (feature_id = {}, object_id = {}) to (feature_id = {}, object_id = {})".format(item[0][0], item[0][1], item[1][0], item[1][1]))
                print("  " + "-" * 10)
            if (len(errors['multi_objects']) > 0):
                print("  Multi objects:")
                for item in errors['multi_objects']:
                    print("    feature_id = {}, objects = {}".format(item['feature_id'], item['object_ids']))
                print("  " + "-" * 10)
            if (len(errors['absent_links']) > 0):
                print("  Absent links:")
                for item in errors['absent_links']:
                    print("    (feature_id = {}, object_id = {}) to (feature_id = {}, object_id = {})".format(item[0][0], item[0][1], item[1][0], item[1][1]))
                print("  " + "-" * 10)

    clusterizer = Clusterizer()
    feature_pairs = _load_pairs(matching_path)
    for pair in feature_pairs:
        fid1 = pair['feature_id_1']
        fid2 = pair['feature_id_2']
        for objects_pair in pair['matches']:
            clusterizer.add_match(fid1, objects_pair['object_id_1'], fid2, objects_pair['object_id_2'])

    if (add_absent_link):
        clusterizer.add_absent_links(False)
    all_errors = clusterizer.validate()
    if (print_errors or print_errors_ext):
        print_all_errors(all_errors, print_errors_ext)

    valid_clusters_cnt = 0
    for errors in all_errors:
        if (0 == (len(errors['double_match']) + len(errors['multi_objects']) + len(errors['absent_links']))):
            valid_clusters_cnt += 1
    print()
    print("Valid clusters: ", valid_clusters_cnt, " from ", len(clusterizer.clusters))


def extract_clusters(matching_path, clusters_path, objects_path=None, add_absent_link=True, ignore_errors=False):
    clusterizer = Clusterizer()
    feature_pairs = _load_pairs(matching_path)
    for pair in feature_pairs:
        fid1 = pair['feature_id_1']
        fid2 = pair['feature_id_2']
        for objects_pair in pair['matches']:
            clusterizer.add_match(fid1, objects_pair['object_id_1'], fid2, objects_pair['object_id_2'])

    if (add_absent_link):
        clusterizer.add_absent_links(False)

    if (not ignore_errors):
        all_errors = clusterizer.validate()
        for idx, errors in enumerate(all_errors):
            if (0 < (len(errors['double_match']) + len(errors['multi_objects']) + len(errors['absent_links']))):
                print("Cluster " + str(idx) + " is invalid")
                return

    clusters = []
    for item in clusterizer.clusters:
        cluster = {'cluster_id' : len(clusters)}
        cluster['objects'] = [{'feature_id' : feature_object[0], 'object_id' : feature_object[1]} for feature_object in item.feature_objects]
        clusters += [cluster]

    if (objects_path is not None):
        with open(objects_path, 'r', encoding='utf-8') as f:
            json_data = json.load(f)
        feature_objects = set()
        for item in json_data['features_objects']:
            fid = item['feature_id']
            for obj in item['objects']:
                feature_objects.add((fid, obj['object_id']))
        for cluster in clusterizer.clusters:
            for cluster_fobj in cluster.feature_objects:
                feature_objects.remove(cluster_fobj)
        for feature_object in feature_objects:
            cluster = {'cluster_id' : len(clusters)}
            cluster['objects'] = [{'feature_id' : feature_object[0], 'object_id' : feature_object[1]}]
            clusters += [cluster]

    with open(clusters_path, 'w',  encoding='utf-8') as f:
        json.dump({"clusters": clusters}, f, indent=4, ensure_ascii=False)


def main():
    parser = argparse.ArgumentParser(description="Process matching pairs of objects")
    parser.add_argument('--input-path', '-i', required=True,
                        help='Path to input json file')
    parser.add_argument('--add-absent-links', action='store_true', default=False, dest='add_absent_links',
                        help='Add absent link in clusters before validation')
    # validate
    parser.add_argument('--validate', action='store_true', default=False, dest='validate',
                        help='Validate input file with matches')
    parser.add_argument('--print-error', action='store_true', default=False, dest='print_error',
                        help='Print validation errors of input file with matches')
    parser.add_argument('--print-error-ext', action='store_true', default=False, dest='print_error_ext',
                        help='Print extended information abaut validation errors of input file with matches')
    # clusterize
    parser.add_argument('--clusterize', action='store_true', default=False, dest='clusterize',
                        help='Extract cluster by matches from input file')
    parser.add_argument('--output-path', '-o', default=None,
                        help='Path to output json file')
    parser.add_argument('--objects-path', default=None,
                        help='Path to json file with feature_ids with objects list')
    parser.add_argument('--ignore-errors', action='store_true', default=False, dest='ignore_errors',
                        help='Ignore errors of validation and save clusters')

    args = parser.parse_args()
    if (args.validate):
        validate_matching(args.input_path, args.add_absent_links, args.print_error, args.print_error_ext)
    if (args.clusterize):
        if (args.output_path is None):
            print("You should define --output-path options for clusterize command")
        else:
            extract_clusters(args.input_path, args.output_path, args.objects_path, args.add_absent_links, args.ignore_errors)

if __name__ == '__main__':
    main()
