import argparse
import json
import copy

from assignments import get_corrections
from utils import load_assignments, load_clusters


def apply_assignments(clusters, assignments):
    accepted, rejected = get_corrections(assignments)

    cluster_indx_by_feature_detection_id = {}
    count = 0
    for cluster_indx, cluster in enumerate(clusters):
        count += len(cluster)
        for feature_detection_id in cluster:
            cluster_indx_by_feature_detection_id[feature_detection_id] = cluster_indx

    for master_id, slave_id in accepted:
        master_cluster_indx = cluster_indx_by_feature_detection_id[master_id]
        slave_cluster_indx = cluster_indx_by_feature_detection_id[slave_id]
        if master_cluster_indx == slave_cluster_indx:
            continue
        clusters[master_cluster_indx] += clusters[slave_cluster_indx]
        for feature_detection_id in clusters[slave_cluster_indx]:
            cluster_indx_by_feature_detection_id[feature_detection_id] = master_cluster_indx
        clusters[slave_cluster_indx] = []

    new_clusters = [cluster for cluster in clusters if len(cluster) > 0]

    return new_clusters


def save_clusters(clusters, path):
    data = []

    for i, cluster in enumerate(clusters):
        detections = []
        for feature_id, detection_id in cluster:
            detections.append({'feature_id': feature_id, 'object_id': detection_id})
        data.append({'cluster_id': i, 'objects': detections})

    with open(path, 'w') as f:
        json.dump({'clusters': data}, f, indent=2)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--clusters',
        required=True
    )
    parser.add_argument(
        '--assignments',
        required=True
    )
    parser.add_argument(
        '--new_clusters',
        required=True
    )
    args = parser.parse_args()

    clusters = load_clusters(args.clusters)
    print('Loaded {} clusters from {}'.format(len(clusters), args.clusters))
    assignments = load_assignments(args.assignments)
    print('Loaded {} assignments from {}'.format(len(assignments), args.assignments))

    new_clusters = apply_assignments(clusters, assignments)
    print('Generated {} new clusters'.format(len(new_clusters)))

    save_clusters(new_clusters, args.new_clusters)

if __name__ == '__main__':
    main()
