import crypta.lib.python.bt.conf.conf as conf

import crypta.lib.python.bt.workflow as workflow
from crypta.lib.python.bt.workflow.targets.table import (
    HasAttribute,
    Exists,
)
from crypta.lib.python.logging import log_context
from yt.wrapper import (
    YtClient,
    TablePath,
    ypath_join,
)
from yt.wrapper.errors import (
    YtOperationFailedError,
)
from library.python.svn_version import svn_revision
from crypta.lib.python.yql.client import (
    create_yql_client,
    create_yql_embedded_client,
)
from crypta.lib.python.native_yt import (
    run_native_map,
    run_native_map_reduce,
    run_native_map_reduce_with_combiner,
    run_native_reduce,
    run_native_join_reduce,
)
from cached_property import cached_property

import abc

import contextlib

import pwd
import os
import socket
import logging
import six


logger = logging.getLogger(__name__)


def _log_yt_operation_fail(operation):
    for error_data in operation.get_stderrs():
        logger.error('Job failed (%s): %s',
                     error_data['host'],
                     error_data.get('stderr', 'No stderr'))


def _collect_destination_info(ytclient, destination):
    def get_destination_info(path):
        return {'row_count': ytclient.get_attribute(path, 'row_count')}

    destination = destination if isinstance(destination, (list, tuple)) else [destination]

    return {str(each): get_destination_info(each) for each in destination}


def _log_outcome(operation):
    ytclient = operation._ytclient

    logger.debug('Operation %s (transaction %s) stats: %s',
                 operation.id, ytclient.COMMAND_PARAMS["transaction_id"],
                 operation.get_job_statistics())

    destination_info = _collect_destination_info(ytclient, operation._destination)
    logger.info('Operation %s (transaction %s) got following outputs: %s',
                operation.id, ytclient.COMMAND_PARAMS["transaction_id"], destination_info)


def wait_operation(operation):
    try:
        operation.wait()
        _log_outcome(operation)
    except YtOperationFailedError:
        _log_yt_operation_fail(operation)
        raise


def wait_if_sync(operation, sync):
    if sync:
        return wait_operation(operation)
    return operation


def _check_operations(operations):
    operation_states = []
    # If couldn't check operation state with
    # yt_wrapper.config['proxy']['request_retry_count'] reties
    # exception raises
    for operation in operations:
        state = operation.get_state()
        operation_states.append((operation, state))
    return operation_states


def _handle_failed_operations(operation_states):
    failed_operations = [operation for operation, state in operation_states
                         if state.is_unsuccessfully_finished()]
    if failed_operations:
        # Raises an exception
        try:
            for failed_operation in failed_operations:
                # simple way to raise YtOperationFailedError properly
                failed_operation.wait()
        except YtOperationFailedError:
            _log_yt_operation_fail(failed_operation)
            raise


def _put_default(dict_, key, value):
    if key not in dict_:
        dict_[key] = value


def _add_spec_defaults(kwargs, title, pool, script_name=None):
    _put_default(kwargs, 'spec', {})
    _put_default(kwargs['spec'], 'pool', pool or conf.yt.pool)
    _put_default(kwargs['spec'], 'title', title)
    _put_default(kwargs['spec'], 'enable_legacy_live_preview', False)
    if script_name:
        _put_default(kwargs['spec'], 'annotations', {})
        _put_default(kwargs['spec']['annotations'], 'script_name', script_name)


def _add_native_spec_defaults(spec, script_name=None):
    if script_name:
        _put_default(spec, 'annotations', {})
        _put_default(spec['annotations'], 'script_name', script_name)


STORAGE_MODES = {
    'default': {},
    'compact': {'erasure_codec': 'lrc_12_2_2',
                'compression_codec': 'brotli_3'},
    'fast': {'compression_codec': 'none'}
}


def _storage_attributes(storage_mode):
    if storage_mode in STORAGE_MODES:
        return STORAGE_MODES[storage_mode]
    else:
        raise ValueError('Unknown storage mode {}'.format(storage_mode))


