import logging
from collections import defaultdict
import json
from sandbox import sdk2
import sandbox.projects.common.environments as env
from os import path

MAX_RECURSION_DEPTH = 6


def get_res_table_name(table_name):
    table_name = table_name.replace('_tmp', '')
    path = table_name.split('/')

    return '/'.join(path[:-2] + [path[-1]])


def get_dates(tables):
    dates = defaultdict(list)
    for table in tables:
        date = table.replace('_tmp', '')
        dates[date].append(table)

    return dates


def parse_tables(tree):
    tables = []

    def get_tmp_tables(tree, path):
        if len(path) > MAX_RECURSION_DEPTH:
            raise Exception(
                "Max recursion depth reached at {}. You've probably chosen the wrong root path".format('/'.join(path)))
        for file, value in tree.items():
            if file == 'tmp':
                for table, _ in value.items():
                    tables.append('/'.join(path) + '/tmp/' + table)
            elif value is not None:
                path.append(file)
                get_tmp_tables(value, path)
                path.pop()

    get_tmp_tables(tree, [])
    return tables


def get_map_reducer_builder(source_tables, res_table, reducer, is_historical, pool_name=None):
    from yt.wrapper.spec_builders import MapReduceSpecBuilder

    reduce_by = ['_k50_query_id']
    sort_by = ['_k50_query_id', {'name': '_k50_uploaded_at', 'sort_order': 'descending'}]
    if is_historical:
        reduce_by.append('_k50_date')
        sort_by.insert(1, '_k50_date')

    builder = MapReduceSpecBuilder()
    builder.input_table_paths(source_tables)
    builder.output_table_paths([res_table])
    if pool_name is not None and pool_name != '':
        builder.pool(pool_name)
    builder.reduce_by(reduce_by)
    builder.sort_by(sort_by)
    builder.begin_reducer().command('python reducer.py').format('json').add_file_path(
        reducer).end_reducer()

    return builder


def is_table_historical(schema):
    is_historical = False
    for row in schema:
        if row['name'] == '_k50_date':
            is_historical = True
            break

    return is_historical


class SowMapReduce(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        description = 'A task for running map reduce on sow reports'
        cluster = sdk2.parameters.String('Yt cluster', required=True)
        root_path = sdk2.parameters.String('Yt root path for reports', required=True)
        tokens = sdk2.parameters.YavSecret('Yt token secret', required=True)
        max_req = sdk2.parameters.Integer('Maximum number of simultaneous map reduce requests', default=10)
        pool_name = sdk2.parameters.String('Yt pool name', default='k50')

    class Requirements(sdk2.Task.Requirements):
        environments = [
            env.PipEnvironment('yandex-yt', version='0.11.1'),
        ]

    def get_tables(self, path):
        tree = json.loads(self.client.get(path, format='json'))
        return list(map(lambda x: path + '/' + x, parse_tables(tree)))

    def rename_table(self, path, new_path):
        self.client.move(path, new_path)

    def remove_table(self, path):
        self.client.remove(path)

    def exists(self, path):
        return self.client.exists(path)

    def on_execute(self):
        logging.info('Starting map reduce')
        import yt.wrapper as yt
        from yt.wrapper.operations_tracker import OperationsTrackerPool
        from yt.wrapper.file_commands import LocalFile

        yt_token = self.Parameters.tokens.data()[self.Parameters.tokens.default_key]
        root_path = self.Parameters.root_path
        max_requests = self.Parameters.max_req
        proxy = self.Parameters.cluster
        pool_name = self.Parameters.pool_name

        self.client = yt.YtClient(proxy=proxy, token=yt_token)

        paths = self.get_tables(root_path)
        table_dates = get_dates(paths)

        tables_to_delete = []
        with OperationsTrackerPool(max_requests) as pool:
            for table_date, tables_to_merge in table_dates.items():
                logging.info('Map reduce on date {0}'.format(table_date))
                tables_to_merge.sort(reverse=True)
                tmp_tables = list(map(lambda x: x + '_tmp', tables_to_merge))
                tables_to_delete += tmp_tables
                for i, table_to_merge in enumerate(tables_to_merge):
                    self.rename_table(table_to_merge, tmp_tables[i])

                source_tables = tmp_tables[:]

                res_table = get_res_table_name(tables_to_merge[0])
                if self.exists(res_table):
                    source_tables.append(res_table)

                schema = self.client.get(source_tables[0] + '/@schema')
                self.client.create('table', path=res_table, attributes={'schema': schema}, ignore_existing=True,
                                   force=False)

                is_historical = is_table_historical(schema)

                pool.add(
                    get_map_reducer_builder(source_tables, res_table,
                                            LocalFile(path.join(path.dirname(__file__), 'reducer.py'),
                                                      file_name='reducer.py'),
                                            is_historical,
                                            pool_name),
                    client=self.client
                )

        for tmp_table in tables_to_delete:
            self.remove_table(tmp_table)

        logging.info('Map reduce finished')
