# coding: utf-8

import os
import types
import codecs
import inspect
import argparse
import datetime
from copy import deepcopy
from collections import defaultdict
from nile.api.v1 import (
    clusters,
    statface as ns,
)

from .streams import TableStream, RedirLogBaseStream, UserSessionsStream, RAW_LOGS
from .utils import DATE_FORMAT, FLOW_NAME_SEPARATOR
from .digraph import get_plotter_vis_html, is_output_table, get_table_path


class Plot(object):
    def __init__(self, job, date, dateend, additional_days, stat_token, stat_output_path):
        self.job = job
        self.date = date
        self.dateend = dateend
        self.stat_token = stat_token
        self.additional_days = additional_days
        if stat_output_path is not None:
            self.stat_output_path = stat_output_path.rstrip('/')
        else:
            self.stat_output_path = None

    def get_checkpoint_name(self, name):
        return '{}.{}'.format(self.__class__.__name__, name)

    def get_statface_report(self, path):
        client = ns.StatfaceProductionClient(token=self.stat_token)

        return ns.StatfaceReport() \
            .path(path) \
            .scale('daily') \
            .client(client)

    def publish(self, stream, statface_report, *args, **kwargs):
        if self.stat_output_path is not None:
            report_path = statface_report._path  # noqa
            yt_table_name = report_path.replace('/', '_')
            return stream.put('{}/{}'.format(self.stat_output_path, yt_table_name), append=True, allow_override=False)
        else:
            return stream.publish(statface_report, *args, **kwargs)

    def put(self, stream, table_name, *args, **kwargs):
        if self.stat_output_path is not None:
            yt_table_name = table_name.replace('/', '_')
            return stream.put('{}/yt_{}'.format(self.stat_output_path, yt_table_name), append=True)
        else:
            return stream.put(table_name, *args, **kwargs)


