#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import Counter

from yt.wrapper import aggregator

from crypta.lib.python.yt import (
    schema_utils,
    yt_helpers,
)


STORAGE_MODES = {
    'compact': {
        'compression_codec': 'brotli_8',
    },
    'default': {
        'compression_codec': 'brotli_3',
    },
    None: {
        'compression_codec': 'none',
        'replication_factor': 6,
    }
}


class Yt(yt_helpers.AnnotatedYtClient):
    def create_directory(self, directory):
        if not self.exists(directory):
            self.mkdir(directory, recursive=True)

    def sort_if_needed(self, table, sort_by):
        yt_helpers.sort_if_needed(
            self,
            table,
            sort_by,
        )

    def get_table_attributes(self, table):
        attributes = self.get('{table}/@'.format(table=table))
        return attributes

    def unique(self, source_table, destination_table, unique_by, spec=None):
        def reducer(key, rows):
            yield key

        final_spec = {'title': 'Unique table'}
        if spec is not None:
            final_spec.update(spec)

        if isinstance(source_table, basestring) and yt_helpers.is_sorted(self, source_table, unique_by):
            self.run_reduce(
                reducer,
                source_table,
                destination_table,
                reduce_by=unique_by,
                spec=final_spec
            )
        else:
            self.run_map_reduce(
                None,
                reducer,
                source_table,
                destination_table,
                reduce_by=unique_by,
                spec=final_spec,
            )

    def create_empty_table(self, path, compression='default', schema=None, additional_attributes=None,
                           erasure=True, force=True, optimize_for='scan'):
        if self.exists(path):
            if force:
                self.remove(path)
            else:
                raise ValueError('Table already exists. If you want to overwrite it set force=True')

        attributes = {}
        attributes.update(STORAGE_MODES[compression])
        if schema:
            if isinstance(schema, dict):
                schema = schema_utils.yt_schema_from_dict(schema)
            attributes['schema'] = schema
        if additional_attributes:
            attributes.update(additional_attributes)
        if erasure:
            attributes['erasure_codec'] = 'lrc_12_2_2'
        attributes['optimize_for'] = optimize_for
        self.create('table', path, recursive=True, attributes=attributes)

    def get_yt_schema_dict_from_table(self, table):
        schema_dict = {}
        for field_desc in self.get_attribute(table, 'schema'):
            schema_dict[field_desc['name']] = field_desc['type']
        return schema_dict

    def unique_count(self, source_table, destination_table, unique_by, result_field='count'):
        @aggregator
        def map_count_agg(rows):
            counter = Counter()
            for row in rows:
                key = tuple(row[field] for field in unique_by)
                counter[key] += 1
            for key in counter:
                out_row = {unique_by[i]: field_val for i, field_val in enumerate(key)}
                out_row[result_field] = counter[key]
                yield out_row

        def reduce_count(key, rows):
            count = sum([row[result_field] for row in rows])
            out_dict = dict((k, v) for k, v in key.iteritems())
            out_dict[result_field] = count
            yield out_dict

        self.run_map_reduce(
            map_count_agg,
            reduce_count,
            source_table,
            destination_table,
            reduce_by=unique_by,
        )
