# -*- coding: utf-8 -*-
"""
=======================================
Tool for evaluate semantic segmentation
=======================================

Usage: python semseg_compare.py --label_map <path> --list <path> --gt_folder <path> --test_folder <path>
    - label_map - txt file with class names and labels. Format of one line: <class name> <label>
    - mode      - type of data for compare:
                    1) list   - compare all files in list
                    2) folder - compare all files in folders
                    3) file   - compare two files
    - list      - path tot txt file with list of files for compare (for mode - 'list')
    - gt        - path to one groundtruth annotation (mode 'file') or to folder with groundtruth annotations
    - test      - path to one annotation for compare (mode 'file') or to folder with annotations for compare
"""
from __future__ import print_function
import os
import re
import argparse
import numpy as np
import cv2 as cv
from string import Template
from collections import namedtuple


CompareItem = namedtuple("CompareItem", ["name", "gt_path", "test_path"])


class MetricsTable:
    """Класс для вывода результатов сравнения.
        Результат выводится в таблицу, которая состоит из столбцов двух типов:
            1) Имя файла, для которого произоводилось сравнение
            2) Метрики качества, посчитанные для данного файла
    """
    def __init__(self, compare_items):
        """Констуртор класса.

        Args:
             compare_items (list): Список сравниваемых файлов (CompareItem). Используется для получения ширины
                столбца 'filename'
        """
        # is_empty равен True, если еще ни одной строчки не было выведено
        self.is_empty = True
        self.max_name_length = max([len(item.name) for item in compare_items])

    def add_row(self, filename, metrics):
        """Выводит строку таблицы в консоль.
            Если выводимая строка первая, то выводит также заголовок таблицы и создает шаблон для всех строк в таблице.

        Args:
            filename (str): Имя сравниваемого файла
            metrics (dict): Словарь с посчитанными метриками качества (см readme.md)
        """
        if self.is_empty:
            self.is_empty = False
            metrics_list = list(sorted(metrics.keys()))
            self.header = self.create_header(metrics_list)
            self.template = self.create_row_template(metrics_list)
            print("-" * len(self.header))
            print(self.header)
            print("-" * len(self.header))
        if filename != "Total":
            self.print_row(filename, metrics)
        else:
            self.print_total(metrics)

    def create_header(self, metrics):
        """Создает заголовок таблицы с колонками двух типов:
            1) Имя файла 'filename'
            2) По одному столбцу на каждую метрику из metrics

        Args:
            metrics (list): Список имен метрик качества

        Return:
            str: Заголовочная строка
        """
        header = "| {column:^{width}} |".format(column="filename", width=self.max_name_length)
        for metric in metrics:
            header += " {column:^{width}} |".format(column=metric, width=len(metric))
        return header

    def create_row_template(self, metrics):
        """Создает шаблон для одной строки таблицы.

        Args:
             metrics (list): Список имен метрик качества

        Return:
            str: Шаблон для строки в таблице
        """
        template = Template("| {filename:^$width} |").substitute({"width": self.max_name_length})
        for metric in metrics:
            template += Template(" {$metric:^$width.2f} |").substitute({"metric": metric, "width": len(metric)})
        return template

    def print_row(self, filename, metrics):
        """Выводит информацию о результатх сравнения согласно заданному шаблону

        Args:
            filename (str): Имя сравниваемого файла
            metrics (dist): Словарь с метриками качества
        """
        print(self.template.format(filename=filename, **metrics))

    def print_total(self, metrics):
        """Выводит результат сравнения по всему набору.
            Строчка с общим результатом отделяется строкой из '-'.

        Args:
            metrics (dist): Словарь с метриками качества
        """
        print("-" * len(self.header))
        print(self.template.format(filename="Total", **metrics))

    def finalize(self):
        """Заканчивает таблицу строкой из '-'"""
        print("-" * len(self.header))


class ClassMap:
    """Класс для получения индекса или метки класса по его имени и наоборот."""
    def __init__(self, path):
        """Констуктор класса.

        Args:
            path (str): Путь до тестового файла с именами и метками классов.
                Файл должен иметь следующий формат:
                    <class name 1> <label 1>
                    <class name 2> <label 2>
                    ...
        """
        self.class_to_label = {}
        self.class_to_index = {}
        self.label_to_class = {}
        self.index_to_class = {}

        class_file = open(path, "r")
        for index, line in enumerate(class_file):
            items = line.strip().split()
            class_name = items[0]
            label = int(items[1])
            # Fill dict
            self.class_to_label[class_name] = label
            self.class_to_index[class_name] = index
            self.label_to_class[label] = class_name
            self.index_to_class[index] = class_name

    def get_label(self, class_name):
        return self.class_to_label[class_name]

    def get_index(self, class_name):
        return self.class_to_index[class_name]

    def get_name_by_label(self, label):
        return self.label_to_class[label]

    def get_name_by_index(self, index):
        return self.index_to_class[index]

    def get_index_by_label(self, label):
        class_name = self.label_to_class[label]
        return self.class_to_index[class_name]

    def get_label_by_index(self, index):
        class_name = self.index_to_class[index]
        return self.class_to_label[class_name]

    def classes_cnt(self):
        return len(self.class_to_label)

    def get_classes(self):
        return self.class_to_label.keys()