class Plotter(object):
    def __init__(
            self,
            date,
            dateend=None,
            cluster='hahn',
            yt_proxy=None,
            yql_proxy=None,
            pool=None,
            token=None,
            stat_token=None,
            require_plots=None,
            store_checkpoints=None,
            checkpoints=None,
            udf_path=None,
            layers=None,
            force_checkpoints=False,
            use_1p=False,
            additional_days=0,
            plot_classes=None,
            stat_output_path=None,
            ):
        self.cluster = cluster
        self.yt_proxy = yt_proxy
        self.yql_proxy = yql_proxy
        self.date = date
        self.dateend = dateend if dateend else date  # по-умолчанию можно запускать без dateend, тогда dateend == date
        self.additional_days = additional_days
        self.stat_token = stat_token
        self.pool = pool
        self.force_checkpoints = force_checkpoints
        if token:
            self.token = token
        else:
            self.token = self._get_token()

        self.store_checkpoints = store_checkpoints
        self.checkpoints = checkpoints if checkpoints else []
        self.udf_path = udf_path
        self.layers = layers if layers else []
        self.use_1p = use_1p
        self.stat_output_path = stat_output_path

        self.plots = []
        self.nodes = {}
        self.deps = defaultdict(set)
        self.streams = {}
        self.log_streams = {}
        self.output_tables = {}
        self.require_plots = require_plots if require_plots else []

        self.init_nile()
        self.init_plots(plot_classes)
        self.init_raw_log_streams()
        self.init_nodes()
        self.sift_nodes()

    def _get_token(self):
        return os.environ['YT_TOKEN']

    def _get_title(self):
        daterange = self.date.strftime(DATE_FORMAT)
        if self.dateend > self.date:
            daterange += ' - ' + self.dateend.strftime(DATE_FORMAT)

        nodes = self.require_plots + self.layers
        if len(nodes) == 0:
            nodes = ['default']

        return 'Plotter | {dates} | {nodes}'.format(
            dates=daterange,
            nodes=', '.join(nodes),
        )

    def init_nile(self):
        env_templates = {'title': self._get_title()}
        if self.store_checkpoints:
            env_templates['checkpoints_root'] = self.store_checkpoints

        cluster_params = dict(
            yql_token=self.token,
            yql_token_for={'yt'},
            pool=self.pool,
        )

        if self.yt_proxy:
            cluster_params['proxy'] = self.yt_proxy
        else:
            cluster_params['proxy'] = '{}.yt.yandex.net'.format(self.cluster)

        if self.yql_proxy:
            cluster_params['yql_proxy'] = self.yql_proxy
        else:
            cluster_params['yql_proxy'] = 'yql.yandex.net'
            cluster_params['yql_web_proxy'] = 'yql.yandex-team.ru'

        cluster = clusters.yql.YQL(**cluster_params).env(
            templates=env_templates,
            yt_spec_defaults=dict(
                pool_trees=["physical"],
                tentative_pool_trees=["cloud"]
            ),
            compression_level=dict(
                final_tables='heaviest',
                tmp_tables='lightest',
            ),
            yql_python_udf_path=self.udf_path
        )
        self.job = cluster.job()
        return self

    def init_plots(self, plot_classes):
        for plot in plot_classes:
            self.plots.append(
                plot(self.job, self.date, self.dateend, self.additional_days, self.stat_token, self.stat_output_path)
            )
        return self

    def init_nodes(self):
        for plot in self.plots:
            print '    for plot:', plot, id(plot)
            for fn_name, fn in inspect.getmembers(plot, predicate=inspect.ismethod):
                argspec = inspect.getargspec(fn)
                if argspec.varargs and 'require_params' in argspec.varargs:
                    deps, fn_kwargs, function = fn()
                    deps = self._fix_clusters_path(deps)
                    print '        save extractor {}.{}: '.format(plot.__class__.__name__, fn_name), deps
                    if fn_kwargs:
                        print '            fn_kwargs: {} '.format(fn_kwargs)
                    self.nodes[plot.__class__.__name__ + '.' + fn_name] = (plot, function, fn_kwargs)
                    self.deps[plot.__class__.__name__ + '.' + fn_name].update(deps)
        return self

    def _fix_clusters_path(self, deps):
        """Для таблиц, у которых разный путь на разных кластерах оставляет только путь на нужном кластере без префикса"""
        new_deps = []
        for dep in deps:
            if dep.startswith('hahn://'):
                if self.cluster == 'hahn':
                    new_deps.append(dep.replace('hahn:', ''))
            elif dep.startswith('arnold://'):
                if self.cluster == 'arnold':
                    new_deps.append(dep.replace('arnold:', ''))
            else:
                new_deps.append(dep)
        return new_deps

    def init_raw_log_streams(self):
        for TStream in RedirLogBaseStream.__subclasses__():
            # TODO: не инитить заранее все логи, а lazy load в resolve_raw_logs_deps
            stream = TStream(self.job, self.date, self.dateend, additional_days=self.additional_days)
            self.log_streams[stream.log] = stream
        self.log_streams['user_sessions'] = UserSessionsStream(self.job, self.date, self.dateend, use_1p=self.use_1p)
        return self

    def sift_nodes(self):
        require_plots = set()
        if self.require_plots:
            require_plots.update(self.require_plots)
        if self.layers:
            for plot_name, (plot, function, fn_kwargs) in self.nodes.items():
                if fn_kwargs.get('layer', 'default') in self.layers:
                    require_plots.add(plot_name)
        elif not self.require_plots:
            # при отсутствии --plots и --layers -> используются все узлы из default
            for plot_name, (plot, function, fn_kwargs) in self.nodes.items():
                if fn_kwargs.get('layer', 'default') == 'default':
                    require_plots.add(plot_name)

        print 'sift_nodes, keep: {}'.format(require_plots)
        require_nodes = set()

        # TODO: вызывать ошибку, если опечатка в -p
        while require_plots:
            plot = require_plots.pop()
            if plot in self.deps:
                require_nodes.add(plot)
                require_plots.update(self.deps[plot])
            else:
                continue

        for node in self.nodes.keys():
            if node not in require_nodes:
                print '    delete extractor {}: '.format(node), self.deps[node]
                del self.nodes[node]
                del self.deps[node]

    def save_checkpoints(self):
        if self.force_checkpoints:
            # режим, в котором требуемые ноды через --plots сохраняем в таблички
            for plot in self.require_plots:
                if plot in self.streams:
                    stream = self.streams[plot]
                    stream.checkpoint(self.nodes[plot][0].get_checkpoint_name(plot))

    def save_stream(self, stream, name, plot_class=None, plot_fn=None):
        # для визуализации сохраняем классы-методы YT таблиц
        if is_output_table(stream._parent):
            self.output_tables[get_table_path(stream._parent)] = (plot_class, plot_fn, name)

        stream.label(name)
        self.streams[name] = stream

    def get_raw_logs_deps(self, batch):
        raw_logs_deps = defaultdict(list)
        for node in batch:
            plot_class, plot_fn, plot_fn_kwargs = self.nodes[node]
            for dep in self.deps[node]:
                if dep in RAW_LOGS:
                    raw_logs_deps[dep].append((plot_class, plot_fn, plot_fn_kwargs))
        return raw_logs_deps

    def resolve_raw_logs_deps(self, raw_logs_deps):
        for log, extractors in raw_logs_deps.items():
            for plot_class, plot_fn, plot_fn_kwargs in extractors:
                self.log_streams[log].add_extractor(plot_class, plot_fn, plot_fn_kwargs)
            self.log_streams[log].fulfill()

            for plot_class, plot_fn, nile_stream in self.log_streams[log].get_nile_streams():
                stream_name = plot_class.__class__.__name__ + '.' + plot_fn.__name__
                self.save_stream(nile_stream, stream_name, plot_class, plot_fn)

    def get_raw_tables_deps(self, batch):
        raw_tables_deps = defaultdict(list)
        for node in batch:
            plot_class, plot_fn, plot_fn_kwargs = self.nodes[node]
            for dep in self.deps[node]:
                if dep.startswith('//'):
                    raw_tables_deps[dep].append((plot_class, plot_fn))
        return raw_tables_deps

    def resolve_raw_tables_deps(self, raw_tables_deps):
        for table_name in raw_tables_deps.keys():
            stream = TableStream(self.job, table_name).get_stream()
            self.save_stream(stream, table_name)

    def can_resolve_deps(self, deps):
        for d in deps:
            if d in RAW_LOGS:
                continue
            if d.startswith('//'):
                continue
            if d in self.streams:
                continue
            return False
        return True

    def get_resolvable_deps(self, current_dependencies):
        return {name
                for name, deps in current_dependencies.iteritems()
                if self.can_resolve_deps(deps)}

    def build_flow_graph(self):  # noqa
        current_dependencies = deepcopy(self.deps)
        batch = self.get_resolvable_deps(current_dependencies)
        while batch:
            raw_logs_deps = self.get_raw_logs_deps(batch)
            self.resolve_raw_logs_deps(raw_logs_deps)

            raw_tables_deps = self.get_raw_tables_deps(batch)
            self.resolve_raw_tables_deps(raw_tables_deps)

            print 'batch: {}, raw_logs_deps: {}, raw_tables_deps: {}'.format(
                batch, raw_logs_deps.keys(), raw_tables_deps.keys()
            )

            for node in batch:
                if node in self.streams:
                    # разрешённая зависимость
                    continue

                deps = self.deps[node]
                deps_streams = {x: self.streams[x] for x in deps}
                if None in deps_streams.values():
                    msg = 'NONE DEPS FOUND: {} in streams: {}'.format(deps, self.streams)
                    raise ValueError(msg)

                plot_class, plot_fn, plot_fn_kwargs = self.nodes[node]
                result = plot_fn(plot_class, deps_streams)
                if not result:
                    continue

                # нода может возвращает как nile stream, так и быть генератором нескольких потоков
                if isinstance(result, types.GeneratorType):
                    for nile_stream in result:
                        stream_name = node
                        if isinstance(nile_stream, tuple):
                            # в генераторе можно возвращать как просто поток, так и именованный поток
                            stream_name += FLOW_NAME_SEPARATOR + nile_stream[0]
                            nile_stream = nile_stream[1]
                        self.save_stream(nile_stream, stream_name, plot_class, plot_fn)
                else:
                    self.save_stream(result, node, plot_class, plot_fn)

            # update deps:
            for name in batch:
                del current_dependencies[name]
            for deps in current_dependencies.itervalues():
                deps.difference_update(batch)
            batch = self.get_resolvable_deps(current_dependencies)

        if current_dependencies:
            msg = 'There are not resolved \n\n deps: {},\n\n streams: {}'.format(current_dependencies, self.streams)
            raise ValueError(msg)

        if self.force_checkpoints:
            self.save_checkpoints()

    def get_flow_graph(self):
        return self.job.flow_graph

    def get_optimized_flow_graph(self):
        pass
        # return self.job.optimized_flow_graph()
        # Сейчас не работает. Ошибка как-то связана с методами checkpoint и label  nile стримов

    def save_plotter_vis(self, filename):
        with codecs.open(filename, 'w', encoding='utf-8') as f:
            f.write(get_plotter_vis_html(self))

    def run(self):
        self.job.run(store_checkpoints=bool(self.store_checkpoints), checkpoints=self.checkpoints)


