Remove manual json marshaling step in pblite

This commit is contained in:
Tulir Asokan 2023-07-15 15:56:55 +03:00
parent 25236fffa9
commit 605d84c485
5 changed files with 15 additions and 32 deletions

View file

@ -1,8 +1,6 @@
package payload package payload
import ( import (
"encoding/json"
"github.com/google/uuid" "github.com/google/uuid"
"go.mau.fi/mautrix-gmessages/libgm/binary" "go.mau.fi/mautrix-gmessages/libgm/binary"
@ -20,11 +18,7 @@ func ReceiveMessages(rpcKey []byte) ([]byte, string, error) {
Unknown: &binary.ReceiveMessagesRequest_UnknownEmptyObject1{}, Unknown: &binary.ReceiveMessagesRequest_UnknownEmptyObject1{},
}, },
} }
data, err := pblite.Serialize(payload.ProtoReflect()) jsonData, err := pblite.Marshal(payload)
if err != nil {
return nil, "", err
}
jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View file

@ -1,8 +1,6 @@
package payload package payload
import ( import (
"encoding/json"
"go.mau.fi/mautrix-gmessages/libgm/binary" "go.mau.fi/mautrix-gmessages/libgm/binary"
"go.mau.fi/mautrix-gmessages/libgm/pblite" "go.mau.fi/mautrix-gmessages/libgm/pblite"
) )
@ -21,15 +19,10 @@ func RegisterRefresh(sig []byte, requestID string, timestamp int64, browser *bin
MessageType: 2, // hmm MessageType: 2, // hmm
} }
serialized, serializeErr := pblite.Serialize(payload.ProtoReflect()) jsonMessage, serializeErr := pblite.Marshal(payload)
if serializeErr != nil { if serializeErr != nil {
return nil, serializeErr return nil, serializeErr
} }
jsonMessage, marshalErr := json.Marshal(serialized)
if marshalErr != nil {
return nil, marshalErr
}
return jsonMessage, nil return jsonMessage, nil
} }

View file

@ -1,7 +1,6 @@
package payload package payload
import ( import (
"encoding/json"
"fmt" "fmt"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -113,15 +112,11 @@ func (sm *SendMessageBuilder) Build() ([]byte, error) {
} }
sm.message.MessageData.ProtobufData = encodedMessage sm.message.MessageData.ProtobufData = encodedMessage
messageProtoJSON, serializeErr := pblite.Serialize(sm.message.ProtoReflect()) protoJSONBytes, serializeErr := pblite.Marshal(sm.message)
if serializeErr != nil { if serializeErr != nil {
panic(serializeErr) panic(serializeErr)
return nil, serializeErr return nil, serializeErr
} }
protoJSONBytes, marshalErr := json.Marshal(messageProtoJSON)
if marshalErr != nil {
return nil, marshalErr
}
return protoJSONBytes, nil return protoJSONBytes, nil
} }

View file

@ -82,7 +82,7 @@ func serializeOneOrList(fieldDescriptor protoreflect.FieldDescriptor, fieldValue
func serializeOne(fieldDescriptor protoreflect.FieldDescriptor, fieldValue protoreflect.Value) (any, error) { func serializeOne(fieldDescriptor protoreflect.FieldDescriptor, fieldValue protoreflect.Value) (any, error) {
switch fieldDescriptor.Kind() { switch fieldDescriptor.Kind() {
case protoreflect.MessageKind: case protoreflect.MessageKind:
serializedMsg, err := Serialize(fieldValue.Message().Interface().ProtoReflect()) serializedMsg, err := SerializeToSlice(fieldValue.Message().Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -106,21 +106,22 @@ func serializeOne(fieldDescriptor protoreflect.FieldDescriptor, fieldValue proto
} }
} }
func Serialize(m protoreflect.Message) ([]any, error) { func SerializeToSlice(msg proto.Message) ([]any, error) {
ref := msg.ProtoReflect()
maxFieldNumber := 0 maxFieldNumber := 0
for i := 0; i < m.Descriptor().Fields().Len(); i++ { for i := 0; i < ref.Descriptor().Fields().Len(); i++ {
fieldNumber := int(m.Descriptor().Fields().Get(i).Number()) fieldNumber := int(ref.Descriptor().Fields().Get(i).Number())
if fieldNumber > maxFieldNumber { if fieldNumber > maxFieldNumber {
maxFieldNumber = fieldNumber maxFieldNumber = fieldNumber
} }
} }
serialized := make([]any, maxFieldNumber) serialized := make([]any, maxFieldNumber)
for i := 0; i < m.Descriptor().Fields().Len(); i++ { for i := 0; i < ref.Descriptor().Fields().Len(); i++ {
fieldDescriptor := m.Descriptor().Fields().Get(i) fieldDescriptor := ref.Descriptor().Fields().Get(i)
fieldValue := m.Get(fieldDescriptor) fieldValue := ref.Get(fieldDescriptor)
fieldNumber := int(fieldDescriptor.Number()) fieldNumber := int(fieldDescriptor.Number())
if !m.Has(fieldDescriptor) { if !ref.Has(fieldDescriptor) {
continue continue
} }
serializedVal, err := serializeOneOrList(fieldDescriptor, fieldValue) serializedVal, err := serializeOneOrList(fieldDescriptor, fieldValue)
@ -133,8 +134,8 @@ func Serialize(m protoreflect.Message) ([]any, error) {
return serialized, nil return serialized, nil
} }
func SerializeToJSON(m proto.Message) ([]byte, error) { func Marshal(m proto.Message) ([]byte, error) {
serialized, err := Serialize(m.ProtoReflect()) serialized, err := SerializeToSlice(m)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -142,7 +142,7 @@ func (s *SessionHandler) sendAckRequest() {
EmptyArr: &binary.EmptyArr{}, EmptyArr: &binary.EmptyArr{},
Acks: ackMessages, Acks: ackMessages,
} }
jsonData, err := pblite.SerializeToJSON(ackMessagePayload) jsonData, err := pblite.Marshal(ackMessagePayload)
if err != nil { if err != nil {
panic(err) panic(err)
} }