def calc_metrics(confusion_matrix, classes_map):
    """По матрице ошибок (confusion matrix, https://en.wikipedia.org/wiki/Confusion_matrix) расчитывает следующие
        метрики качества (см readme.md):
            1) total_px_acc      - общее попиксельное качество классификации
            2) label_px_acc      - попиксельное качество классификации для каждого класса из classes
            3) mean_px_acc       - усредненное по всем классам попиксельное качество классификации
            4) mean_iou          - усредненная по всем классам метрика intersection over union
                                    (IoU, https://en.wikipedia.org/wiki/Jaccard_index)
            5) freq_weighted_iou - взвешенная по классам метрика intersection over union

    Args:
        confusion_matrix (numpy.array): Матрица ошибок
        classes_map (ClassMap): Объект для получения имени класса по его индексу

    Returns:
        dict: Словарь со всеми посчитанными метриками качества
    """
    epsilon = 0.000001
    # Count pixels of each class
    pixel_cnt = np.sum(confusion_matrix, axis=1)

    metrics = {}
    # Total pixel accuracy
    metrics["total_px_acc"] = float(np.trace(confusion_matrix)) / float(np.sum(pixel_cnt)) * 100.
    # Pixel accuracy per class
    for name in classes_map.get_classes():
        index = classes_map.get_index(name)
        if pixel_cnt[index] == 0:
            metrics["{}_px_acc".format(name)] = 100.
        else:
            metrics["{}_px_acc".format(name)] = confusion_matrix[index, index] / float(pixel_cnt[index]) * 100.
    # Mean pixel accuracy
    metrics["mean_px_acc"] = 0
    for name in classes_map.get_classes():
        metrics["mean_px_acc"] += metrics["{}_px_acc".format(name)]
    metrics["mean_px_acc"] /= float(classes_map.classes_cnt())
    # Mean IoU
    intersection = np.diagonal(confusion_matrix).astype(np.float32) + epsilon
    union = pixel_cnt + np.sum(confusion_matrix, axis=0) - np.diagonal(confusion_matrix) + epsilon
    metrics["mean_iou"] = np.sum(intersection / union) / classes_map.classes_cnt() * 100.
    # Freq weighted IoU
    intersection = np.multiply(np.diagonal(confusion_matrix).astype(np.float32), pixel_cnt) + epsilon
    union = pixel_cnt + np.sum(confusion_matrix, axis=0) - np.diagonal(confusion_matrix) + epsilon
    metrics["freq_weighted_iou"] = np.sum(intersection / union) / np.sum(pixel_cnt).astype(np.float32) * 100.
    return metrics


def compare_ann(gt_ann, test_ann, classes_map):
    """Производит сравнение двух изображений.
        Создает матрицу ошибок (confusion matrix, https://en.wikipedia.org/wiki/Confusion_matrix) размера
        [classes_cnt, classes_cnt], по которой потом расчитываются различные метрики качества.

    Args:
        gt_ann (numpy.array): Эталонная разметка
        test_ann (numpy.array): Результат сегментации изображения
        classes_map (ClassMap): Объект для получения индекса класса по его метке

    Returns:
        numpy.array: Матрица ошибок (confusion matrix)
    """
    if (gt_ann.size != test_ann.size):
        raise ValueError("ground truth and test images have different amount of pixels")
    confusion_matrix = np.zeros([classes_map.classes_cnt(), classes_map.classes_cnt()], np.int64)

    test_ann_scale = 256
    joined_ann = gt_ann.reshape(gt_ann.size) + test_ann_scale * test_ann.reshape(test_ann.size).astype(np.int32)
    unique, counts = np.unique(joined_ann, return_counts=True, axis=0)
    for item in zip(unique, counts):
        true_label = item[0] % test_ann_scale
        true_index = classes_map.get_index_by_label(true_label)
        test_label = item[0] // test_ann_scale
        test_index = classes_map.get_index_by_label(test_label)
        pixel_cnt = item[1]
        confusion_matrix[true_index, test_index] = pixel_cnt
    return confusion_matrix