def _schema(schema):
    if schema == 'default':
        return {}
    return {'schema': schema}


def _setup_operation(operation, ytclient, source, destination):
    operation._source = source
    operation._destination = destination
    operation._ytclient = ytclient


class ConfigurableYtProxyTask(six.with_metaclass(abc.ABCMeta, workflow.Task)):

    proxy = workflow.Parameter(default=None)

    pool = workflow.Parameter(default=None)

    @cached_property
    def yt_config(self):
        config = conf.yt.to_dict()
        if self.proxy:
            config['proxy']['url'] = self.proxy
        if self.pool:
            config['pool'] = self.pool

        config['operation_tracker'] = {'progress_logging_level': 'DEBUG'}

        logger.debug("Using YT config: %s", config)
        return config

    def process_dependency(self, dep):
        """ Inherit non-default parameters unless they are specified. """
        dep = super(ConfigurableYtProxyTask, self).process_dependency(dep)
        if isinstance(dep, ConfigurableYtProxyTask):
            dep.proxy = dep.proxy or self.proxy
            dep.pool = dep.pool or self.pool
        return dep

    @cached_property
    def yt(self):
        return self._init_yt()

    def _init_yt(self, proxy=None):
        yt_config = dict(self.yt_config)
        if proxy is not None:
            yt_config['proxy']['url'] = proxy
        ytclient = YtClient()
        ytclient.TablePath = TablePath
        conf.setup_yt(ytclient, yt_config)
        return ytclient

    def ypath_join(self, *args):
        return ypath_join(*args)


class TransactionalYtTask(six.with_metaclass(abc.ABCMeta, ConfigurableYtProxyTask)):

    @contextlib.contextmanager
    def run_context(self):
        with self.yt.Transaction(timeout=conf.yt.main_transaction_timeout) as yt_transaction:
            self.transaction_id = yt_transaction.transaction_id
            with log_context(yt_transaction_id=self.transaction_id):
                logger.info('Running %s within YT transaction %s',
                            str(self), self.transaction_id)
                yield dict(yt_transaction=yt_transaction)
                logger.info('Task %s is going to commit YT transaction %s',
                            str(self), self.transaction_id)


