import pytest

from tasklet.runtime.utils import convert
from tasklet.runtime.utils.ut.proto import test_pb2


flat_msg = test_pb2.FlatMsg()
flat_msg.f1 = 1
flat_msg.f2 = True
flat_msg.f3 = 'some str'
flat_dct = {'f1': 1, 'f2': True, 'f3': 'some str'}

composite_msg = test_pb2.CompositeMsg()
composite_msg.ff1 = 2
composite_msg.ff2.CopyFrom(flat_msg)
composite_dct = {'ff1': 2, 'ff2__f1': 1, 'ff2__f2': True, 'ff2__f3': 'some str'}

repeated_msg = test_pb2.RepeatedMsg()
repeated_msg.f1.extend([13, 432])
repeated_msg.f2.extend(['blue', 'green'])
repeated_dct = {'f1': [13, 432], 'f2': ['blue', 'green']}

super_composite_msg = test_pb2.SuperCompositeMsg()
super_composite_msg.fff1 = 3
super_composite_msg.fff1_f1 = 30
super_composite_msg.fff2.CopyFrom(composite_msg)
super_composite_dct = {'fff1': 3, 'fff1_f1': 30, 'fff2__ff1': 2, 'fff2__ff2__f1': 1, 'fff2__ff2__f2': True, 'fff2__ff2__f3': 'some str'}

recursive_msg = test_pb2.RecursiveMsg()
recursive_msg.f1 = 1
recursive_msg.f2.f1 = 2
recursive_msg.f2.f2.f1 = 3
recursive_dct = {'f1': 1, 'f2__f1': 2, 'f2__f2__f1': 3}


@pytest.mark.parametrize('proto, expected', (
    (flat_msg, flat_dct),
    (composite_msg, composite_dct),
    (repeated_msg, repeated_dct),
    (super_composite_msg, super_composite_dct),
    (recursive_msg, recursive_dct),
))
def test_proto_flat_dict(proto, expected):
    dct = convert.proto_to_flat_dict(proto)
    assert dct == expected

    res = type(proto)()
    not_found = convert.flat_dict_to_proto(dct, res)
    assert res == proto
    assert not_found == []

    res = type(proto)()
    dct.update(fff2__ff2__f4='v1', fff2__ff4='v2', fff4='v3')
    not_found = convert.flat_dict_to_proto(dct, res)
    assert res == proto
    assert sorted(not_found) == ['fff2__ff2__f4', 'fff2__ff4', 'fff4']


def test_nesting_exceeded():
    msg = test_pb2.RecursiveMsg()
    msg.f1 = 1
    msg.f2.f2.f2.f2.f2.f2.f2.f2.f1 = 2
    dct = convert.proto_to_flat_dict(msg)
    assert dct == {
        'f1': 1,
        'f2__f1': 0,
        'f2__f2__f1': 0,
        'f2__f2__f2__f1': 0,
        'f2__f2__f2__f2__f1': 0,
        'f2__f2__f2__f2__f2__f1': 0,
        'f2__f2__f2__f2__f2__f2__f1': 0,
        'f2__f2__f2__f2__f2__f2__f2__f1': 0,
        'f2__f2__f2__f2__f2__f2__f2__f2__f1': 2,
    }
    msg.f2.f2.f2.f2.f2.f2.f2.f2.f2.f1 = 3

    with pytest.raises(Exception) as e:
        dct = convert.proto_to_flat_dict(msg)
    assert 'Max fields nesting depth' in e.value.args[0]