def get_compare_items_from_list(filenames_list_path, gt_folder, test_folder):
    """Создает список сравниваемых файлов на основе списка файлов.
        Считывает имена файлов из текстового файла и добавляет в список для сравнения только те пары файлов,
        у которых оба файла (с эталонной и сравниваемой) существуют.

    Args:
        filenames_list_path (str): пусть до текстового файла с именами сравниваемых файлов
        gt_folder (str): путь до папки с эталонной разметкой
        test_folder (str): путь до папки со сравниваемой сегментацией

    Return:
        list: Список сравниваемых файлов (CompareItem)
    """
    compare_items = []
    filenames_file = open(filenames_list_path, "r")
    for line in filenames_file:
        filename = line.strip()
        gt_path = os.path.join(gt_folder, filename)
        test_path = os.path.join(test_folder, filename)
        # Check that files exist
        if os.path.exists(gt_path) is not True:
            print("{} doesn't exist".format(gt_path))
            continue
        elif os.path.exists(test_path) is not True:
            print("{} doesn't exist".format(test_path))
            continue
        else:
            compare_items.append(CompareItem(filename, gt_path, test_path))
    compare_items = list(sorted(compare_items, key=lambda x: x.name))
    return compare_items


def get_compare_items_from_folders(gt_folder, test_folder, filename_mask=".*\.png"):
    """Создает список сравниваемых файлов, основываясь на содержимом двух папок.
        Добавляет в список сравниваемых файлов только те файлы, которые одновременной существуют в двух папках.

    Args:
        gt_folder (str): путь до папки с эталонной разметкой
        test_folder (str): путь до папки со сравниваемой сегментацией
        filename_mask (str): шаблон, которому должны соответсвовать имена сравниваемых файлов. По умолчанию используются
            только png файлы

    Return:
        list: Список сравниваемых файлов (CompareItem)
    """
    compare_items = []
    gt_filenames = [name for name in os.listdir(gt_folder) if re.match(filename_mask, name)]

    for filename in gt_filenames:
        gt_path = os.path.join(gt_folder, filename)
        test_path = os.path.join(test_folder, filename)
        # Check that test file exists
        if os.path.exists(test_path) is not True:
            print("{} doesn't exist".format(test_path))
            continue
        else:
            compare_items.append(CompareItem(filename, gt_path, test_path))
    compare_items = list(sorted(compare_items, key=lambda x: x.name))
    return compare_items


def get_compare_item_from_files(gt_path, test_path):
    """Создает список сравниваемых файлов, состоящий из одного элемента.
        Создание одноэлементного списка используется для единообразного представления со всеми остальными методами
        загрузки сравниваемых файлов. Если какой-то из двух файлов не существует, то возвращается пустой список.

    Args:
        gt_path (str): Путь до файла с эталонной размекой
        test_path (str): Путь до сравниваемого файла

    Return:
        list: Список, состоящий не более чем из одного элемента (CompareItem)
    """
    if os.path.exists(gt_path) is not True:
        print("{} doesn't exist".format(gt_path))
        return []
    elif os.path.exists(test_path) is not True:
        print("{} doesn't exist".format(test_path))
        return []
    else:
        filename = gt_path.split("/")[-1]
        return [CompareItem(filename, gt_path, test_path)]


def main():
    parser = argparse.ArgumentParser(description="Tool for evaluate semantic segmentation")
    parser.add_argument("--label_map", required=True, help="Map from class name to label")
    parser.add_argument("--mode", choices=["list", "folder", "file"], default="folder")
    parser.add_argument("--list", required=False, help="List of filenames for compare")
    parser.add_argument("--gt", required=True, help="Path to folder with ground truth files")
    parser.add_argument("--test", required=True, help="Path to folder with files for compare")
    args = parser.parse_args()

    classes_map = ClassMap(args.label_map)

    if "list" == args.mode:
        compare_items = get_compare_items_from_list(args.list, args.gt, args.test)
    elif "folder" == args.mode:
        compare_items = get_compare_items_from_folders(args.gt, args.test)
    else:
        compare_items = get_compare_item_from_files(args.gt, args.test)

    if len(compare_items) == 0:
        print("List of files for compare is empty!")
        return

    table = MetricsTable(compare_items)

    if len(compare_items) > 1:
        total_confusion_matrix = np.zeros([classes_map.classes_cnt(), classes_map.classes_cnt()])

    for item in compare_items:
        gt_ann = cv.imread(item.gt_path, cv.IMREAD_GRAYSCALE)
        test_ann = cv.imread(item.test_path, cv.IMREAD_GRAYSCALE)
        confusion_matrix = compare_ann(gt_ann, test_ann, classes_map)
        metrics = calc_metrics(confusion_matrix, classes_map)
        table.add_row(item.name, metrics)
        if len(compare_items) > 1:
            total_confusion_matrix += confusion_matrix

    if len(compare_items) > 1:
        # Total result
        total_metrics = calc_metrics(total_confusion_matrix, classes_map)
        table.add_row("Total", total_metrics)
    table.finalize()


if __name__ == "__main__":
    main()
