Add list support to pblite deserializer

This commit is contained in:
Tulir Asokan 2023-09-04 14:25:00 +03:00
parent d757ced271
commit 88ba4b12b6

View file

@ -21,11 +21,27 @@ func Unmarshal(data []byte, m proto.Message) error {
return deserializeFromSlice(anyDataArr, m.ProtoReflect()) return deserializeFromSlice(anyDataArr, m.ProtoReflect())
} }
func deserializeOne(val any, index int, ref protoreflect.Message, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) { func deserializeOne(val any, index int, ref protoreflect.Message, insideList protoreflect.List, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) {
var num float64 var num float64
var expectedKind, str string var expectedKind, str string
var boolean, ok bool var boolean, ok bool
var outputVal protoreflect.Value var outputVal protoreflect.Value
if fieldDescriptor.IsList() && insideList == nil {
nestedData, ok := val.([]any)
if !ok {
return outputVal, fmt.Errorf("expected untyped array at index %d for repeated field %s, got %T", index, fieldDescriptor.FullName(), val)
}
list := ref.NewField(fieldDescriptor).List()
list.NewElement()
for i, nestedVal := range nestedData {
nestedParsed, err := deserializeOne(nestedVal, i, ref, list, fieldDescriptor)
if err != nil {
return outputVal, err
}
list.Append(nestedParsed)
}
return protoreflect.ValueOfList(list), nil
}
switch fieldDescriptor.Kind() { switch fieldDescriptor.Kind() {
case protoreflect.MessageKind: case protoreflect.MessageKind:
ok = true ok = true
@ -33,7 +49,12 @@ func deserializeOne(val any, index int, ref protoreflect.Message, fieldDescripto
if !ok { if !ok {
return outputVal, fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) return outputVal, fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
} }
nestedMessage := ref.NewField(fieldDescriptor).Message() var nestedMessage protoreflect.Message
if insideList != nil {
nestedMessage = insideList.NewElement().Message()
} else {
nestedMessage = ref.NewField(fieldDescriptor).Message()
}
if err := deserializeFromSlice(nestedData, nestedMessage); err != nil { if err := deserializeFromSlice(nestedData, nestedMessage); err != nil {
return outputVal, err return outputVal, err
} }
@ -104,7 +125,7 @@ func deserializeFromSlice(data []any, ref protoreflect.Message) error {
} }
val := data[index] val := data[index]
outputVal, err := deserializeOne(val, index, ref, fieldDescriptor) outputVal, err := deserializeOne(val, index, ref, nil, fieldDescriptor)
if err != nil { if err != nil {
return err return err
} }