"""
Utilities to convert Protobuf messages to and from Python dict.
"""
from __future__ import unicode_literals

import calendar
import datetime
import re

import six

import google.protobuf.json_format
import google.protobuf.message

FD = google.protobuf.descriptor.FieldDescriptor

if six.PY2:
    long_int = long  # noqa
else:
    long_int = int

_types = {
    FD.TYPE_DOUBLE: float,
    FD.TYPE_FLOAT: float,
    FD.TYPE_INT64: long_int,
    FD.TYPE_UINT64: long_int,
    FD.TYPE_INT32: int,
    FD.TYPE_FIXED64: float,
    FD.TYPE_FIXED32: float,
    FD.TYPE_BOOL: bool,
    FD.TYPE_STRING: six.text_type,
    FD.TYPE_BYTES: lambda x: x.decode('string_escape'),
    FD.TYPE_UINT32: int,
    FD.TYPE_SFIXED32: float,
    FD.TYPE_SFIXED64: float,
    FD.TYPE_SINT32: int,
    FD.TYPE_SINT64: long_int,
}

_INT_TYPES = (
    FD.TYPE_INT32,
    FD.TYPE_INT64,
    FD.TYPE_SINT32,
    FD.TYPE_SINT64,
    FD.TYPE_FIXED32,
    FD.TYPE_FIXED64,
    FD.TYPE_SFIXED32,
    FD.TYPE_SFIXED64,
)
_FLOAT_TYPES = (
    FD.TYPE_FLOAT,
    FD.TYPE_DOUBLE
)
_NUM_TYPES = _INT_TYPES + _FLOAT_TYPES
_STRING_TYPES = (
    FD.TYPE_STRING,
)
_BOOL_TYPES = (
    FD.TYPE_BOOL,
)
_STRUCT_TYPES = (
    FD.TYPE_MESSAGE,
)

LABEL_REPEATED = FD.LABEL_REPEATED


def pb_to_jsondict(message, including_default_value_fields=True, preserve_field_names=False,
                   pb_json_printer_cls=google.protobuf.json_format._Printer):
    p = pb_json_printer_cls(including_default_value_fields=including_default_value_fields,
                            preserving_proto_field_name=preserve_field_names)
    return p._MessageToJsonObject(message)


def jsondict_to_pb(d, message):
    return google.protobuf.json_format.ParseDict(d, message, ignore_unknown_fields=True)


def enum_value_to_name(enum_type, value):
    for name, enum_value_desc in enum_type.values_by_name.iteritems():
        if enum_value_desc.number == value:
            return name
    raise ValueError('{} enum does not contain value {}'.format(enum_type.full_name, value))


def enum_value_to_name_optional(enum_type, value):
    if value is None:
        return None
    for name, enum_value_desc in enum_type.values_by_name.iteritems():
        if enum_value_desc.number == value:
            return name
    raise ValueError('{} enum does not contain value {}'.format(enum_type.full_name, value))


def enum_name_to_value(enum_type, name):
    if name in enum_type.values_by_name:
        return enum_type.values_by_name[name].number
    raise ValueError('{} enum does not contain name {}'.format(enum_type.full_name, name))


def default_get_doc_key(field_desc):
    return field_desc.name


