import argparse
import json
import itertools


import yt.wrapper as yt

from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.feature import Feature
from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.detection import load_detections
from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.util import (
    distance,
    heading_abs_diff,
)


def intersects(first_detections, second_detections, min_box_size):
    return any(
        first.type == second.type
        for first in first_detections
        if first.box.max_size() >= min_box_size
        for second in second_detections
        if second.box.max_size() >= min_box_size
    )


def filter_pairs(
    features, detections_by_feature_id, max_distance_meters, max_heading_diff_degrees, min_box_size, ignore_detections
):
    pairs = []

    for first, second in itertools.combinations(features, 2):
        if distance(first.pos, second.pos) > max_distance_meters:
            continue

        if heading_abs_diff(first.heading, second.heading) > max_heading_diff_degrees:
            continue

        pair = first, second
        if ignore_detections:
            pairs.append(pair)
        else:
            first_detections = detections_by_feature_id[first.id]
            second_detections = detections_by_feature_id[second.id]

            if intersects(first_detections, second_detections, min_box_size):
                pairs.append(pair)

    return pairs


def write_pairs(pair_file, pairs):
    data = {'pairs': [], 'features': {}}
    for first, second in pairs:
        data['pairs'].append((first.id, second.id))
        if first.id not in data['features']:
            data['features'][first.id] = first.to_dict()
        if second.id not in data['features']:
            data['features'][second.id] = second.to_dict()

    with open(pair_file, 'w') as f:
        json.dump(data, f, indent=2)


def get_args():
    parser = argparse.ArgumentParser(description='Generate pair candidates')

    parser.add_argument('--yt-proxy', dest='yt_proxy', default='hahn', help='name of yt proxy')

    parser.add_argument('--feature-table', dest='feature_table', type=str, required=True, help="input feature table")

    parser.add_argument(
        '--object-file', dest='object_file', type=str, required=True, help='input yt json file with detections'
    )

    parser.add_argument('--distance', type=float, required=True, help='max distance between features in meters')

    parser.add_argument(
        '--heading-diff', dest='heading_diff', type=float, required=True, help='max heading diff of features in degrees'
    )

    parser.add_argument(
        '--min-box-size', dest='min_box_size', type=int, default=0, help='min matching box size in pixels'
    )

    parser.add_argument('--pair-file', dest='pair_file', type=str, required=True, help='output pair file')

    parser.add_argument(
        '--ignore-detections',
        dest='ignore_detections',
        action='store_true',
        required=False,
        help='do not consider detections on photos',
    )

    return parser.parse_args()


def main():
    args = get_args()

    yt_client = yt.YtClient(proxy=args.yt_proxy)

    feature_by_id = dict()
    for row in yt_client.read_table(args.feature_table):
        feature = Feature(row)
        feature_by_id[feature.id] = feature

    detections_by_feature_id = None
    if not args.ignore_detections:
        detections_by_feature_id = load_detections(yt_client.read_file(args.object_file), feature_by_id)

    pairs = filter_pairs(
        feature_by_id.values(),
        detections_by_feature_id,
        args.distance,
        args.heading_diff,
        args.min_box_size,
        args.ignore_detections,
    )

    write_pairs(args.pair_file, pairs)


if __name__ == "__main__":
    main()
