import logging
import json
from sandbox import sdk2
import sandbox.projects.common.environments as env

MAX_RECURSION_DEPTH = 6


def parse_tables(tree):
    tables = []

    def get_tables(tree, path):
        if len(path) > MAX_RECURSION_DEPTH:
            raise Exception(
                "Max recursion depth reached at {}. You've probably chosen the wrong root path".format('/'.join(path)))
        for file, value in tree.items():
            if value is None:
                tables.append('/'.join(path) + '/' + file)
            elif file != 'tmp':
                path.append(file)
                get_tables(value, path)
                path.pop()

    get_tables(tree, [])
    return tables


def get_chunk_merge_builder(source_table, chunk_size, pool_name=None):
    from yt.wrapper.spec_builders import MergeSpecBuilder

    builder = MergeSpecBuilder()
    builder.input_table_paths(source_table)
    builder.output_table_path(source_table)
    if pool_name is not None and pool_name != '':
        builder.pool(pool_name)
    builder.spec({
        'force_transform': True,
        'combine_chunks': True,
        'job_io': {
            'table_writer': {
                'desired_chunk_size': chunk_size
            }
        }
    })

    return builder


class SowChunkMerge(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        description = 'A task for running map reduce on sow reports'
        cluster = sdk2.parameters.String('Yt cluster', required=True)
        root_path = sdk2.parameters.String('Yt root path for reports', required=True)
        tokens = sdk2.parameters.YavSecret('Yt token secret', required=True)
        max_req = sdk2.parameters.Integer('Maximum number of simultaneous merge requests', default=10)
        chunk_count_limit = sdk2.parameters.Integer('Chunk limit past which table merge should be run', default=1)
        pool_name = sdk2.parameters.String('Yt pool name', default='k50')
        chunk_size = sdk2.parameters.Integer('Desired chunk size in MB', default=1024)

    class Requirements(sdk2.Task.Requirements):
        environments = [
            env.PipEnvironment('yandex-yt', version='0.11.1'),
        ]

    def get_tables(self, path):
        tree = json.loads(self.client.get(path, format='json'))
        return list(map(lambda x: path + '/' + x, parse_tables(tree)))

    def on_execute(self):
        logging.info('Starting chunk merge')
        import yt.wrapper as yt
        from yt.wrapper.operations_tracker import OperationsTrackerPool

        yt_token = self.Parameters.tokens.data()[self.Parameters.tokens.default_key]
        root_path = self.Parameters.root_path
        max_requests = self.Parameters.max_req
        proxy = self.Parameters.cluster
        chunk_count_limit = self.Parameters.chunk_count_limit
        pool_name = self.Parameters.pool_name
        chunk_size = self.Parameters.chunk_size * 2 ** 20  # converting MB to bytes

        self.client = yt.YtClient(proxy=proxy, token=yt_token)

        paths = self.get_tables(root_path)

        table_count = 0

        with OperationsTrackerPool(max_requests) as pool:
            for table in paths:
                if self.client.has_attribute(table, 'chunk_count'):
                    chunk_count = self.client.get_attribute(table, 'chunk_count')
                    if chunk_count > chunk_count_limit:
                        logging.info('{0} table chunk count is {1}. Merging...'.format(table, chunk_count))
                        pool.add(get_chunk_merge_builder(table, chunk_size, pool_name), client=self.client)
                        table_count += 1

        logging.info('Chunk merge finished. A total of {0} tables have been merged'.format(table_count))
