# -*- coding: utf-8 -*-
from __future__ import unicode_literals

# noinspection PyUnresolvedReferences
import travel.proto.cpa.proto_options.options_pb2 as options_pb2

from itertools import tee
from typing import Any, Callable, Type
import base64
import hashlib

from google.protobuf.descriptor import FieldDescriptor
from google.protobuf.message import Message


TYPE_MAP = {
    FieldDescriptor.TYPE_DOUBLE: float,
    FieldDescriptor.TYPE_FLOAT: float,
    FieldDescriptor.TYPE_INT32: int,
    FieldDescriptor.TYPE_INT64: int,
    FieldDescriptor.TYPE_UINT32: int,
    FieldDescriptor.TYPE_UINT64: int,
    FieldDescriptor.TYPE_SINT32: int,
    FieldDescriptor.TYPE_SINT64: int,
    FieldDescriptor.TYPE_FIXED32: int,
    FieldDescriptor.TYPE_FIXED64: int,
    FieldDescriptor.TYPE_SFIXED32: int,
    FieldDescriptor.TYPE_SFIXED64: int,
    FieldDescriptor.TYPE_BOOL: bool,
    FieldDescriptor.TYPE_STRING: str,
    FieldDescriptor.TYPE_BYTES: bin,
    FieldDescriptor.TYPE_ENUM: int,
}


def camel_to_snake(s: str) -> str:
    if len(s) < 2:
        return s.lower()
    result = list()
    a, b = tee(s)
    next(b, None)
    for p, n in zip(a, b):
        result.append(p)
        if p.islower() and n.isupper():
            result.append('_')
    result.append(s[-1])
    return ''.join(result).lower()


def repeated(type_converter: Callable[[Any], Any]) -> Callable[[list[Any]], list[Any]]:
    return lambda value_list: [type_converter(value) for value in value_list]


def protobuf_to_dict(message: Message, convert_case: bool = False, convert_enum: bool = False) -> dict[str, Any]:
    result = dict()
    for field, value in message.ListFields():
        if field.is_extension:
            continue
        field_name = field.name
        field_type = field.type
        if field.type == FieldDescriptor.TYPE_MESSAGE:
            result[field_name] = protobuf_to_dict(value, convert_case, convert_enum)
            continue
        type_converter = TYPE_MAP[field_type]
        if field.label == FieldDescriptor.LABEL_REPEATED:
            type_converter = repeated(type_converter)
        if convert_case:
            field_name = camel_to_snake(field_name)
        value = type_converter(value)
        convert_current_enum = convert_enum or field.GetOptions().Extensions[options_pb2.UseEnumNameInYt]
        if convert_current_enum and field_type == FieldDescriptor.TYPE_ENUM:
            value = field.enum_type.values_by_number[int(value)].name
        result[field_name] = value
    return result


def proto_from_base64_str(proto_class: Type[Message], proto_str: str) -> Message:
    proto_str = bytes(proto_str, encoding='utf8')
    missing_padding = len(proto_str) % 4
    if missing_padding:
        proto_str += b'=' * (4 - missing_padding)
    message = proto_class()
    message.ParseFromString(base64.urlsafe_b64decode(proto_str))
    return message


def get_proto_hash(proto: Message) -> str:
    m = hashlib.md5()
    _update_message_hash(proto, m)
    return str(m.hexdigest())


def _update_message_hash(proto: Message, hash_obj) -> None:
    for field, value in proto.ListFields():
        if field.is_extension:
            continue
        field_type = field.type
        if field.label != FieldDescriptor.LABEL_REPEATED:
            value = [value]
        if field_type == FieldDescriptor.TYPE_MESSAGE:
            for item in value:
                _update_message_hash(item, hash_obj)
            continue
        for item in value:
            hash_obj.update(repr(item).encode())