class YtTask(TransactionalYtTask):
    """Base class for YT tasks."""

    def _tracing(self):
        return dict(
            hostname=socket.getfqdn(),
            username=pwd.getpwuid(os.getuid())[0],
            task=str(self),
            revision=svn_revision(),
            pid=os.getpid(),
            yt_transaction=self.transaction_id,
        )

    def _create_empty_table(self, yt, table):
        path = self.yt.TablePath(table, attributes={
            "crypta_tracing": self._tracing(),
        })
        if not yt.exists(path):
            logger.info('Creating table %s', path)
            yt.create('table', path, recursive=True)

    def _create_destination(self, yt, destination):
        if isinstance(destination, list):
            for table in destination:
                self._create_empty_table(yt, table)
        else:
            self._create_empty_table(yt, destination)

    def _native_kwargs(self, spec=None, sync=True, **kwargs):
        spec = spec or {}
        _add_native_spec_defaults(spec, script_name=self.full_name())
        kwargs = dict(
            token=conf.yt.token,
            pool=(self.pool or conf.yt.pool),
            title=self.operations_title,
            spec=spec,
            sync=sync,
            **kwargs
        )
        return {k: v for (k, v) in kwargs.items() if v is not None}

    def _table_paths(self, paths):
        paths = paths if isinstance(paths, list) else [paths]
        return [self.yt.TablePath(path) for path in paths]

    def list(self, path):
        return self.yt.list(path, max_size=100000, absolute=True)

    @property
    def operations_title(self):
        parameters = ('{}={}'.format(k, v)
                      for (k, v) in self._parameters_with_values())
        return '{}({})'.format(self.short_name(), ';'.join(parameters))

    def reduce(self, reducer, source, destination, sync=True, *args, **kwargs):
        _add_spec_defaults(kwargs, title=self.operations_title, pool=self.pool, script_name=self.full_name())
        self._create_destination(self.yt, destination)
        operation = self.yt.run_reduce(reducer, source, destination, sync=False, *args, **kwargs)
        _setup_operation(operation, self.yt, source, destination)
        logger.info('Reducing [%s] into [%s] in %s (transaction %s)',
                    source, destination, operation.id, self.transaction_id)
        return wait_if_sync(operation, sync)

    def map(self, mapper, source, destination, sync=True, *args, **kwargs):
        _add_spec_defaults(kwargs, title=self.operations_title, pool=self.pool, script_name=self.full_name())
        self._create_destination(self.yt, destination)
        operation = self.yt.run_map(mapper, source, destination, sync=False, *args, **kwargs)
        _setup_operation(operation, self.yt, source, destination)
        logger.info('Mapping [%s] into [%s] in %s (transaction %s)',
                    source, destination, operation.id, self.transaction_id)
        return wait_if_sync(operation, sync)

    def native_map(self, mapper_name, source, destination, state=None, spec=None, files=None, sync=True):
        transaction = self.yt.COMMAND_PARAMS['transaction_id']
        proxy = self.yt.config['proxy']['url']
        logger.info('Mapping [%s] into [%s] using native %s operation (transaction %s)',
                    source, destination, mapper_name, transaction)
        operation = run_native_map(mapper_name, self._table_paths(source), destination, proxy, str(transaction),
                                   **self._native_kwargs(state=state, mapper_files=files or [], spec=spec, sync=sync))
        logger.info('Operation %s is done (transaction %s), outcomes are %s',
                    operation, transaction, _collect_destination_info(self.yt, destination))
        return operation

    def native_reduce(self, reducer_name, source, destination, reduce_by, sort_by=None, state=None, spec=None, files=None, sync=True):
        transaction = self.yt.COMMAND_PARAMS['transaction_id']
        proxy = self.yt.config['proxy']['url']
        logger.info('Reducing [%s] into [%s] using native %s operation (transaction %s)',
                    source, destination, reducer_name, transaction)
        operation = run_native_reduce(reducer_name, self._table_paths(source), destination, reduce_by,
                                      proxy, str(transaction),
                                      sort_by=sort_by,
                                      **self._native_kwargs(reducer_state=state, reducer_files=files or [], spec=spec, sync=sync))
        logger.info('Operation %s is done (transaction %s), outcomes are %s',
                    operation, transaction, _collect_destination_info(self.yt, destination))
        return operation

    def native_map_reduce(self, mapper_name, reducer_name, source, destination, reduce_by,
                          sort_by=None, mapper_state=None, reducer_state=None, spec=None,
                          mapper_files=None, reducer_files=None, sync=True):
        sort_by = sort_by or reduce_by
        transaction = self.yt.COMMAND_PARAMS['transaction_id']
        proxy = self.yt.config['proxy']['url']
        logger.info('Map-Reducing [%s] into [%s] using native map %s, reduce %s operation (transaction %s)',
                    source, destination, mapper_name, reducer_name, transaction)
        operation = run_native_map_reduce(mapper_name, reducer_name, self._table_paths(source), destination, reduce_by, sort_by,
                                          proxy, str(transaction),
                                          **self._native_kwargs(mapper_state=mapper_state, reducer_state=reducer_state, spec=spec,
                                                                mapper_files=mapper_files or [], reducer_files=reducer_files or [], sync=sync))
        logger.info('Operation %s is done (transaction %s), outcomes are %s',
                    operation, transaction, _collect_destination_info(self.yt, destination))
        return operation

    def native_map_reduce_with_combiner(self, mapper_name, combiner_name, reducer_name, source, destination, reduce_by,
                                        sort_by=None, mapper_state=None, reducer_state=None, combiner_state=None, spec=None,
                                        mapper_files=None, reducer_files=None, sync=True):
        sort_by = sort_by or reduce_by
        transaction = self.yt.COMMAND_PARAMS['transaction_id']
        proxy = self.yt.config['proxy']['url']
        logger.info('Map-Reducing [%s] into [%s] using native map %s, reduce %s, combiner %s operation (transaction %s)',
                    source, destination, mapper_name, reducer_name, combiner_name, transaction)
        operation = run_native_map_reduce_with_combiner(
            mapper_name, combiner_name, reducer_name, self._table_paths(source), destination, reduce_by, sort_by, proxy, str(transaction),
            **self._native_kwargs(mapper_state=mapper_state, reducer_state=reducer_state, combiner_state=combiner_state, spec=spec,
                                  mapper_files=mapper_files, reducer_files=reducer_files, sync=sync))
        logger.info('Operation %s is done (transaction %s), outcomes are %s',
                    operation, transaction, _collect_destination_info(self.yt, destination))
        return operation

    def native_join_reduce(self, reducer_name, source, destination, join_by, state=None, spec=None, files=None, sync=True):
        transaction = self.yt.COMMAND_PARAMS['transaction_id']
        proxy = self.yt.config['proxy']['url']
        logger.info('Join-reducing [%s] into [%s] using native %s operation (transaction %s)',
                    source, destination, reducer_name, transaction)
        operation = run_native_join_reduce(reducer_name, self._table_paths(source), destination, join_by,
                                           proxy, str(transaction),
                                           **self._native_kwargs(reducer_state=state, spec=spec, reducer_files=files or [], sync=sync))
        logger.info('Operation %s is done (transaction %s), outcomes are %s',
                    operation, transaction, _collect_destination_info(self.yt, destination))
        return operation

    def map_reduce(self, mapper, reducer, source, destination, sync=True, *args, **kwargs):
        _add_spec_defaults(kwargs, title=self.operations_title, pool=self.pool, script_name=self.full_name())
        self._create_destination(self.yt, destination)
        operation = self.yt.run_map_reduce(mapper, reducer, source, destination, sync=False, *args, **kwargs)
        _setup_operation(operation, self.yt, source, destination)
        logger.info('Map-reducing [%s] into [%s] in %s (transaction %s)',
                    source, destination, operation.id, self.transaction_id)
        return wait_if_sync(operation,  sync=sync)

    def sort(self, source, destination, sync=True, *args, **kwargs):
        _add_spec_defaults(kwargs, title=self.operations_title, pool=self.pool, script_name=self.full_name())
        self._create_destination(self.yt, destination)
        operation = self.yt.run_sort(source, destination, sync=False, *args, **kwargs)
        _setup_operation(operation, self.yt, source, destination)
        logger.info('Sorting [%s] into [%s] by [%s] in %s (transaction %s)',
                    source, destination, kwargs['sort_by'], operation.id, self.transaction_id)

        return wait_if_sync(operation,  sync=sync)

    def merge(self, source, destination, sync=True, *args, **kwargs):
        _add_spec_defaults(kwargs, title=self.operations_title, pool=self.pool, script_name=self.full_name())
        self._create_destination(self.yt, destination)
        operation = self.yt.run_merge(source, destination, sync=False, *args, **kwargs)
        _setup_operation(operation, self.yt, source, destination)
        logger.info('Merging [%s] into [%s] in %s (transaction %s)',
                    source, destination, operation.id, self.transaction_id)
        return wait_if_sync(operation,  sync=sync)

    def join_reduce(self, join_reducer, source, destination, sync=True, *args, **kwargs):
        _add_spec_defaults(kwargs, title=self.operations_title, pool=self.pool, script_name=self.full_name())
        self._create_destination(self.yt, destination)

        operation = self.yt.run_join_reduce(join_reducer, source, destination, sync=False, *args, **kwargs)
        _setup_operation(operation, self.yt, source, destination)
        logger.info('Running join_reduce [%s] into [%s] in %s (transaction %s)',
                    source, destination, operation.id, self.transaction_id)
        return wait_if_sync(operation,  sync=sync)

    def table(self, *args, **kwargs):
        return TablePath(*args, **kwargs)


