#!/usr/bin/python3
# -*- coding: utf-8 -*-

from warnings import simplefilter

simplefilter(action='ignore', category=FutureWarning)

import argparse
import json


import yt.wrapper as yt

from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.feature import Feature, MdsLoader
from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.detection import load_detections
from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.util import mask_from_png_base64
import maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.matcher.pylibs.matcher_consts as mconst
from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.matcher.pylibs.descriptor_algorithm import options as descriptor_options
from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.matcher.pylibs.match_algorithm import options as match_options, MatchContext


def write_pairs(yt_client, pair_table, pairs):
    def generator():
        for first_id, second_id in pairs['pairs']:
            yield {
                'first': pairs['features'][str(first_id)],
                'second': pairs['features'][str(second_id)],
            }

    yt_client.write_table(pair_table, generator())


def get_args():
    parser = argparse.ArgumentParser(description='Make matches')

    parser.add_argument(
        '--mds-host', dest='mds_host', default='storage-int.mds.yandex.net', help='mds host to load images'
    )

    parser.add_argument('--yt-proxy', dest='yt_proxy', default='hahn', help='name of yt proxy')

    parser.add_argument(
        '--porto-layer',
        dest='porto_layer',
        default=mconst.YT_JOB_PORTO_LAYER,
        help='yt path to file with porto container',
    )

    parser.add_argument('--pair-file', dest='pair_file', type=str, required=True, help='input file with feature pairs')

    parser.add_argument('--mask-table', dest='mask_table', type=str, required=True, help='input table with masks')

    parser.add_argument(
        '--object-file', dest='object_file', type=str, required=True, help='input yt json file with detections'
    )

    parser.add_argument('--match-table', dest='match_table', type=str, required=True, help='output table with result')

    parser.add_argument(
        '--gpu',
        dest='gpu',
        type=str,
        choices=['none', 'gpu_geforce_1080ti', 'gpu_tesla_v100', 'gpu_tesla_a100'],
        default='none',
        help='gpu name',
    )

    parser.add_argument(
        '--descriptor-type', dest='descriptor_type', type=str, choices=list(descriptor_options), help='descriptors type'
    )

    parser.add_argument('--match-type', dest='match_type', type=str, choices=list(match_options), help='matching type')

    return parser.parse_args()


def basename(path):
    return path.rpartition('/')[-1]


def load_feature_pairs(feature_file):
    pairs = []

    for line in feature_file:
        pair = json.loads(line)

        first = Feature(pair['first'])
        second = Feature(pair['second'])

        pairs.append((first, second))

    return pairs


def load_masks(mask_file, feature_ids):
    mask_by_feature_id = {}

    for line in mask_file:
        row = json.loads(line)
        feature_id = row[mconst.FEATURE_ID]

        if feature_id not in feature_ids:
            continue

        mask = mask_from_png_base64(row[mconst.MASK_PNG_BASE64])
        mask_by_feature_id[feature_id] = mask

    return mask_by_feature_id


def make_match_context(feature, descriptors_by_feature_id, detections_by_feature_id):
    points, descriptors, confidence = descriptors_by_feature_id[feature.id]
    detections = detections_by_feature_id[feature.id]

    return MatchContext(feature, points, descriptors, confidence, detections)


class MatcherMapper(object):

    def __init__(self, mds_host, descriptor_type, match_type, mask_filename, object_filename):
        self.mds_loader = MdsLoader(mds_host)
        self.descriptor_type = descriptor_type
        self.match_type = match_type
        self.mask_filename = mask_filename
        self.object_filename = object_filename
        self.feature_pairs = []
        self.__initialized = False

    def __initialize_if_needed(self):
        if self.__initialized:
            return

        self.descriptor_func = descriptor_options[self.descriptor_type]
        self.match_func = match_options[self.match_type]
        self.__initialized = True

    def __collect_feature_by_id(self):
        feature_by_id = {}

        for first, second in self.feature_pairs:
            feature_by_id[first.id] = first
            feature_by_id[second.id] = second

        return feature_by_id

    def __call__(self, row):
        first = Feature(row['first'])
        second = Feature(row['second'])
        self.feature_pairs.append((first, second))

    def finish(self):
        self.__initialize_if_needed()
        feature_by_id = self.__collect_feature_by_id()
        with open(self.mask_filename) as lines:
            mask_by_feature_id = load_masks(lines, set(feature_by_id.keys()))

        with open(self.object_filename) as lines:
            detections_by_feature_id = load_detections(lines, feature_by_id)

        descriptors_by_feature_id = {}
        for feature_id, feature in feature_by_id.items():
            mask = mask_by_feature_id[feature_id]
            image = self.mds_loader(feature)
            points, descriptors, confidence = self.descriptor_func(image, mask)
            descriptors_by_feature_id[feature_id] = points, descriptors, confidence

        for first, second in self.feature_pairs:
            matches = self.match_func(
                make_match_context(first, descriptors_by_feature_id, detections_by_feature_id),
                make_match_context(second, descriptors_by_feature_id, detections_by_feature_id)
            )

            if len(matches) == 0:
                continue

            yield {
                'feature_id_1': first.id,
                'feature_id_2': second.id,
                'matches': [
                    {'object_id_1': first_id, 'object_id_2': second_id, 'confidence': confidence}
                    for first_id, second_id, confidence in matches
                ],
            }


def main():
    args = get_args()

    yt_client = yt.YtClient(proxy=args.yt_proxy)
    yt_client.create('table', args.match_table, recursive=True, ignore_existing=True)

    # Upload file to YT
    pairs = json.load(open(args.pair_file))
    pair_table = yt_client.create_temp_table()
    write_pairs(yt_client, pair_table, pairs)

    size = yt_client.row_count(pair_table)

    mask_filename = basename(args.mask_table)
    object_filename = basename(args.object_file)

    mapper = MatcherMapper(args.mds_host, args.descriptor_type, args.match_type, mask_filename, object_filename)

    spec = {
        'title': "Matches",
        'job_count': min(200, max(1, int(size / 200))),
        'mapper': {
            "memory_limit": 6 * (1024 ** 3),
            'layer_paths': [args.porto_layer],
        },
    }

    if args.gpu != 'none':
        spec['pool_trees'] = [args.gpu]
        spec['scheduling_options_per_pool_tree'] = {args.gpu: {'pool': 'research_gpu'}}
        spec['mapper']['gpu_limit'] = 1

    yt_client.run_map(
        mapper,
        pair_table,
        args.match_table,
        yt_files=['<format=json>' + args.mask_table, args.object_file],
        spec=spec,
    )


if __name__ == "__main__":
    main()
