diff --git a/libgm/pblite/deserialize.go b/libgm/pblite/deserialize.go index df986ef..a568246 100644 --- a/libgm/pblite/deserialize.go +++ b/libgm/pblite/deserialize.go @@ -21,11 +21,27 @@ func Unmarshal(data []byte, m proto.Message) error { return deserializeFromSlice(anyDataArr, m.ProtoReflect()) } -func deserializeOne(val any, index int, ref protoreflect.Message, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) { +func deserializeOne(val any, index int, ref protoreflect.Message, insideList protoreflect.List, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) { var num float64 var expectedKind, str string var boolean, ok bool var outputVal protoreflect.Value + if fieldDescriptor.IsList() && insideList == nil { + nestedData, ok := val.([]any) + if !ok { + return outputVal, fmt.Errorf("expected untyped array at index %d for repeated field %s, got %T", index, fieldDescriptor.FullName(), val) + } + list := ref.NewField(fieldDescriptor).List() + list.NewElement() + for i, nestedVal := range nestedData { + nestedParsed, err := deserializeOne(nestedVal, i, ref, list, fieldDescriptor) + if err != nil { + return outputVal, err + } + list.Append(nestedParsed) + } + return protoreflect.ValueOfList(list), nil + } switch fieldDescriptor.Kind() { case protoreflect.MessageKind: ok = true @@ -33,7 +49,12 @@ func deserializeOne(val any, index int, ref protoreflect.Message, fieldDescripto if !ok { return outputVal, fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) } - nestedMessage := ref.NewField(fieldDescriptor).Message() + var nestedMessage protoreflect.Message + if insideList != nil { + nestedMessage = insideList.NewElement().Message() + } else { + nestedMessage = ref.NewField(fieldDescriptor).Message() + } if err := deserializeFromSlice(nestedData, nestedMessage); err != nil { return outputVal, err } @@ -104,7 +125,7 @@ func deserializeFromSlice(data []any, ref protoreflect.Message) error { } val := data[index] - outputVal, err := deserializeOne(val, index, ref, fieldDescriptor) + outputVal, err := deserializeOne(val, index, ref, nil, fieldDescriptor) if err != nil { return err }