#!/usr/bin/env python
# coding: utf-8

from six import text_type, binary_type
from future.moves.itertools import zip_longest

from functools import partial, wraps
from operator import methodcaller
import logging

from pymdb.types import User, ContactsUser

from mail.pypg.pypg.common import get_connection
from ora2pg.pg_get import get_user

log = logging.getLogger(__name__)

COMPARATOR = 'cmpable'


def cmpable_item(item):
    if not hasattr(item, COMPARATOR):
        return 42
    return getattr(item, COMPARATOR)()


def is_seq(s):
    import collections.abc
    from inspect import isgenerator

    return (isinstance(s, collections.abc.Sequence) or isgenerator(s)) \
        and not isinstance(s, binary_type) and not isinstance(s, text_type)


def get_volatile_keys(obj):
    try:
        return set(getattr(obj, '_volatile'))
    except AttributeError:
        return set()


def add_default_volatiles(getter):
    @wraps(getter)
    def impl(obj):
        if isinstance(obj, User):
            return set(['here_since', 'is_here', 'purge_date']) | set(getter(obj))
        if isinstance(obj, ContactsUser):
            return set(['here_since', 'is_here'])
        return getter(obj)
    return impl


class AreEqual(object):
    serializer = 'as_dict'
    serialize_object = methodcaller(serializer)

    def __init__(self, sorter=partial(sorted, key=cmpable_item), volatile_getter=get_volatile_keys):
        self.sorter = sorter
        self.volatile_getter = add_default_volatiles(volatile_getter)

    def is_serializable(self, o):
        return hasattr(o, self.serializer)

    def __call__(self, l, r, name):
        def test_both(func):
            return func(l) and func(r)

        log.debug('Comparing %s', name)
        if test_both(lambda o: not self.is_serializable(o)):
            if test_both(is_seq):
                lseq, rseq = self.sorter(l), self.sorter(r)
                return all(
                    self(a, b, name + '[_]')
                    for a, b in zip_longest(lseq, rseq)
                )
            if l != r:
                log.error('%s:\n%r\nis not equal to\n%r', name, l, r)
                # raise AssertionError('Not equal')
                return False
        elif test_both(self.is_serializable):
            l_data = self.serialize_object(l)
            r_data = self.serialize_object(r)
            volatile_names = self.volatile_getter(l) | self.volatile_getter(r)
            all_names = set(l_data) | set(r_data)
            names = all_names - volatile_names
            assert names, (
                'Nothing to compare '
                'all keys are volatile? '
                'all_names %r, volatile_names: %r' % (all_names, volatile_names)
            )

            for k in names:
                for o_data in (l_data, r_data):
                    if k not in o_data:
                        log.error('%s\nnot found in\n%r', k, o_data)
                        return False
                if not self(l_data[k], r_data[k], name + '.' + k):
                    log.error('So %s:\n%r\nis not equal to\n%r', name, l, r)
                    # raise AssertionError('Not equal')
                    return False
        else:
            raise AssertionError(
                '%r\nand\n%r\nare uncomparable' % (l, r)
            )
        return True


def check_users_are_equal(first_uid, second_uid, maildb):
    with get_connection(maildb) as conn:
        return AreEqual()(
            get_user(first_uid, conn=conn),
            get_user(second_uid, conn=conn),
            name='User')
