Remove manual json unmarshaling step in pblite
This commit is contained in:
parent
e086846574
commit
08cbe12181
3 changed files with 116 additions and 95 deletions
|
@ -372,16 +372,8 @@ func (c *Client) refreshAuthToken() error {
|
||||||
return readErr
|
return readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
var deserialized []interface{}
|
|
||||||
|
|
||||||
marshalErr := json.Unmarshal(responseBody, &deserialized)
|
|
||||||
if marshalErr != nil {
|
|
||||||
return marshalErr
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &binary.RegisterRefreshResponse{}
|
resp := &binary.RegisterRefreshResponse{}
|
||||||
|
deserializeErr := pblite.Unmarshal(responseBody, resp)
|
||||||
deserializeErr := pblite.Deserialize(deserialized, resp.ProtoReflect())
|
|
||||||
if deserializeErr != nil {
|
if deserializeErr != nil {
|
||||||
return deserializeErr
|
return deserializeErr
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,90 +2,113 @@ package pblite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/reflect/protoreflect"
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Deserialize(data []any, m protoreflect.Message) error {
|
func Unmarshal(data []byte, m proto.Message) error {
|
||||||
for i := 0; i < m.Descriptor().Fields().Len(); i++ {
|
var anyData any
|
||||||
fieldDescriptor := m.Descriptor().Fields().Get(i)
|
if err := json.Unmarshal(data, &anyData); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
anyDataArr, ok := anyData.([]any)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("expected array in JSON, got %T", anyData)
|
||||||
|
}
|
||||||
|
return deserializeFromSlice(anyDataArr, m.ProtoReflect())
|
||||||
|
}
|
||||||
|
|
||||||
|
func deserializeOne(val any, index int, ref protoreflect.Message, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) {
|
||||||
|
var num float64
|
||||||
|
var expectedKind, str string
|
||||||
|
var boolean, ok bool
|
||||||
|
var outputVal protoreflect.Value
|
||||||
|
switch fieldDescriptor.Kind() {
|
||||||
|
case protoreflect.MessageKind:
|
||||||
|
ok = true
|
||||||
|
nestedData, ok := val.([]any)
|
||||||
|
if !ok {
|
||||||
|
return outputVal, fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
|
||||||
|
}
|
||||||
|
nestedMessage := ref.NewField(fieldDescriptor).Message()
|
||||||
|
if err := deserializeFromSlice(nestedData, nestedMessage); err != nil {
|
||||||
|
return outputVal, err
|
||||||
|
}
|
||||||
|
outputVal = protoreflect.ValueOfMessage(nestedMessage)
|
||||||
|
case protoreflect.BytesKind:
|
||||||
|
ok = true
|
||||||
|
bytesBase64, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
return outputVal, fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
|
||||||
|
}
|
||||||
|
bytes, err := base64.StdEncoding.DecodeString(bytesBase64)
|
||||||
|
if err != nil {
|
||||||
|
return outputVal, fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
outputVal = protoreflect.ValueOfBytes(bytes)
|
||||||
|
case protoreflect.EnumKind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
outputVal = protoreflect.ValueOfEnum(protoreflect.EnumNumber(int32(num)))
|
||||||
|
case protoreflect.Int32Kind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
outputVal = protoreflect.ValueOfInt32(int32(num))
|
||||||
|
case protoreflect.Int64Kind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
outputVal = protoreflect.ValueOfInt64(int64(num))
|
||||||
|
case protoreflect.Uint32Kind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
outputVal = protoreflect.ValueOfUint32(uint32(num))
|
||||||
|
case protoreflect.Uint64Kind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
outputVal = protoreflect.ValueOfUint64(uint64(num))
|
||||||
|
case protoreflect.FloatKind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
outputVal = protoreflect.ValueOfFloat32(float32(num))
|
||||||
|
case protoreflect.DoubleKind:
|
||||||
|
num, ok = val.(float64)
|
||||||
|
expectedKind = "float64"
|
||||||
|
outputVal = protoreflect.ValueOfFloat64(num)
|
||||||
|
case protoreflect.StringKind:
|
||||||
|
str, ok = val.(string)
|
||||||
|
expectedKind = "string"
|
||||||
|
outputVal = protoreflect.ValueOfString(str)
|
||||||
|
case protoreflect.BoolKind:
|
||||||
|
boolean, ok = val.(bool)
|
||||||
|
expectedKind = "bool"
|
||||||
|
outputVal = protoreflect.ValueOfBool(boolean)
|
||||||
|
default:
|
||||||
|
return outputVal, fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName())
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return outputVal, fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val)
|
||||||
|
}
|
||||||
|
return outputVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func deserializeFromSlice(data []any, ref protoreflect.Message) error {
|
||||||
|
for i := 0; i < ref.Descriptor().Fields().Len(); i++ {
|
||||||
|
fieldDescriptor := ref.Descriptor().Fields().Get(i)
|
||||||
index := int(fieldDescriptor.Number()) - 1
|
index := int(fieldDescriptor.Number()) - 1
|
||||||
if index < 0 || index >= len(data) || data[index] == nil {
|
if index < 0 || index >= len(data) || data[index] == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
val := data[index]
|
val := data[index]
|
||||||
|
outputVal, err := deserializeOne(val, index, ref, fieldDescriptor)
|
||||||
var num float64
|
if err != nil {
|
||||||
var expectedKind, str string
|
|
||||||
var boolean, ok bool
|
|
||||||
switch fieldDescriptor.Kind() {
|
|
||||||
case protoreflect.MessageKind:
|
|
||||||
ok = true
|
|
||||||
nestedData, ok := val.([]any)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
|
|
||||||
}
|
|
||||||
nestedMessage := m.NewField(fieldDescriptor).Message()
|
|
||||||
if err := Deserialize(nestedData, nestedMessage); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfMessage(nestedMessage))
|
ref.Set(fieldDescriptor, outputVal)
|
||||||
case protoreflect.BytesKind:
|
|
||||||
ok = true
|
|
||||||
bytesBase64, ok := val.(string)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
|
|
||||||
}
|
|
||||||
bytes, err := base64.StdEncoding.DecodeString(bytesBase64)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfBytes(bytes))
|
|
||||||
case protoreflect.EnumKind:
|
|
||||||
num, ok = val.(float64)
|
|
||||||
expectedKind = "float64"
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfEnum(protoreflect.EnumNumber(int32(num))))
|
|
||||||
case protoreflect.Int32Kind:
|
|
||||||
num, ok = val.(float64)
|
|
||||||
expectedKind = "float64"
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfInt32(int32(num)))
|
|
||||||
case protoreflect.Int64Kind:
|
|
||||||
num, ok = val.(float64)
|
|
||||||
expectedKind = "float64"
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfInt64(int64(num)))
|
|
||||||
case protoreflect.Uint32Kind:
|
|
||||||
num, ok = val.(float64)
|
|
||||||
expectedKind = "float64"
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfUint32(uint32(num)))
|
|
||||||
case protoreflect.Uint64Kind:
|
|
||||||
num, ok = val.(float64)
|
|
||||||
expectedKind = "float64"
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfUint64(uint64(num)))
|
|
||||||
case protoreflect.FloatKind:
|
|
||||||
num, ok = val.(float64)
|
|
||||||
expectedKind = "float64"
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfFloat32(float32(num)))
|
|
||||||
case protoreflect.DoubleKind:
|
|
||||||
num, ok = val.(float64)
|
|
||||||
expectedKind = "float64"
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfFloat64(num))
|
|
||||||
case protoreflect.StringKind:
|
|
||||||
str, ok = val.(string)
|
|
||||||
expectedKind = "string"
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfString(str))
|
|
||||||
case protoreflect.BoolKind:
|
|
||||||
boolean, ok = val.(bool)
|
|
||||||
expectedKind = "bool"
|
|
||||||
m.Set(fieldDescriptor, protoreflect.ValueOfBool(boolean))
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName())
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
34
libgm/rpc.go
34
libgm/rpc.go
|
@ -3,14 +3,16 @@ package libgm
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"go.mau.fi/mautrix-gmessages/libgm/events"
|
"go.mau.fi/mautrix-gmessages/libgm/events"
|
||||||
"go.mau.fi/mautrix-gmessages/libgm/pblite"
|
"go.mau.fi/mautrix-gmessages/libgm/pblite"
|
||||||
|
|
||||||
|
@ -86,6 +88,7 @@ func (r *RPC) ListenReceiveMessages(payload []byte) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
r.startReadingData(resp.Body)
|
r.startReadingData(resp.Body)
|
||||||
|
r.conn = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,21 +117,27 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
|
||||||
r.client.Logger.Err(err).Msg("Opening is not [[")
|
r.client.Logger.Err(err).Msg("Opening is not [[")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var expectEOF bool
|
||||||
for {
|
for {
|
||||||
n, err = reader.Read(buf)
|
n, err = reader.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) {
|
var logEvt *zerolog.Event
|
||||||
r.client.Logger.Err(err).Msg("Closed body from server")
|
if errors.Is(err, io.EOF) && expectEOF {
|
||||||
r.conn = nil
|
logEvt = r.client.Logger.Debug()
|
||||||
return
|
} else {
|
||||||
|
logEvt = r.client.Logger.Warn()
|
||||||
}
|
}
|
||||||
r.client.Logger.Err(err).Msg("Stopped reading data from server")
|
logEvt.Err(err).Msg("Stopped reading data from server")
|
||||||
return
|
return
|
||||||
|
} else if expectEOF {
|
||||||
|
r.client.Logger.Warn().Msg("Didn't get EOF after stream end marker")
|
||||||
}
|
}
|
||||||
chunk := buf[:n]
|
chunk := buf[:n]
|
||||||
if len(accumulatedData) == 0 {
|
if len(accumulatedData) == 0 {
|
||||||
if len(chunk) == 2 && string(chunk) == "]]" {
|
if len(chunk) == 2 && string(chunk) == "]]" {
|
||||||
r.client.Logger.Debug().Msg("Got stream end marker")
|
r.client.Logger.Debug().Msg("Got stream end marker")
|
||||||
|
expectEOF = true
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
chunk = bytes.TrimPrefix(chunk, []byte{','})
|
chunk = bytes.TrimPrefix(chunk, []byte{','})
|
||||||
}
|
}
|
||||||
|
@ -137,15 +146,10 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
|
||||||
r.client.Logger.Debug().Str("data", string(chunk)).Msg("Invalid JSON")
|
r.client.Logger.Debug().Str("data", string(chunk)).Msg("Invalid JSON")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var msgArr []any
|
currentBlock := accumulatedData
|
||||||
err = json.Unmarshal(accumulatedData, &msgArr)
|
|
||||||
if err != nil {
|
|
||||||
r.client.Logger.Err(err).Msg("Error unmarshalling json")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
accumulatedData = accumulatedData[:0]
|
accumulatedData = accumulatedData[:0]
|
||||||
msg := &binary.InternalMessage{}
|
msg := &binary.InternalMessage{}
|
||||||
err = pblite.Deserialize(msgArr, msg.ProtoReflect())
|
err = pblite.Unmarshal(currentBlock, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.client.Logger.Err(err).Msg("Error deserializing pblite message")
|
r.client.Logger.Err(err).Msg("Error deserializing pblite message")
|
||||||
continue
|
continue
|
||||||
|
@ -161,7 +165,9 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
|
||||||
case msg.GetHeartbeat() != nil:
|
case msg.GetHeartbeat() != nil:
|
||||||
r.client.Logger.Trace().Msg("Got heartbeat message")
|
r.client.Logger.Trace().Msg("Got heartbeat message")
|
||||||
default:
|
default:
|
||||||
r.client.Logger.Warn().Interface("data", msgArr).Msg("Got unknown message")
|
r.client.Logger.Warn().
|
||||||
|
Str("data", base64.StdEncoding.EncodeToString(currentBlock)).
|
||||||
|
Msg("Got unknown message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue