import logging
import os
import yt.wrapper


class StatisticsCounter:
    """
    Class for counting statistics in B2BComparison

    B2BComparison:
        Unfolds statistics if {unfold_statistics_mapper} or ({unfold_statistics_reducer} and {unfold_reduce_fields}) were provided,
        result is in {work_dir}/comparison/{name}/statistics

        Counts rows by {count_reduce_fields}, result is in {work_dir}/comparison/{name}/count

    Parameters:
        name                        directory to store result
        unfold_statistics_mapper    mapper for {work_dir}/statistics
        unfold_statistics_reducer   reducer for {work_dir}/statistics
        unfold_reduce_fields        reduce fields for {work_dir}/statistics
        count_reduce_fields         fields to count statistics by
    """

    @staticmethod
    def _count_reducer(reduce_key, rows):
        result_row = dict(reduce_key)
        result_row['count'] = sum(1 for _ in rows)
        yield result_row

    def __init__(self, name, unfold_statistics_mapper, unfold_statistics_reducer, unfold_reduce_fields, count_reduce_fields):
        self.name = name
        self.unfold_statistics_mapper = unfold_statistics_mapper
        self.unfold_statistics_reducer = unfold_statistics_reducer
        self.unfold_reduce_fields = unfold_reduce_fields
        self.count_reduce_fields = count_reduce_fields


class B2BComparison():
    """
    Reduces {work_dir}/{branch} into {work_dir}/comparison/statistics
    Counts statistics with statistics_counters

    Parameters:
        branches                    branches to comparison
        work_dir                    tables for comparison {work_dir}/{branch}
        yt_client
        satistics_reduce_fields     reduce fields for branch tables
        statistics_reducer          reducer for branch tables
        statistics_counters         see StatisticsCounter
    """

    def _default_statistics_reducer(self, reduce_key, rows):
        result_row = dict(reduce_key)

        rows_by_branch = {}
        for row in rows:
            branch = row['branch']
            if (branch in rows_by_branch):
                logging.warning('Duplicating rows from branch "{branch}" for key {reduce_key}'.format(branch=branch, reduce_key=reduce_key))
            rows_by_branch[branch] = row

        for branch in self.branches:
            row = rows_by_branch[branch]
            for field, value in rows_by_branch[branch].items():
                if field not in reduce_key:
                    result_row['{branch}_{field}'.format(branch=branch, field=field)] = value

        yield result_row

    def __init__(self, branches, work_dir, yt_client, satistics_reduce_fields, statistics_reducer, statistics_counters):
        self.branches = branches
        self.work_dir = work_dir
        self.yt_client = yt_client
        self.satistics_reduce_fields = satistics_reduce_fields
        self.statistics_reducer = statistics_reducer or self._default_statistics_reducer
        self.statistics_counters = statistics_counters

    def _create_dir(self, base_path, dir_name):
        dir_path = os.path.join(base_path, dir_name)
        if not self.yt_client.exists(dir_path):
            logging.info('Creating {dir_path} map_node'.format(dir_path=dir_path))
            self.yt_client.create(path=dir_path, type='map_node', recursive=True)
        else:
            logging.info('Node with path {dir_path} already exists'.format(dir_path=dir_path))
        return dir_path

    @yt.wrapper.with_context
    def _set_branch_mapper(self, row, context):
        row['branch'] = self.branches[context.table_index]
        yield row

    def run(self):
        with yt.wrapper.OperationsTracker(print_progress=True) as yt_tracker:
            comparison_dir = self._create_dir(self.work_dir, 'comparison')
            statistics_table = os.path.join(comparison_dir, 'statistics')
            yt_tracker.add(
                self.yt_client.run_map_reduce(
                    self._set_branch_mapper,
                    self.statistics_reducer,
                    [os.path.join(self.work_dir, branch) for branch in self.branches],
                    statistics_table,
                    reduce_by=self.satistics_reduce_fields,
                )
            )

            for statistics_counter in self.statistics_counters:
                count_dir = self._create_dir(comparison_dir, statistics_counter.name)
                count_table = os.path.join(count_dir, 'count')

                count_mapper = statistics_counter.unfold_statistics_mapper
                unfolded_statistics_table = statistics_table
                if statistics_counter.unfold_reduce_fields and statistics_counter.unfold_statistics_reducer:
                    unfolded_statistics_table = os.path.join(count_dir, 'statistics')
                    yt_tracker.add(
                        self.yt_client.run_map_reduce(
                            statistics_counter.unfold_statistics_mapper,
                            statistics_counter.unfold_statistics_reducer,
                            source_table=statistics_table,
                            destination_table=unfolded_statistics_table,
                            reduce_by=statistics_counter.unfold_reduce_fields,
                        )
                    )
                    count_mapper = None

                yt_tracker.add(
                    self.yt_client.run_map_reduce(
                        count_mapper,
                        statistics_counter._count_reducer,
                        source_table=unfolded_statistics_table,
                        destination_table=count_table,
                        reduce_by=statistics_counter.count_reduce_fields,
                        sync=False
                    )
                )