class YQLTask(six.with_metaclass(abc.ABCMeta, TransactionalYtTask)):
    """Sync YQL task."""

    @abc.abstractproperty
    def syntax_version(self):
        return None

    @abc.abstractproperty
    def query(self):
        """Should be complete YQL query without any boilerplates and
        parameters."""
        pass

    @property
    def attached_files(self):
        return None

    @property
    def udfs(self):
        return None

    @property
    def yql_libs(self):
        return None

    @cached_property
    def yql_client(self):
        if hasattr(conf, 'proto'):
            if hasattr(conf.proto, 'Yql'):
                token = conf.proto.Yql.Token
            else:
                raise Exception('Failed to obtain YQL token')
        else:
            token = conf.service.yandex.yql.token
        return create_yql_client(
            yt_proxy=self.yt_config['proxy']['url'],
            token=token,
            pool=self.pool or conf.yt.pool,
            transaction=self.transaction_id,
            additional_attributes={"script_name": self.full_name()}
        )

    def run_query(self):
        for result in self.yql_client.execute(
            self.query, syntax_version=self.syntax_version,
            attached_files=self.attached_files,
            udfs=self.udfs,
            yql_libs=self.yql_libs,
        ):
            yield result

    def run(self, **kwargs):
        for result in self.run_query():
            logger.debug('YQL result:\n%s', result)


