Update pblite package to handle more types
This commit is contained in:
parent
cdf9b1e4a0
commit
5e5344742e
5 changed files with 116 additions and 78 deletions
|
@ -1,53 +0,0 @@
|
||||||
package json_proto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"google.golang.org/protobuf/reflect/protoreflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Deserialize(data []interface{}, m protoreflect.Message) error {
|
|
||||||
for i := 0; i < m.Descriptor().Fields().Len(); i++ {
|
|
||||||
fieldDescriptor := m.Descriptor().Fields().Get(i)
|
|
||||||
index := int(fieldDescriptor.Number()) - 1
|
|
||||||
if index < 0 || index >= len(data) || data[index] == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
val := data[index]
|
|
||||||
|
|
||||||
switch fieldDescriptor.Kind() {
|
|
||||||
case protoreflect.MessageKind:
|
|
||||||
nestedData, ok := val.([]interface{})
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected slice, got %T", val)
|
|
||||||
}
|
|
||||||
nestedMessage := m.NewField(fieldDescriptor).Message()
|
|
||||||
if err := Deserialize(nestedData, nestedMessage); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfMessage(nestedMessage))
|
|
||||||
case protoreflect.BytesKind:
|
|
||||||
bytes, ok := val.([]byte)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected bytes, got %T", val)
|
|
||||||
}
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfBytes(bytes))
|
|
||||||
case protoreflect.Int32Kind, protoreflect.Int64Kind:
|
|
||||||
num, ok := val.(float64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected number, got %T", val)
|
|
||||||
}
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOf(int64(num)))
|
|
||||||
case protoreflect.StringKind:
|
|
||||||
str, ok := val.(string)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected string, got %T", val)
|
|
||||||
}
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOf(str))
|
|
||||||
default:
|
|
||||||
// ignore fields of other types
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"go.mau.fi/mautrix-gmessages/libgm/binary"
|
"go.mau.fi/mautrix-gmessages/libgm/binary"
|
||||||
"go.mau.fi/mautrix-gmessages/libgm/json_proto"
|
"go.mau.fi/mautrix-gmessages/libgm/pblite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *RPC) HandleRPCMsg(msgArr []interface{}) {
|
func (r *RPC) HandleRPCMsg(msgArr []interface{}) {
|
||||||
|
@ -24,7 +24,7 @@ func (r *RPC) HandleRPCMsg(msgArr []interface{}) {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
response := &binary.RPCResponse{}
|
response := &binary.RPCResponse{}
|
||||||
deserializeErr := json_proto.Deserialize(msgArr, response.ProtoReflect())
|
deserializeErr := pblite.Deserialize(msgArr, response.ProtoReflect())
|
||||||
if deserializeErr != nil {
|
if deserializeErr != nil {
|
||||||
r.client.Logger.Error().Err(fmt.Errorf("failed to deserialize response %s", msgArr)).Msg("rpc deserialize msg err")
|
r.client.Logger.Error().Err(fmt.Errorf("failed to deserialize response %s", msgArr)).Msg("rpc deserialize msg err")
|
||||||
return
|
return
|
||||||
|
|
85
libgm/pblite/deseralize.go
Normal file
85
libgm/pblite/deseralize.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package pblite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Deserialize(data []any, m protoreflect.Message) error {
|
||||||
|
for i := 0; i < m.Descriptor().Fields().Len(); i++ {
|
||||||
|
fieldDescriptor := m.Descriptor().Fields().Get(i)
|
||||||
|
index := int(fieldDescriptor.Number()) - 1
|
||||||
|
if index < 0 || index >= len(data) || data[index] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
val := data[index]
|
||||||
|
|
||||||
|
var num float64
|
||||||
|
var expectedKind, str string
|
||||||
|
var boolean, ok bool
|
||||||
|
switch fieldDescriptor.Kind() {
|
||||||
|
case protoreflect.MessageKind:
|
||||||
|
nestedData, ok := val.([]any)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
|
||||||
|
}
|
||||||
|
nestedMessage := m.NewField(fieldDescriptor).Message()
|
||||||
|
if err := Deserialize(nestedData, nestedMessage); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfMessage(nestedMessage))
|
||||||
|
case protoreflect.BytesKind:
|
||||||
|
bytesBase64, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
|
||||||
|
}
|
||||||
|
bytes, err := base64.StdEncoding.DecodeString(bytesBase64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfBytes(bytes))
|
||||||
|
case protoreflect.Int32Kind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfInt32(int32(num)))
|
||||||
|
case protoreflect.Int64Kind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfInt64(int64(num)))
|
||||||
|
case protoreflect.Uint32Kind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfUint32(uint32(num)))
|
||||||
|
case protoreflect.Uint64Kind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfUint64(uint64(num)))
|
||||||
|
case protoreflect.FloatKind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfFloat32(float32(num)))
|
||||||
|
case protoreflect.DoubleKind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfFloat64(num))
|
||||||
|
case protoreflect.StringKind:
|
||||||
|
str, ok = val.(string)
|
||||||
|
expectedKind = "string"
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfString(str))
|
||||||
|
case protoreflect.BoolKind:
|
||||||
|
boolean, ok = val.(bool)
|
||||||
|
expectedKind = "bool"
|
||||||
|
m.Set(fieldDescriptor, protoreflect.ValueOfBool(boolean))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName())
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package json_proto
|
package pblite
|
||||||
|
|
||||||
/*
|
/*
|
||||||
in protobuf, a message looks like this:
|
in protobuf, a message looks like this:
|
||||||
|
@ -53,10 +53,13 @@ This means that any slice inside of the current slice, indicates another message
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"google.golang.org/protobuf/reflect/protoreflect"
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Serialize(m protoreflect.Message) ([]interface{}, error) {
|
func Serialize(m protoreflect.Message) ([]any, error) {
|
||||||
maxFieldNumber := 0
|
maxFieldNumber := 0
|
||||||
for i := 0; i < m.Descriptor().Fields().Len(); i++ {
|
for i := 0; i < m.Descriptor().Fields().Len(); i++ {
|
||||||
fieldNumber := int(m.Descriptor().Fields().Get(i).Number())
|
fieldNumber := int(m.Descriptor().Fields().Get(i).Number())
|
||||||
|
@ -65,34 +68,37 @@ func Serialize(m protoreflect.Message) ([]interface{}, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
serialized := make([]interface{}, maxFieldNumber)
|
serialized := make([]any, maxFieldNumber)
|
||||||
for i := 0; i < m.Descriptor().Fields().Len(); i++ {
|
for i := 0; i < m.Descriptor().Fields().Len(); i++ {
|
||||||
fieldDescriptor := m.Descriptor().Fields().Get(i)
|
fieldDescriptor := m.Descriptor().Fields().Get(i)
|
||||||
fieldValue := m.Get(fieldDescriptor)
|
fieldValue := m.Get(fieldDescriptor)
|
||||||
fieldNumber := int(fieldDescriptor.Number())
|
fieldNumber := int(fieldDescriptor.Number())
|
||||||
|
if !m.Has(fieldDescriptor) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
switch fieldDescriptor.Kind() {
|
switch fieldDescriptor.Kind() {
|
||||||
case protoreflect.MessageKind:
|
case protoreflect.MessageKind:
|
||||||
if m.Has(fieldDescriptor) {
|
|
||||||
serializedMsg, err := Serialize(fieldValue.Message().Interface().ProtoReflect())
|
serializedMsg, err := Serialize(fieldValue.Message().Interface().ProtoReflect())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
serialized[fieldNumber-1] = serializedMsg
|
serialized[fieldNumber-1] = serializedMsg
|
||||||
}
|
|
||||||
case protoreflect.BytesKind:
|
case protoreflect.BytesKind:
|
||||||
if m.Has(fieldDescriptor) {
|
serialized[fieldNumber-1] = base64.StdEncoding.EncodeToString(fieldValue.Bytes())
|
||||||
serialized[fieldNumber-1] = fieldValue.Bytes()
|
|
||||||
}
|
|
||||||
case protoreflect.Int32Kind, protoreflect.Int64Kind:
|
case protoreflect.Int32Kind, protoreflect.Int64Kind:
|
||||||
if m.Has(fieldDescriptor) {
|
|
||||||
serialized[fieldNumber-1] = fieldValue.Int()
|
serialized[fieldNumber-1] = fieldValue.Int()
|
||||||
}
|
case protoreflect.Uint32Kind, protoreflect.Uint64Kind:
|
||||||
|
serialized[fieldNumber-1] = fieldValue.Uint()
|
||||||
|
case protoreflect.FloatKind, protoreflect.DoubleKind:
|
||||||
|
serialized[fieldNumber-1] = fieldValue.Float()
|
||||||
|
case protoreflect.EnumKind:
|
||||||
|
serialized[fieldNumber-1] = int(fieldValue.Enum())
|
||||||
|
case protoreflect.BoolKind:
|
||||||
|
serialized[fieldNumber-1] = fieldValue.Bool()
|
||||||
case protoreflect.StringKind:
|
case protoreflect.StringKind:
|
||||||
if m.Has(fieldDescriptor) {
|
|
||||||
serialized[fieldNumber-1] = fieldValue.String()
|
serialized[fieldNumber-1] = fieldValue.String()
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
// ignore fields of other types
|
return nil, fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,8 @@ import (
|
||||||
|
|
||||||
"go.mau.fi/mautrix-gmessages/libgm/binary"
|
"go.mau.fi/mautrix-gmessages/libgm/binary"
|
||||||
"go.mau.fi/mautrix-gmessages/libgm/crypto"
|
"go.mau.fi/mautrix-gmessages/libgm/crypto"
|
||||||
"go.mau.fi/mautrix-gmessages/libgm/json_proto"
|
|
||||||
"go.mau.fi/mautrix-gmessages/libgm/payload"
|
"go.mau.fi/mautrix-gmessages/libgm/payload"
|
||||||
|
"go.mau.fi/mautrix-gmessages/libgm/pblite"
|
||||||
"go.mau.fi/mautrix-gmessages/libgm/util"
|
"go.mau.fi/mautrix-gmessages/libgm/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ func (s *SessionHandler) completeSendMessage(requestId string, opCode int64, msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionHandler) toJSON(message protoreflect.Message) ([]byte, error) {
|
func (s *SessionHandler) toJSON(message protoreflect.Message) ([]byte, error) {
|
||||||
interfaceArr, err := json_proto.Serialize(message)
|
interfaceArr, err := pblite.Serialize(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -145,14 +145,14 @@ func (s *SessionHandler) sendAckRequest() {
|
||||||
EmptyArr: &binary.EmptyArr{},
|
EmptyArr: &binary.EmptyArr{},
|
||||||
NoClue: nil,
|
NoClue: nil,
|
||||||
}
|
}
|
||||||
dataArray, err := json_proto.Serialize(ackMessagePayload.ProtoReflect())
|
dataArray, err := pblite.Serialize(ackMessagePayload.ProtoReflect())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
ackMessages := make([][]interface{}, 0)
|
ackMessages := make([][]interface{}, 0)
|
||||||
for _, reqId := range s.ackMap {
|
for _, reqId := range s.ackMap {
|
||||||
ackMessageData := &binary.AckMessageData{RequestId: reqId, Device: s.client.devicePair.Browser}
|
ackMessageData := &binary.AckMessageData{RequestId: reqId, Device: s.client.devicePair.Browser}
|
||||||
ackMessageDataArr, err := json_proto.Serialize(ackMessageData.ProtoReflect())
|
ackMessageDataArr, err := pblite.Serialize(ackMessageData.ProtoReflect())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue