diff --git a/libgm/json_proto/deseralize.go b/libgm/json_proto/deseralize.go deleted file mode 100644 index 61b7368..0000000 --- a/libgm/json_proto/deseralize.go +++ /dev/null @@ -1,53 +0,0 @@ -package json_proto - -import ( - "fmt" - - "google.golang.org/protobuf/reflect/protoreflect" -) - -func Deserialize(data []interface{}, m protoreflect.Message) error { - for i := 0; i < m.Descriptor().Fields().Len(); i++ { - fieldDescriptor := m.Descriptor().Fields().Get(i) - index := int(fieldDescriptor.Number()) - 1 - if index < 0 || index >= len(data) || data[index] == nil { - continue - } - - val := data[index] - - switch fieldDescriptor.Kind() { - case protoreflect.MessageKind: - nestedData, ok := val.([]interface{}) - if !ok { - return fmt.Errorf("expected slice, got %T", val) - } - nestedMessage := m.NewField(fieldDescriptor).Message() - if err := Deserialize(nestedData, nestedMessage); err != nil { - return err - } - m.Set(fieldDescriptor, protoreflect.ValueOfMessage(nestedMessage)) - case protoreflect.BytesKind: - bytes, ok := val.([]byte) - if !ok { - return fmt.Errorf("expected bytes, got %T", val) - } - m.Set(fieldDescriptor, protoreflect.ValueOfBytes(bytes)) - case protoreflect.Int32Kind, protoreflect.Int64Kind: - num, ok := val.(float64) - if !ok { - return fmt.Errorf("expected number, got %T", val) - } - m.Set(fieldDescriptor, protoreflect.ValueOf(int64(num))) - case protoreflect.StringKind: - str, ok := val.(string) - if !ok { - return fmt.Errorf("expected string, got %T", val) - } - m.Set(fieldDescriptor, protoreflect.ValueOf(str)) - default: - // ignore fields of other types - } - } - return nil -} diff --git a/libgm/msg_handler.go b/libgm/msg_handler.go index 3351a33..1836f6d 100644 --- a/libgm/msg_handler.go +++ b/libgm/msg_handler.go @@ -6,7 +6,7 @@ import ( "log" "go.mau.fi/mautrix-gmessages/libgm/binary" - "go.mau.fi/mautrix-gmessages/libgm/json_proto" + "go.mau.fi/mautrix-gmessages/libgm/pblite" ) func (r *RPC) HandleRPCMsg(msgArr []interface{}) { @@ -24,7 +24,7 @@ func (r *RPC) HandleRPCMsg(msgArr []interface{}) { } */ response := &binary.RPCResponse{} - deserializeErr := json_proto.Deserialize(msgArr, response.ProtoReflect()) + deserializeErr := pblite.Deserialize(msgArr, response.ProtoReflect()) if deserializeErr != nil { r.client.Logger.Error().Err(fmt.Errorf("failed to deserialize response %s", msgArr)).Msg("rpc deserialize msg err") return diff --git a/libgm/pblite/deseralize.go b/libgm/pblite/deseralize.go new file mode 100644 index 0000000..377d7c7 --- /dev/null +++ b/libgm/pblite/deseralize.go @@ -0,0 +1,85 @@ +package pblite + +import ( + "encoding/base64" + "fmt" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +func Deserialize(data []any, m protoreflect.Message) error { + for i := 0; i < m.Descriptor().Fields().Len(); i++ { + fieldDescriptor := m.Descriptor().Fields().Get(i) + index := int(fieldDescriptor.Number()) - 1 + if index < 0 || index >= len(data) || data[index] == nil { + continue + } + + val := data[index] + + var num float64 + var expectedKind, str string + var boolean, ok bool + switch fieldDescriptor.Kind() { + case protoreflect.MessageKind: + nestedData, ok := val.([]any) + if !ok { + return fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) + } + nestedMessage := m.NewField(fieldDescriptor).Message() + if err := Deserialize(nestedData, nestedMessage); err != nil { + return err + } + m.Set(fieldDescriptor, protoreflect.ValueOfMessage(nestedMessage)) + case protoreflect.BytesKind: + bytesBase64, ok := val.(string) + if !ok { + return fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) + } + bytes, err := base64.StdEncoding.DecodeString(bytesBase64) + if err != nil { + return fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) + } + + m.Set(fieldDescriptor, protoreflect.ValueOfBytes(bytes)) + case protoreflect.Int32Kind: + num, ok = val.(float64) + expectedKind = "float64" + m.Set(fieldDescriptor, protoreflect.ValueOfInt32(int32(num))) + case protoreflect.Int64Kind: + num, ok = val.(float64) + expectedKind = "float64" + m.Set(fieldDescriptor, protoreflect.ValueOfInt64(int64(num))) + case protoreflect.Uint32Kind: + num, ok = val.(float64) + expectedKind = "float64" + m.Set(fieldDescriptor, protoreflect.ValueOfUint32(uint32(num))) + case protoreflect.Uint64Kind: + num, ok = val.(float64) + expectedKind = "float64" + m.Set(fieldDescriptor, protoreflect.ValueOfUint64(uint64(num))) + case protoreflect.FloatKind: + num, ok = val.(float64) + expectedKind = "float64" + m.Set(fieldDescriptor, protoreflect.ValueOfFloat32(float32(num))) + case protoreflect.DoubleKind: + num, ok = val.(float64) + expectedKind = "float64" + m.Set(fieldDescriptor, protoreflect.ValueOfFloat64(num)) + case protoreflect.StringKind: + str, ok = val.(string) + expectedKind = "string" + m.Set(fieldDescriptor, protoreflect.ValueOfString(str)) + case protoreflect.BoolKind: + boolean, ok = val.(bool) + expectedKind = "bool" + m.Set(fieldDescriptor, protoreflect.ValueOfBool(boolean)) + default: + return fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName()) + } + if !ok { + return fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val) + } + } + return nil +} diff --git a/libgm/json_proto/serialize.go b/libgm/pblite/serialize.go similarity index 62% rename from libgm/json_proto/serialize.go rename to libgm/pblite/serialize.go index 72804da..56b26d5 100644 --- a/libgm/json_proto/serialize.go +++ b/libgm/pblite/serialize.go @@ -1,4 +1,4 @@ -package json_proto +package pblite /* in protobuf, a message looks like this: @@ -53,10 +53,13 @@ This means that any slice inside of the current slice, indicates another message */ import ( + "encoding/base64" + "fmt" + "google.golang.org/protobuf/reflect/protoreflect" ) -func Serialize(m protoreflect.Message) ([]interface{}, error) { +func Serialize(m protoreflect.Message) ([]any, error) { maxFieldNumber := 0 for i := 0; i < m.Descriptor().Fields().Len(); i++ { fieldNumber := int(m.Descriptor().Fields().Get(i).Number()) @@ -65,34 +68,37 @@ func Serialize(m protoreflect.Message) ([]interface{}, error) { } } - serialized := make([]interface{}, maxFieldNumber) + serialized := make([]any, maxFieldNumber) for i := 0; i < m.Descriptor().Fields().Len(); i++ { fieldDescriptor := m.Descriptor().Fields().Get(i) fieldValue := m.Get(fieldDescriptor) fieldNumber := int(fieldDescriptor.Number()) + if !m.Has(fieldDescriptor) { + continue + } switch fieldDescriptor.Kind() { case protoreflect.MessageKind: - if m.Has(fieldDescriptor) { - serializedMsg, err := Serialize(fieldValue.Message().Interface().ProtoReflect()) - if err != nil { - return nil, err - } - serialized[fieldNumber-1] = serializedMsg + serializedMsg, err := Serialize(fieldValue.Message().Interface().ProtoReflect()) + if err != nil { + return nil, err } + serialized[fieldNumber-1] = serializedMsg case protoreflect.BytesKind: - if m.Has(fieldDescriptor) { - serialized[fieldNumber-1] = fieldValue.Bytes() - } + serialized[fieldNumber-1] = base64.StdEncoding.EncodeToString(fieldValue.Bytes()) case protoreflect.Int32Kind, protoreflect.Int64Kind: - if m.Has(fieldDescriptor) { - serialized[fieldNumber-1] = fieldValue.Int() - } + serialized[fieldNumber-1] = fieldValue.Int() + case protoreflect.Uint32Kind, protoreflect.Uint64Kind: + serialized[fieldNumber-1] = fieldValue.Uint() + case protoreflect.FloatKind, protoreflect.DoubleKind: + serialized[fieldNumber-1] = fieldValue.Float() + case protoreflect.EnumKind: + serialized[fieldNumber-1] = int(fieldValue.Enum()) + case protoreflect.BoolKind: + serialized[fieldNumber-1] = fieldValue.Bool() case protoreflect.StringKind: - if m.Has(fieldDescriptor) { - serialized[fieldNumber-1] = fieldValue.String() - } + serialized[fieldNumber-1] = fieldValue.String() default: - // ignore fields of other types + return nil, fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName()) } } diff --git a/libgm/session_handler.go b/libgm/session_handler.go index c914380..885362d 100644 --- a/libgm/session_handler.go +++ b/libgm/session_handler.go @@ -12,8 +12,8 @@ import ( "go.mau.fi/mautrix-gmessages/libgm/binary" "go.mau.fi/mautrix-gmessages/libgm/crypto" - "go.mau.fi/mautrix-gmessages/libgm/json_proto" "go.mau.fi/mautrix-gmessages/libgm/payload" + "go.mau.fi/mautrix-gmessages/libgm/pblite" "go.mau.fi/mautrix-gmessages/libgm/util" ) @@ -100,7 +100,7 @@ func (s *SessionHandler) completeSendMessage(requestId string, opCode int64, msg } func (s *SessionHandler) toJSON(message protoreflect.Message) ([]byte, error) { - interfaceArr, err := json_proto.Serialize(message) + interfaceArr, err := pblite.Serialize(message) if err != nil { return nil, err } @@ -145,14 +145,14 @@ func (s *SessionHandler) sendAckRequest() { EmptyArr: &binary.EmptyArr{}, NoClue: nil, } - dataArray, err := json_proto.Serialize(ackMessagePayload.ProtoReflect()) + dataArray, err := pblite.Serialize(ackMessagePayload.ProtoReflect()) if err != nil { log.Fatal(err) } ackMessages := make([][]interface{}, 0) for _, reqId := range s.ackMap { ackMessageData := &binary.AckMessageData{RequestId: reqId, Device: s.client.devicePair.Browser} - ackMessageDataArr, err := json_proto.Serialize(ackMessageData.ProtoReflect()) + ackMessageDataArr, err := pblite.Serialize(ackMessageData.ProtoReflect()) if err != nil { log.Fatal(err) }