def doc_to_pb(pb, doc, get_doc_key=default_get_doc_key):
    """
    Recursively initializes provided protobuf object (pb) with values from MongoDB document (doc).

    :param pb: protobuf object to initialize
    :param doc: dictionary to get values from
    :param get_doc_key: callable that receives a field descriptor and returns its document key
    """
    if not doc:
        return pb
    for field in pb.DESCRIPTOR.fields:
        doc_key = get_doc_key(field)
        if doc_key not in doc:
            # Skip unknown fields
            continue

        if field.type == FD.TYPE_MESSAGE:
            pass
        elif field.type == FD.TYPE_ENUM:
            field_type = lambda v: enum_name_to_value(field.enum_type, v)
        elif field.type in _types:
            field_type = _types[field.type]
        else:
            raise ValueError("Field {}.{} of type '{}' is not supported".format(
                pb.__class__.__name__, field.name, field.type))

        doc_value = doc[doc_key]
        if field.label == FD.LABEL_REPEATED:
            pb_value = getattr(pb, field.name, None)
            if isinstance(doc_value, list):
                if field.type == FD.TYPE_MESSAGE:
                    for v in doc_value:
                        doc_to_pb(pb_value.add(), v, get_doc_key=get_doc_key)
                else:
                    for v in doc_value:
                        pb_value.append(field_type(v))
            elif isinstance(doc_value, dict):
                if field.type == FD.TYPE_MESSAGE:
                    for k, v in doc_value.iteritems():
                        pb_value[k] = v
        else:
            if field.type == FD.TYPE_MESSAGE:
                if field.is_extension:
                    nested_pb = pb.Extensions[field]
                else:
                    nested_pb = getattr(pb, field.name, None)
                if field.message_type and field.message_type.full_name == 'google.protobuf.Timestamp':
                    if isinstance(doc_value, datetime.datetime):
                        nested_pb.FromDatetime(doc_value.replace(tzinfo=None))
                    else:
                        nested_pb.FromMilliseconds(doc_value)
                else:
                    doc_to_pb(nested_pb, doc_value, get_doc_key=get_doc_key)
            else:
                setattr(pb, field.name, field_type(doc_value))
    return pb


def pb_to_doc(pb, include_empty_fields=False, get_doc_key=default_get_doc_key):
    """
    Creates a MongoDB document with values from protobuf object.

    :param pb: protobuf object
    :param get_doc_key: callable that receives a field descriptor and returns its document key
    """
    doc = {}
    if include_empty_fields:
        fields = pb.DESCRIPTOR.fields
    else:
        fields = [item[0] for item in pb.ListFields()]
        # In dictionary we need all set fields and enums
        # Even if we explicitly set enum to value which happens to be zero
        # this field will not be present in ListFields() result.
        # So we manually add them.
        fields.extend([field for field in pb.DESCRIPTOR.fields if field.type == FD.TYPE_ENUM])
    for field in fields:
        pb_value = getattr(pb, field.name)
        if field.message_type and field.message_type.full_name == 'google.protobuf.Timestamp':
            doc_value = pb_value.ToMilliseconds()
        else:
            if field.type == FD.TYPE_MESSAGE:
                # Don't set empty message fields in resulting dictionary
                if field.label != FD.LABEL_REPEATED and not pb.HasField(field.name):
                    continue
                field_type = lambda v: pb_to_doc(v, include_empty_fields=include_empty_fields, get_doc_key=get_doc_key)
            elif field.type == FD.TYPE_ENUM:
                field_type = lambda v: enum_value_to_name(field.enum_type, v)
            elif field.type in _types:
                field_type = _types[field.type]
            else:
                raise ValueError("Field {}.{} of type '{}' is not supported".format(
                    pb.__class__.__name__, field.name, field.type))
            if field.label == FD.LABEL_REPEATED:
                doc_value = [field_type(v) for v in pb_value]
            else:
                doc_value = field_type(pb_value)
        doc_key = get_doc_key(field)
        doc[doc_key] = doc_value
    return doc


def _format_attr_path(prefix, field_name, index=None):
    if not prefix:
        if index is None:
            return field_name
        return '{}[{}]'.format(field_name, index)
    rv = '.'.join((prefix, field_name))
    if index is not None:
        rv += '[{}]'.format(index)
    return rv


def _limit_len(v, limit=64):
    if len(v) > limit:
        return v[:limit] + '...'
    return v


