import argparse
import json
import random

import yt.wrapper as yt
from feature import Feature, Size, Orientation
from detection import Box, Detection
from utils import load_features, load_detections, load_clusters
from task import Task

def generate_tasks(feature_by_id, detection_by_feature_detection_id, clusters, pairs_count, not_pairs_count):
    pairs = []

    not_one_elem_clusters = [cluster for cluster in clusters if len(cluster) > 1]
    for i in range(pairs_count):
        cluster = random.sample(not_one_elem_clusters, k=1)[0]
        feature_detection_id1, feature_detection_id2 = random.sample(cluster, k=2)
        pairs.append((feature_detection_id1, feature_detection_id2, 'yes'))

    clusters_indx_by_type = {}
    for i, cluster in enumerate(clusters):
        detection = detection_by_feature_detection_id[cluster[0]]
        sign_type = detection.sign_type
        if sign_type not in clusters_indx_by_type:
            clusters_indx_by_type[sign_type] = []
        clusters_indx_by_type[sign_type].append(i)
    clusters_indx_by_type = {item[0]: item[1] for item in clusters_indx_by_type.items() if len(item[1]) >= 2}
    for i in range(not_pairs_count):
        sign_type = random.sample(clusters_indx_by_type.keys(), k=1)[0]
        cluster_indx1, cluster_indx2 = random.sample(clusters_indx_by_type[sign_type], k=2)
        cluster1 = clusters[cluster_indx1]
        cluster2 = clusters[cluster_indx2]
        feature_detection_id1 = random.sample(cluster1, k=1)[0]
        feature_detection_id2 = random.sample(cluster2, k=1)[0]
        pairs.append((feature_detection_id1, feature_detection_id2, 'no'))

    tasks = []
    for feature_detection_id1, feature_detection_id2, status in pairs:
        feature_id1, detection_id1 = feature_detection_id1
        detection1 = detection_by_feature_detection_id[feature_detection_id1]
        feature1 = feature_by_id[feature_id1]

        feature_id2, detection_id2 = feature_detection_id2
        detection2 = detection_by_feature_detection_id[feature_detection_id2]
        feature2 = feature_by_id[feature_id2]

        tasks.append(Task(detection1, feature1, detection2, feature2, status).to_json())

    return tasks

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--yt-proxy',
        dest = 'yt_proxy',
        default = 'hahn',
        help = 'name of yt proxy'
    )
    parser.add_argument(
        '--feature-table',
        dest='feature_table',
        required=True
    )
    parser.add_argument(
        '--detections',
        required=True
    )
    parser.add_argument(
        '--clusters',
        required=True
    )
    parser.add_argument(
        '--pairs-count',
        dest='pairs_count',
        type=int,
        required=True
    )
    parser.add_argument(
        '--not-pairs-count',
        dest='not_pairs_count',
        type=int,
        required=True
    )
    parser.add_argument(
        '--tasks',
        required=True
    )
    parser.add_argument(
        '--seed',
        default=42,
        type=int
    )
    args = parser.parse_args()

    random.seed(args.seed)
    yt_client = yt.YtClient(proxy=args.yt_proxy)

    feature_by_id = load_features(yt_client, args.feature_table)
    print('Loaded {} features from {}'.format(len(feature_by_id), args.feature_table))
    detection_by_feature_detection_id = load_detections(args.detections)
    print('Loaded {} detections from {}'.format(len(detection_by_feature_detection_id), args.detections))
    clusters = load_clusters(args.clusters)
    print('Loaded {} clusters from {}'.format(len(clusters), args.clusters))

    tasks = generate_tasks(
        feature_by_id, detection_by_feature_detection_id, clusters,
        args.pairs_count, args.not_pairs_count
    )
    print('Generated {} tasks'.format(len(tasks)))

    with open(args.tasks, 'w') as f:
        json.dump(tasks, f, indent=2)

if __name__ == '__main__':
    main()
