import os.path
from collections import defaultdict

import luigi
import luigi.task_register as register
import yt.wrapper as yt

from lib.luigi import yt_luigi
from rtcconf import config
from utils import mr_utils as mr
from utils import utils

copy_tables = dict()

def _get_yt_output(luigi_task):
    result = list()
    for output in utils.flatten(luigi_task.output()):
        if isinstance(output, (yt_luigi.YtDateTarget, yt_luigi.YtTarget)):
            result.append(output.table)
    return result


def _replace_dir(prod_dir, isolated_dir):
    return (
        prod_dir
        .replace('//home/crypta/production/state/', isolated_dir + 'state/')
        .replace('//crypta/production/state/', isolated_dir + 'state/')
        .replace('//crypta/production/', isolated_dir)
        .replace('//crypta/production/', isolated_dir)
        .replace('//home/crypta/testing/state/', isolated_dir + 'state/')
        .replace('//crypta/testing/state/', isolated_dir + 'state/')
        .replace('//crypta/testing/', isolated_dir)
        .replace('//home/crypta/testing/', isolated_dir)
        .replace('//statbox/', isolated_dir + 'statbox/')
    )


def _wrap_replace_path(task_class, method_name, isolated_dir, required=True):
    if not required and method_name not in task_class.__dict__:
        return

    func = task_class.__dict__[method_name]

    def _wrap(*k, **kw):
        folders = func(*k, **kw)
        if isinstance(folders, dict):
            return {k: _replace_dir(v, isolated_dir) for k, v in folders.iteritems()}
        elif isinstance(folders, dict):
            return [_replace_dir(f, isolated_dir) for f in folders]
        else:
            # single folder
            return _replace_dir(folders, isolated_dir)

    already_modified = getattr(task_class, 'already_modified', False)
    if not already_modified:
        setattr(task_class, method_name, _wrap)
        task_class.modified = True


def _find_dependencies_to_run(task, isolated_dir, changed_tasks=[], level=0):

    indent = '|' + '__' * level
    important_indent = '!' + '__' * level
    print('%sChecking %s...' % (indent, task))

    task_class = task.__class__
    task_class_name = task.__class__.__name__
    changed_itself = 'has_changed' in task_class.__dict__ or task_class in changed_tasks

    children = utils.flatten(task.requires())

    changed_dependencies_per_child = dict()
    output_tables_per_child = defaultdict(set)

    if children:
        for child_task in children:

            deps = _find_dependencies_to_run(child_task, isolated_dir, changed_tasks, level + 1)  # recursion
            changed_dependencies_per_child[child_task.__class__.__name__] = deps

            child_output_tables = _get_yt_output(child_task)
            output_tables_per_child[child_task.__class__.__name__].update(child_output_tables)

    some_has_changes = [cl for cl, deps in changed_dependencies_per_child.iteritems() if len(deps) > 0]
    some_has_not_changes = [cl for cl, deps in changed_dependencies_per_child.iteritems() if len(deps) == 0]
    if some_has_changes and some_has_not_changes:
        # raise Exception('%s are changed and %s are not: ' % (some_has_changes, some_has_not_changes)
        #                 + 'can\'t determine whether inputs should be replaced')
        print('%s%s Warn! %s are changed and %s are not. ' % (important_indent, task_class_name,
                                                              some_has_changes, some_has_not_changes))
        print('%sThe following inputs will be copied from prod to isolated dir:' % important_indent)
        for child_task_class in some_has_not_changes:
            child_output_tables = list(output_tables_per_child[child_task_class])
            print('%s%s -> %s' % (important_indent, child_task_class, child_output_tables))
            for t in child_output_tables:
                copy_tables[t] = _replace_dir(t, isolated_dir)


    changed = changed_itself or some_has_changes

    if changed:
        if not issubclass(task_class, yt_luigi.BaseYtTask):
            # TODO: consider "pythonic" duck-typing instead
            raise Exception('Smart run only supports yt_luigi.BaseYtTask, %s is not one' % task_class)

        _wrap_replace_path(task_class, 'output_folders', isolated_dir)
        _wrap_replace_path(task_class, 'workdir', isolated_dir, required=False)

        # the leaf changed dependency should take input from production
        if some_has_changes:
            _wrap_replace_path(task_class, 'input_folders', isolated_dir)

            if changed_itself:
                print('%sTask %s has been changed. It will be run in isolated dir' % (important_indent, task_class_name))
            else:
                print('%sTask %s is affected. It will be run in isolated dir' % (important_indent, task_class_name))
        else:
            if changed_itself:
                print('%sTask %s has been changed. It will be run in isolated dir, but take usual input' % (important_indent, task_class_name))

    changed_tasks = set(utils.flatten(deps for cl, deps in changed_dependencies_per_child.iteritems()))
    if changed_itself:
        changed_tasks.add(task)
    return changed_tasks


def run_isolated(isolated_path, date1, changed_tasks, luigi_task_class, *task_params, **task_kv):

    task_instance_to_check = luigi_task_class(*task_params, **task_kv)
    print('\nPrepare graph to run in %s' % isolated_path)
    _find_dependencies_to_run(task_instance_to_check, isolated_path, utils.flatten(changed_tasks))
    print('Graph prepared.\n')

    print('These tables will be copied to start your isolated run:')
    for orig_table, copy_table in copy_tables.iteritems():
        print('%s -> %s' % (orig_table, copy_table))

    print('\n')

    answer = raw_input('Are you sure you want to continue? [yes/no]\n')
    if answer == 'yes':
        mr.mkdir(isolated_path)
        mr.mkdir(isolated_path + 'statbox')

        for orig_table, copy_table in copy_tables.iteritems():
            copy_dir, _ = os.path.split(copy_table)
            mr.mkdir(copy_dir)
            mr.copy(orig_table, copy_table)
            mr.set_generate_date(copy_table, date1)


        register.Register.clear_instance_cache()

        task_instance_to_run = luigi_task_class(*task_params, **task_kv)
        luigi.build([task_instance_to_run], scheduler_port=8083, workers=5)
        return task_instance_to_run


if __name__ == '__main__':
    yt.config.set_proxy(config.MR_SERVER)


    for vt in ['exact', 'exact/cluster', 'tmp']:
        s1 = 0
        s2 = 0
        recs = yt.read_table('//crypta/production/state/graph/2016-05-16/%s/stat/crypta_ids_size_hist' % vt, raw=False)
        for r in recs:
            if r['crypta_id_size'] < 15:
                s1 += r['count'] * r['crypta_id_size']
            else:
                s2 += r['count'] * r['crypta_id_size']

        print(vt, s1, s2, s1 / float(s1 + s2))
