# coding: utf-8
import io
import io

import logging
import operator
import re
import traceback
import unicodecsv
from collections import defaultdict

import openpyxl
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE

import yenv
from django.core.exceptions import ValidationError
from ylog import context as log_context

from review.core import (
    const,
    models,
)
from review.lib import errors
from review.lib.file_properties import XlsFileProperties


log = logging.getLogger(__name__)


class CSVSerializer(object):
    @staticmethod
    def serialize(data, **csv_params):
        buff = io.BytesIO()
        if not data:
            return buff.getvalue()

        header = list(data[0].keys())
        csv_writer = unicodecsv.DictWriter(buff, fieldnames=header, **csv_params)
        csv_writer.writeheader()
        csv_writer.writerows(data)
        data = buff.getvalue()
        return data

    @classmethod
    def parse(cls, file, validation_form=None):
        reader = unicodecsv.DictReader(file, dialect=cls._get_dialect(file))
        return _validate_file_rows(list(reader), validation_form)

    @staticmethod
    def _get_dialect(file):
        file_data = file.read()
        file.seek(0)
        if isinstance(file_data, bytes):
            file_data = file_data.decode('utf-8')
        return unicodecsv.Sniffer().sniff(file_data)


class XLSSerializer(object):

    @staticmethod
    def get_buffer(data, file_properties=None):
        fp = file_properties or XlsFileProperties()
        if fp.template:
            fp.template.seek(0)
            wb = openpyxl.load_workbook(io.BytesIO(fp.template.read()))
            sheet = wb.get_sheet_by_name(fp.sheet_name)
        else:
            wb = openpyxl.Workbook()
            sheet = wb.active
            sheet.title = fp.sheet_name

        if data:
            header = list(data[0].keys())
            for column_index, title in enumerate(header, 1):
                sheet.cell(row=1, column=column_index).value = title
            for row_index, row in enumerate(data, 2):
                for column_index, field_name in enumerate(header, 1):
                    elem = row.get(field_name)
                    if isinstance(elem, str):
                        elem = ILLEGAL_CHARACTERS_RE.sub(' \\wrong_symbol\\ ', elem)
                    sheet.cell(row=row_index, column=column_index).value = elem

        buff = io.BytesIO()
        wb.save(buff)

        return buff

    @staticmethod
    def serialize(data, file_properties=None):
        buff = XLSSerializer.get_buffer(data, file_properties)
        data = buff.getvalue()
        return data

    @staticmethod
    def parse(file, validation_form=None):
        import xlrd
        rb = xlrd.open_workbook(file_contents=file.read())
        sheet = rb.sheet_by_index(0)
        # read header values into the list
        keys = [sheet.cell(0, col_index).value for col_index in range(sheet.ncols)]

        result = []
        for row_index in range(1, sheet.nrows):
            row = {
                keys[col_index]: sheet.cell(row_index, col_index).value
                for col_index in range(sheet.ncols)
            }
            result.append(row)

        return _validate_file_rows(result, validation_form)


def _validate_file_rows(rows, form):
    if form is None or not rows:
        return rows
    person_review_id = rows[0].get('id')
    if person_review_id is None:
        review = None
    else:
        review = models.Review.objects.select_related(
            'scale'
        ).get(person_review__id=person_review_id)
    validated = [form(data=r, review=review) for r in rows]
    err_row_to_fields = [
        (i, r.errors) for i, r in enumerate(validated)
        if not r.is_valid()
    ]
    if not err_row_to_fields:
        return [r.cleaned_data for r in validated]

    field_to_rows = defaultdict(dict)
    for row_num, errs in err_row_to_fields:
        for field, err in errs.items():
            field_to_rows[field][row_num] = err[0]
    raise errors.FileIncorrect('incorrect file format', params=field_to_rows)


def serialize_file_sheet(file_, form=None):
    try:
        return CSVSerializer.parse(file_, form)
    except ValidationError:
        raise
    except Exception as err:
        # file cursor could be moved
        trace = traceback.format_exc()
        file_.seek(0)
        try:
            return XLSSerializer.parse(file_, form)
        except ValidationError:
            raise
        except:
            log.error(err)
            log.error(trace)
            log.error(traceback.format_exc())
            raise errors.CantParseFile()


class SerializerField(object):
    def __init__(self, target, source=None, complex=None, many=False,
                 verbose=None, cast_type=None, complex_fields=None):
        self.target, self.source, self.complex = target, source, complex
        self.many, self.verbose, self.cast_type = many, verbose, cast_type
        self.complex_fields = complex_fields

    def __eq__(self, other):
        if isinstance(other, SerializerField):
            return self.target == other.target
        else:
            return self.target == other

    def __hash__(self):
        return hash(self.target)

    def __repr__(self):
        return '<%s: %s>' % (
            self.__class__.__name__,
            str(self.__dict__),
        )


F = SerializerField


class ProxySerializableObject(object):
    def __init__(self, object_dict):
        self.object_dict = object_dict

    def __getattr__(self, item):
        return self.object_dict[item]


class ProxySerializableObjectGetter(object):
    def __init__(self, field_map):
        field_map = list(field_map.items())
        self.field_names = [name for name, _ in field_map]
        getters = [getter_field for _, getter_field in field_map]
        self.getter = operator.attrgetter(*getters)

    def __call__(self, obj, *args, **kwargs):
        field_values = self.getter(obj)
        if len(self.field_names) == 1:
            field_values = (field_values,)
        fields = {
            field_name: field_value
            for field_name, field_value in zip(self.field_names, field_values)
        }
        return ProxySerializableObject(fields)


class SerializableObjectCallableGetter(object):
    def __init__(self, field_source):
        self.field_source = field_source

    def __call__(self, obj, *args, **kwargs):
        if hasattr(obj, self.field_source):
            attr = getattr(obj, self.field_source)
            if callable(attr):
                return attr()
        raise RuntimeWarning("No such callable in object: {}".format(self.field_source))


def get_complex_field_prefix(field):
    if field.source is None:
        return field.target + '.'
    elif field.source.endswith('*'):
        return field.source[:-len('*')]
    else:
        return field.source + '.'


def flat_field(field):
    if isinstance(field, str):
        return {field}
    elif isinstance(field, SerializerField):
        flatten = {field.target}
        if field.complex is not None:
            prefix = get_complex_field_prefix(field)
            for field_name in field.complex.get_all_fields():
                if field.complex_fields and field_name not in field.complex_fields:
                    continue
                flatten.add(prefix + field_name)
        return flatten


class Serializer(object):
    fields = set()
    default_fields = None
    root_cls = dict
    log_instance_name = ''
    log_instance_id_attr = 'pk'
    _serialize_fields_cache = {}
    _serialize_fields_source_cache = {}

    @classmethod
    def combine_fields(cls, fields):
        fields = set(fields)
        return (cls.fields - fields) | fields

    @classmethod
    def get_all_fields(cls):
        flatten_fields = set()
        for field in cls.fields:
            flatten_fields |= flat_field(field)
        return flatten_fields

    @classmethod
    def _get_fields_to_serialize(cls, requested_fields):
        requested_fields = set(requested_fields)

        result = []
        for field in cls.fields:
            if not isinstance(field, SerializerField):
                field = F(field)

            nested_fields = None
            if field.complex:
                nested_fields = field.complex.get_all_fields()
                if field.complex_fields:
                    nested_fields = {f for f in nested_fields if f in field.complex_fields}

                if field.target not in requested_fields:
                    prefix = get_complex_field_prefix(field)
                    for nested_field in list(nested_fields):
                        name = prefix + nested_field
                        if name not in requested_fields:
                            nested_fields.remove(nested_field)

            if field.target in requested_fields or nested_fields:
                result.append((field, nested_fields if field.complex else None))

        return result

    @classmethod
    def get_fields_to_serialize(cls, requested_fields):
        requested_fields_key = (tuple(set(requested_fields)), id(cls))
        if requested_fields_key not in cls._serialize_fields_cache:
            cls._serialize_fields_cache[requested_fields_key] = cls._get_fields_to_serialize(requested_fields)
        return cls._serialize_fields_cache[requested_fields_key]

    @classmethod
    def _get_field_source(cls, field, requested_fields, context):
        if field.source is None:
            return operator.attrgetter(field.target)
        elif callable(field.source):
            return field.source
        elif field.source.endswith('()'):
            field_source = field.source[:-len('()')]
            if hasattr(cls, field_source):
                return getattr(cls, field_source)
            else:
                return SerializableObjectCallableGetter(field_source)
        elif field.source.startswith('[') and field.source.endswith(']'):
            return operator.itemgetter(field.source[len('['):-len(']')])
        elif field.source.endswith('*'):
            field_source = field.source[:-len('*')]
            field_map = {
                field_name: field_source + field_name
                for field_name in requested_fields
            }
            return ProxySerializableObjectGetter(field_map)
        elif field.source.endswith('?'):
            CONTEXT_KEY_REGEX = r'(.*)\?([\w_]+)\?'
            match = re.match(CONTEXT_KEY_REGEX, field.source)
            if not match:
                raise Exception('Bad source format')
            field_prefix, context_key = match.groups()
            if context_key not in context:
                raise Exception('Key %s not found in context', context_key)
            return operator.attrgetter(field_prefix + context[context_key])
        else:
            return operator.attrgetter(field.source)

    @classmethod
    def get_field_source(cls, field, requested_fields, context=None):
        context = context or {}
        requested_fields_key = requested_fields and tuple(set(requested_fields))
        context_key = tuple(context.items())
        field_source_key = (field, requested_fields_key, context_key, cls.__name__)
        if field_source_key not in cls._serialize_fields_source_cache:
            cls._serialize_fields_source_cache[field_source_key] = cls._get_field_source(field, requested_fields, context)
        return cls._serialize_fields_source_cache[field_source_key]

    @classmethod
    def field_has_special_value(cls, field_value):
        return False

    @classmethod
    def handle_special_value(cls, field_value):
        return NotImplemented

    @classmethod
    def _serialize_field(cls, field, field_value):
        if cls.field_has_special_value(field_value):
            return cls.handle_special_value(field_value)

        if field.many:
            return [cls._serialize_one_field_item(field, value) for value in field_value]
        else:
            return cls._serialize_one_field_item(field, field_value)

    @classmethod
    def _serialize_one_field_item(cls, field, field_value):
        if field.verbose is not None:
            return field.verbose.get(field_value)
        elif field.cast_type is not None:
            return field.cast_type(field_value)
        else:
            return field_value

    @classmethod
    def serialize(cls, obj, fields_requested=None, context=None):
        result = cls.root_cls()

        if fields_requested is None:
            fields_requested = cls.default_fields if cls.default_fields is not None else cls.get_all_fields()

        fields = cls.get_fields_to_serialize(fields_requested)
        for field, available_fields in fields:
            if field.complex:
                obj_getter = cls.get_field_source(field, available_fields, context)
                serializer_obj = obj_getter(obj)
                func = field.complex.serialize_many if field.many else field.complex.serialize
                serialized = func(serializer_obj, available_fields, context)
                if not serialized:
                    continue
            else:
                try:
                    field_value = cls.get_field_source(field, available_fields, context)(obj)
                    serialized = cls._serialize_field(field, field_value)
                except Exception:
                    instance_str = '%s:%s' % (
                        cls.log_instance_name or obj.__class__,
                        getattr(obj, cls.log_instance_id_attr, '?'),
                    )
                    with log_context.LogContext(instance=instance_str):
                        log.exception(
                            'serialization failed for field `%s` in %s',
                            field.target,
                            cls.__name__,
                        )
                    if yenv.type == 'development':
                        raise
                    continue
            result[field.target] = serialized

        return result

    @classmethod
    def serialize_many(cls, objects, fields_requested=None, context=None):
        if fields_requested is None:
            fields_requested = cls.default_fields if cls.default_fields is not None else cls.get_all_fields()
        if cls.field_has_special_value(objects):
            return cls.handle_special_value(objects)
        return [cls.serialize(obj, fields_requested, context) for obj in objects]


class BaseSerializer(Serializer):
    @classmethod
    def field_has_special_value(cls, field_value):
        return field_value in (
            const.NO_ACCESS,
            const.NOT_SET,
            const.DISABLED,
            const.NOT_AVAILABLE
        )

    @classmethod
    def handle_special_value(cls, field_value):
        return field_value