def require(*logs, **kwargs):
    def deco(function):
        def inner(self, *require_params):
            return logs, kwargs, function
        return inner
    return deco


def parse_args():
    parser = argparse.ArgumentParser(description='Plotter')

    parser.add_argument(
        '--date',
        required=True,
        type=lambda d: datetime.datetime.strptime(d, DATE_FORMAT),
        help='Дата [начала], за которую обрабатывать логи в формате YYYY-MM-DD',
    )
    parser.add_argument(
        '--dateend',
        type=lambda d: datetime.datetime.strptime(d, DATE_FORMAT),
        help='Дата окончания диапазона дат [date, dateend] в формате YYYY-MM-DD, необязательный параметр',
    )
    parser.add_argument(
        '--additional-days',
        type=int,
        default=0,
        help='Дополнительные дни в redir-log, считаемые относительно dateend. Нужно для некоторых графиков 30 дней',
    )
    parser.add_argument(
        '--token',
        help='YT token',
    )
    parser.add_argument(
        '--pool',
        help='YT pool',
    )
    parser.add_argument(
        '--stat-token',
        help='Statface token',
    )
    parser.add_argument(
        '--cluster',
        default='hahn',
        help='YT Cluster: hahn или arnold, по-умолчанию: hahn',
    )
    parser.add_argument(
        '--udf-path',
        default='udf/libplotter.so',
        help='Путь к so-библиотеке. По-умолчанию: относительный путь udf/libplotter.so',
    )
    parser.add_argument(
        '--dry-run',
        '-d',
        action='store_true',
        help='Не запускать операции в YQL. Удобно для отладки или визуализации',
    )
    parser.add_argument(
        '--save-flowchart',
        nargs='?',
        const='plotter-vis.html',
        help='Задаёт путь для сохранения html вьюера графа. По-умолчанию не пишется. При указании без параметров: plotter-vis.html',
    )
    parser.add_argument(
        '-p',
        '--plots',
        nargs='*',
        help='Классы графиков, которые пересчитать. По-умолчанию — все из default layer',
    )
    parser.add_argument(
        '-s',
        '--store-checkpoints',
        required=False,
        help='Путь в YT без слеша на конце, куда сохранять данные чекпоинтов'
    )
    parser.add_argument(
        '-c',
        '--checkpoints',
        nargs='*',
        help='Включает режим nile — начать расчёт с заданных чекпоинтов',
    )
    parser.add_argument(
        '--force-checkpoints',
        '-fc',
        action='store_true',
        help='Нужно ли все вычисляемые ноды сложить в YT в виде чекпоинтов',
    )
    parser.add_argument(
        '--layers',
        '--layer',
        '-l',
        nargs='*',
        help='Слои для расчёта только тех узлов (графиков), которые помечены декоратором layer=NNN в require. По-умолчанию - всё, что без декоратора',
    )
    parser.add_argument(
        '--1p',
        '-1p',
        action='store_true',
        dest='use_1p',
        help='Использовать ли однопроцентные юзерсессии (в отладочных целях)',
    )
    parser.add_argument(
        '--stat-output-path',
        required=False,
        help='Включает дебаг режим, в котором выход пишется в таблички на yt, а не в отчеты на стате. '
             'Путь в YT без слеша на конце, куда сохранять данные вместо отчетов на стате'
    )
    return parser.parse_args()


def run_plotter(plot_classes):
    args = parse_args()

    plotter = Plotter(
        date=args.date,
        dateend=args.dateend,
        cluster=args.cluster,
        pool=args.pool,
        token=args.token,
        stat_token=args.stat_token if args.stat_token else args.token,
        require_plots=args.plots,
        store_checkpoints=args.store_checkpoints,
        checkpoints=args.checkpoints,
        layers=args.layers,
        udf_path=args.udf_path,
        force_checkpoints=args.force_checkpoints,
        use_1p=args.use_1p,
        additional_days=args.additional_days,
        plot_classes=plot_classes,
        stat_output_path=args.stat_output_path,
    )

    plotter.build_flow_graph()

    if not args.dry_run:
        plotter.run()

    if args.save_flowchart:
        plotter.save_plotter_vis(args.save_flowchart)
        # plotter.get_flow_graph().save_to_file('flow_graph.png')
