import inspect
import logging
import six
from functools import wraps
from collections import Counter

from sandbox import sdk2
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.common.errors import Wait
from sandbox.common.types.misc import NotExists


logger = logging.getLogger(__name__)


FAILED_STAGE = '_failed_stage'


class StageRequirementsUnsatisfied(Exception):
    pass


class DuplicatedResultFields(Exception):
    pass


class StageMissingResults(Exception):
    pass


class InconsistentRequirements(Exception):
    pass


def _get_context(task):
    if isinstance(task, SandboxTask):
        return task.ctx
    return task.Context


def _set_context_field(context, key, value):
    if isinstance(context, dict):
        context[key] = value
    else:
        setattr(context, key, value)


def _get_context_field(context, key, default=None):
    if isinstance(context, dict):
        return context.get(key, default)
    return getattr(context, key, default)


def get_non_filled_fields_from_context(context, fields):
    non_filled = []
    for item in fields:
        if _get_context_field(context, item, None) in (None, NotExists):
            non_filled.append(item)

    return non_filled


def get_results_from_context(context, provides, result_is_dict):
    wishlist = get_non_filled_fields_from_context(context, provides)
    if wishlist:
        raise RuntimeError('Task does not have filled context fields: {}'.format(', '.join(wishlist)))
    if result_is_dict:
        return {item: _get_context_field(context, item) for item in provides}
    elif len(provides) > 1:
        return tuple(_get_context_field(context, item) for item in provides)
    else:
        return _get_context_field(context, provides[0])


def _get_requirements(method, provides, requires):
    args_spec = inspect.getargspec(method)
    requirements = args_spec.args[1:len(args_spec.args) - len(args_spec.defaults or ())]  # cut off 'self' and possible kwargs

    if requires and set(requirements) != set(requires):
        missing_declaration = list(set(requirements) - set(requires))
        extra_declaration = list(set(requires) - set(requirements))
        logger.error(
            'Stage \'%s\', declared requirements in \'requires\' decorator parameter: %s, actual requirements derived from parameters: %s.',
            method.__name__,
            requires,
            requirements,
        )
        msg = ''
        if missing_declaration:
            msg += 'Missing declaration of items: {}. '.format(missing_declaration)
        if extra_declaration:
            msg += 'Extra declared items: {}. '.format(extra_declaration)
        raise InconsistentRequirements('Declared requirements are inconsistent with actual requirements for stage \'{}\': {}'.format(method.__name__, msg))

    non_string_items = list(item for item in provides if not isinstance(item, six.string_types))
    if non_string_items:
        raise ValueError("Following declared 'provides' items {} for stage '{}' have invalid types, expected: {}".format(method.__name__, non_string_items, six.string_types))

    duplicated = list(k for k, count in Counter(provides).items() if count > 1)
    if duplicated:
        raise DuplicatedResultFields("Fields {} are duplicated in 'provides' declaration of stage '{}'".format(duplicated, method.__name__))

    return requirements


def _check_requirements(method, context, requirements):
    unsatisfied_requirements = get_non_filled_fields_from_context(context, requirements)
    if unsatisfied_requirements:
        raise StageRequirementsUnsatisfied("Requirements {} for stage '{}' aren't satisfied".format(unsatisfied_requirements, method.__name__))


def _save_context(task, context):
    if isinstance(task, SandboxTask):
        task.ctx.update(**context)
    else:
        context.save()


def memoize_stage(requires=(), max_entries=1):
    def inner(method):
        requirements = _get_requirements(method, (), requires)

        @wraps(method)
        def method_wrapper(task, **kwargs):

            context = _get_context(task)
            entries = int(_get_context_field(context, '__{}_count'.format(method.__name__), 0) or 0)
            if entries >= max_entries:
                logger.info('Method %s was visited %s times already, exiting', method.__name__, entries)
                return

            args = [_get_context_field(context, arg_name) for arg_name in requirements]
            try:
                result = method(task, *args, **kwargs)
            except Wait:
                raise
            except Exception:
                _set_context_field(context, FAILED_STAGE, method.__name__)
                raise

            _set_context_field(context, '__{}_count'.format(method.__name__), entries + 1)
            _save_context(task, context)

            return result

        return method_wrapper

    return inner


def stage(provides=(), requires=(), result_is_dict=False, force_run=False):
    if isinstance(provides, six.string_types):
        provides = (provides, )

    if isinstance(requires, six.string_types):
        requires = (requires, )

    def inner(method):
        requirements = _get_requirements(method, provides, requires)

        @wraps(method)
        def method_wrapper(task, **kwargs):
            context = _get_context(task)

            _check_requirements(method, context, requirements)

            if force_run:
                wishlist = provides
            else:
                wishlist = get_non_filled_fields_from_context(context, provides)
                if not wishlist:
                    logger.info("All declared items %s from 'provides' parameter for stage '%s' were provided earlier, nothing to do", list(provides), method.__name__)
                    return get_results_from_context(context, provides, result_is_dict)

            args = [_get_context_field(context, arg_name) for arg_name in requirements]

            reuse_result = all([
                _get_context_field(context, 'copy_of', None),
                method.__name__ in _get_context_field(context, '__reusable_stages', [])
            ])
            if reuse_result:
                try:
                    prototype_task = sdk2.Task[_get_context_field(context, 'copy_of')]
                    prototype_context = _get_context(prototype_task)
                    logger.debug('Trying to copy \'%s\' stage results', method.__name__)
                    original_result = get_results_from_context(prototype_context, provides, result_is_dict)
                except Exception as e:
                    logger.warning('Cannot copy \'%s\' stage results from prototype task: %s', method.__name__, e)
                    reuse_result = False

            if not reuse_result:
                logger.debug('Running \'%s\' stage', method.__name__)
                try:
                    original_result = method(task, *args, **kwargs)
                except Wait:
                    raise
                except Exception:
                    _set_context_field(context, FAILED_STAGE, method.__name__)
                    raise

            result = original_result
            if not (result_is_dict and isinstance(result, dict)):
                if not isinstance(result, tuple):
                    result = (result, )

                if len(result) != len(provides):
                    logger.error("Stage '%s' results: %s, declared 'provides' parameter was %s", method.__name__, result, provides)
                    raise StageMissingResults("Stage '{}' provided {} items, expected {}".format(method.__name__, len(result), len(provides)))

                result = {provides[idx]: result[idx] for idx in range(len(provides))}

            missing_provided_items = set(wishlist) - set(result)
            if missing_provided_items:
                logger.error("Provided result: %s, expected items: %s", result, wishlist)
                raise StageMissingResults("Items {} were not provided by stage '{}'".format(list(missing_provided_items), method.__name__))

            for provided_item in wishlist:
                value = result[provided_item]
                _set_context_field(context, provided_item, value)

            _save_context(task, context)

            return original_result

        return method_wrapper

    return inner