def validate_pb_schema(pb_object, path=''):
    """
    Validates protobuf object according to json schema style field annotations.
    Raises ValueError in case of constraint violation.

    :param pb_object: protobuf message object to validate
    :param path: optional path to prepend violated field name in error message (if object is a sub message).
    """
    for field_desc in pb_object.DESCRIPTOR.fields:
        field_name = field_desc.name
        if field_desc.containing_oneof:
            oneof_name = field_desc.containing_oneof.name
            if pb_object.WhichOneof(oneof_name) is not None and pb_object.WhichOneof(oneof_name) != field_name:
                continue
        field_value = getattr(pb_object, field_name)
        for option_field, option_value in field_desc.GetOptions().ListFields():
            if field_desc.type in _NUM_TYPES:
                if option_field.name == 'minimum':
                    if field_desc.label == LABEL_REPEATED:
                        for index, i in enumerate(field_value):
                            if i < option_value:
                                full_name = _format_attr_path(path, field_name, index)
                                raise ValueError('{}: got: {}, min: {}'.format(full_name,
                                                                               i,
                                                                               option_value))
                    elif field_value < option_value:
                        full_name = _format_attr_path(path, field_name)
                        raise ValueError('{}: got: {}, min: {}'.format(full_name,
                                                                       field_value,
                                                                       option_value))
                elif option_field.name == 'maximum':
                    if field_desc.label == LABEL_REPEATED:
                        for index, i in enumerate(field_value):
                            if i > option_value:
                                full_name = _format_attr_path(path, field_name, index)
                                raise ValueError('{}: got: {}, max: {}'.format(full_name,
                                                                               i,
                                                                               option_value))
                    elif field_value > option_value:
                        full_name = _format_attr_path(path, field_name)
                        raise ValueError('{}: got: {}, max: {}'.format(full_name,
                                                                       field_value,
                                                                       option_value))
            elif field_desc.type in _STRING_TYPES:
                if option_field.name == 'minLength':
                    if field_desc.label == LABEL_REPEATED:
                        for index, i in enumerate(field_value):
                            if len(i) < option_value:
                                full_name = _format_attr_path(path, field_name, index)
                                raise ValueError('{} length {} is too short: min {}, got "{}"'.format(full_name,
                                                                                                      len(i),
                                                                                                      option_value,
                                                                                                      _limit_len(i)))
                    elif len(field_value) < option_value:
                        full_name = _format_attr_path(path, field_name)
                        raise ValueError('{} length {} is too short: min {}, got "{}"'.format(full_name,
                                                                                              len(field_value),
                                                                                              option_value,
                                                                                              _limit_len(field_value)))
                elif option_field.name == 'maxLength':
                    if field_desc.label == LABEL_REPEATED:
                        for index, i in enumerate(field_value):
                            if len(i) > option_value:
                                full_name = _format_attr_path(path, field_name, index)
                                raise ValueError('{} length {} is too long: max {}, got "{}"'.format(full_name,
                                                                                                     len(i),
                                                                                                     option_value,
                                                                                                     _limit_len(i)))
                    elif len(field_value) > option_value:
                        full_name = _format_attr_path(path, field_name)
                        raise ValueError('{} length {} is too long: max {}, got "{}"'.format(full_name,
                                                                                             len(field_value),
                                                                                             option_value,
                                                                                             _limit_len(field_value)))
                elif option_field.name == 'pattern':
                    matcher = re.compile(option_value)
                    if field_desc.label == LABEL_REPEATED:
                        for index, i in enumerate(field_value):
                            if matcher.match(i) is None:
                                full_name = _format_attr_path(path, field_name, index)
                                raise ValueError('{} must match pattern "{}", got "{}"'.format(full_name,
                                                                                               option_value,
                                                                                               field_value
                                                                                               ))
                    elif matcher.match(field_value) is None:
                        full_name = _format_attr_path(path, field_name)
                        raise ValueError('{} must match pattern "{}", got "{}"'.format(full_name,
                                                                                       option_value,
                                                                                       field_value
                                                                                       ))
        if field_desc.type in _STRUCT_TYPES:
            if field_desc.label == LABEL_REPEATED:
                if field_desc.message_type.has_options and field_desc.message_type.GetOptions().map_entry:
                    # map field
                    if field_desc.message_type.fields_by_number[2].type in _STRUCT_TYPES:
                        # values have a message type
                        for i in field_value.itervalues():
                            validate_pb_schema(i, path=_format_attr_path(path, field_name, i))
                else:
                    # repeated message
                    for index, i in enumerate(field_value):
                        validate_pb_schema(i, path=_format_attr_path(path, field_name, index))
            else:
                validate_pb_schema(field_value, path=_format_attr_path(path, field_name))


def datetime_to_timestamp(timestamp, dt):
    """
    Initializes protobuf timestamp object from provided datetime.
    Default .FromDateTime() won't work on timezone aware datetime objects.

    :param timestamp: protobuf timestamp object
    :param dt: datetime
    """
    sec_since_epoch = int(calendar.timegm(dt.utctimetuple()) + dt.microsecond / 1000000.0)
    timestamp.FromSeconds(sec_since_epoch)