class YQLTaskV0(YQLTask):

    """ Yql task with Syntax v0 by default """

    @property
    def syntax_version(self):
        return 0


class YQLTaskV1(YQLTask):

    """ Yql task with Syntax v1 by default """

    @property
    def syntax_version(self):
        return 1


class EmbeddedYQLTask(six.with_metaclass(abc.ABCMeta, TransactionalYtTask)):
    """Sync YQL task with embedded udf."""

    @abc.abstractproperty
    def syntax_version(self):
        return None

    @abc.abstractproperty
    def query(self):
        """Should be complete YQL query without any boilerplates and
        parameters."""
        pass

    @cached_property
    def yql_embedded_client(self):
        return create_yql_embedded_client(
            yt_proxy=self.yt_config['proxy']['url'],
            token=conf.yt.token,
            pool=self.pool or conf.yt.pool,
            transaction=self.transaction_id,
            mrjob_binary=getattr(self, 'mrjob_binary', 'yql/tools/mrjob/mrjob'),
            udf_resolver_binary=getattr(self, 'udf_resolver_binary', 'yql/tools/udf_resolver/udf_resolver'),
            udfs_dir=getattr(self, 'udfs_dir', 'yql/udfs'),
            title=self.title()
        )

    def run_query(self):
        kwargs = {}
        if self.syntax_version is not None:
            kwargs.update(syntax_version=self.syntax_version)
        return self.yql_embedded_client.execute(self.query, **kwargs)

    def run(self, **kwargs):
        result = self.run_query()
        logger.debug('YQL result:\n%s', result)


class EmbeddedYQLTaskV1(EmbeddedYQLTask):

    """ Sync YQL task with embedded udf and with Syntax v1 by default """

    @property
    def syntax_version(self):
        return 1


class EmbeddedYQLTaskV0(EmbeddedYQLTask):

    """ Sync YQL task with embedded udf and with Syntax v0 by default """

    @property
    def syntax_version(self):
        return 0


class ExternalTask(YtTask, workflow.IndependentTask):
    """
    Subclass for references to external dependencies.

    An ExternalTask's does not have a `run` implementation, which signifies to
    the framework that this Task's :py:meth:`output` is generated outside.
    """
    def run(self, **kwargs):
        pass


class ExternalInput(ExternalTask):
    yt_table = workflow.Parameter()

    def targets(self):
        yield Exists(self.yt, self.yt_table)


class ExternalInputWithAttribute(ExternalTask):
    yt_table = workflow.Parameter()
    value = workflow.Parameter()
    attribute = workflow.Parameter()

    def targets(self):
        yield HasAttribute(self.yt, self.yt_table, self.attribute, self.value)
