import os
from yt.yson.yson_types import YsonEntity
from collections import defaultdict
from de2.common import Exporter, KEYS_WHITELIST, replace_extension, filter_keys


class CsvExporter(Exporter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.header_row = None
        self.rows = []

    def generate_header_row(self, row):
        keys = list(row.keys())
        keys = filter_keys(keys, self.config)
        keys_str = sorted(
            k for k in keys if k in KEYS_WHITELIST or isinstance(row[k], str)
        )
        other_keys = sorted(k for k in keys if not isinstance(row[k], str))
        self.header_row = keys_str + other_keys
        self.rows.append(self.header_row)

    def process(self, row):
        if not self.header_row:
            self.generate_header_row(row)
        row_ = []
        for k in self.header_row:
            row_.append(row.get(k))
        self.rows.append(row_)

    @staticmethod
    def prepare(val):
        if val is None or isinstance(val, YsonEntity):
            return ""
        return str(val)

    def save_data(self):
        filename = replace_extension(
            os.path.basename(self.config_file), ".csv"
        )
        cwd = os.getcwd()
        if not self.config.get("output_folder"):
            wd = cwd
        else:
            wd = os.path.abspath(self.config["output_folder"])
        filepath = os.path.join(wd, filename)
        with open(filepath, "w") as f:
            f.write(
                "\n".join(
                    "\t".join(map(self.prepare, row)) for row in self.rows
                )
            )


class BeelineCsvExporter(CsvExporter):
    def generate_header_row(self):
        self.header_row = []
        self.header_row.extend(self.grouping_keys)
        for key in self.value_keys:
            self.header_row.append(key)
            self.header_row.append("{}_min".format(key))
            self.header_row.append("{}_max".format(key))
            self.header_row.append("{}_position".format(key))
        self.rows.append(self.header_row)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.row_dict = defaultdict(
            lambda: defaultdict(lambda: defaultdict(list))
        )
        self.grouping_keys = self.config["grouping_keys"]
        self.value_keys = sorted(
            x
            for x in self.config["included_keys"]
            if x not in self.grouping_keys and x != "operator"
        )

    def process(self, row):
        tup = tuple([row.get(k) for k in self.grouping_keys])
        if row["operator"] == self.config["operator"]:
            for k in self.value_keys:
                self.row_dict[tup]["own"][k] = row[k]
        else:
            for k in self.value_keys:
                self.row_dict[tup]["other"][k].append(row[k])

    @staticmethod
    def _get_position(x, lst):
        try:
            merged_lst = sorted([x] + lst, reverse=True)
        except:
            import pdb

            pdb.set_trace()
        p1 = merged_lst.index(x) + 1
        p2 = len(merged_lst) - merged_lst[::-1].index(x)
        return (p1 + p2) / 2.0

    def _make_rows(self):
        self.generate_header_row()
        for tup in sorted(self.row_dict):
            row = dict(zip(self.grouping_keys, tup))
            own = self.row_dict[tup]["own"]
            other = self.row_dict[tup]["other"]
            for key in self.value_keys:
                if isinstance(own[key], (list, YsonEntity)):
                    row[key] = None
                else:
                    row[key] = own[key]
                other_lst = [
                    x
                    for x in other[key]
                    if not isinstance(x, (list, YsonEntity))
                ]
                if other_lst:
                    row["{}_max".format(key)] = max(other_lst)
                    row["{}_min".format(key)] = min(other_lst)
                    if row[key] is not None:
                        row["{}_position".format(key)] = self._get_position(
                            row[key], other_lst
                        )
                else:
                    row["{}_max".format(key)] = None
                    row["{}_min".format(key)] = None
                    row["{}_position".format(key)] = None
            self.rows.append([row.get(k) for k in self.header_row])
            self.row_dict.pop(tup)

    def save_data(self):
        self._make_rows()
        super().save_data()
