Add list support to pblite deserializer

This commit is contained in:
Tulir Asokan 2023-09-04 14:25:00 +03:00
parent d757ced271
commit 88ba4b12b6

View file

@ -21,11 +21,27 @@ func Unmarshal(data []byte, m proto.Message) error {
return deserializeFromSlice(anyDataArr, m.ProtoReflect())
}
func deserializeOne(val any, index int, ref protoreflect.Message, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) {
func deserializeOne(val any, index int, ref protoreflect.Message, insideList protoreflect.List, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) {
var num float64
var expectedKind, str string
var boolean, ok bool
var outputVal protoreflect.Value
if fieldDescriptor.IsList() && insideList == nil {
nestedData, ok := val.([]any)
if !ok {
return outputVal, fmt.Errorf("expected untyped array at index %d for repeated field %s, got %T", index, fieldDescriptor.FullName(), val)
}
list := ref.NewField(fieldDescriptor).List()
list.NewElement()
for i, nestedVal := range nestedData {
nestedParsed, err := deserializeOne(nestedVal, i, ref, list, fieldDescriptor)
if err != nil {
return outputVal, err
}
list.Append(nestedParsed)
}
return protoreflect.ValueOfList(list), nil
}
switch fieldDescriptor.Kind() {
case protoreflect.MessageKind:
ok = true
@ -33,7 +49,12 @@ func deserializeOne(val any, index int, ref protoreflect.Message, fieldDescripto
if !ok {
return outputVal, fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
}
nestedMessage := ref.NewField(fieldDescriptor).Message()
var nestedMessage protoreflect.Message
if insideList != nil {
nestedMessage = insideList.NewElement().Message()
} else {
nestedMessage = ref.NewField(fieldDescriptor).Message()
}
if err := deserializeFromSlice(nestedData, nestedMessage); err != nil {
return outputVal, err
}
@ -104,7 +125,7 @@ func deserializeFromSlice(data []any, ref protoreflect.Message) error {
}
val := data[index]
outputVal, err := deserializeOne(val, index, ref, fieldDescriptor)
outputVal, err := deserializeOne(val, index, ref, nil, fieldDescriptor)
if err != nil {
return err
}