Remove manual json unmarshaling step in pblite

This commit is contained in:
Tulir Asokan 2023-07-15 16:11:36 +03:00
parent e086846574
commit 08cbe12181
3 changed files with 116 additions and 95 deletions

View file

@ -372,16 +372,8 @@ func (c *Client) refreshAuthToken() error {
return readErr return readErr
} }
var deserialized []interface{}
marshalErr := json.Unmarshal(responseBody, &deserialized)
if marshalErr != nil {
return marshalErr
}
resp := &binary.RegisterRefreshResponse{} resp := &binary.RegisterRefreshResponse{}
deserializeErr := pblite.Unmarshal(responseBody, resp)
deserializeErr := pblite.Deserialize(deserialized, resp.ProtoReflect())
if deserializeErr != nil { if deserializeErr != nil {
return deserializeErr return deserializeErr
} }

View file

@ -2,90 +2,113 @@ package pblite
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"fmt" "fmt"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
) )
func Deserialize(data []any, m protoreflect.Message) error { func Unmarshal(data []byte, m proto.Message) error {
for i := 0; i < m.Descriptor().Fields().Len(); i++ { var anyData any
fieldDescriptor := m.Descriptor().Fields().Get(i) if err := json.Unmarshal(data, &anyData); err != nil {
return err
}
anyDataArr, ok := anyData.([]any)
if !ok {
return fmt.Errorf("expected array in JSON, got %T", anyData)
}
return deserializeFromSlice(anyDataArr, m.ProtoReflect())
}
func deserializeOne(val any, index int, ref protoreflect.Message, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) {
var num float64
var expectedKind, str string
var boolean, ok bool
var outputVal protoreflect.Value
switch fieldDescriptor.Kind() {
case protoreflect.MessageKind:
ok = true
nestedData, ok := val.([]any)
if !ok {
return outputVal, fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
}
nestedMessage := ref.NewField(fieldDescriptor).Message()
if err := deserializeFromSlice(nestedData, nestedMessage); err != nil {
return outputVal, err
}
outputVal = protoreflect.ValueOfMessage(nestedMessage)
case protoreflect.BytesKind:
ok = true
bytesBase64, ok := val.(string)
if !ok {
return outputVal, fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
}
bytes, err := base64.StdEncoding.DecodeString(bytesBase64)
if err != nil {
return outputVal, fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}
outputVal = protoreflect.ValueOfBytes(bytes)
case protoreflect.EnumKind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfEnum(protoreflect.EnumNumber(int32(num)))
case protoreflect.Int32Kind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfInt32(int32(num))
case protoreflect.Int64Kind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfInt64(int64(num))
case protoreflect.Uint32Kind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfUint32(uint32(num))
case protoreflect.Uint64Kind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfUint64(uint64(num))
case protoreflect.FloatKind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfFloat32(float32(num))
case protoreflect.DoubleKind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfFloat64(num)
case protoreflect.StringKind:
str, ok = val.(string)
expectedKind = "string"
outputVal = protoreflect.ValueOfString(str)
case protoreflect.BoolKind:
boolean, ok = val.(bool)
expectedKind = "bool"
outputVal = protoreflect.ValueOfBool(boolean)
default:
return outputVal, fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName())
}
if !ok {
return outputVal, fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val)
}
return outputVal, nil
}
func deserializeFromSlice(data []any, ref protoreflect.Message) error {
for i := 0; i < ref.Descriptor().Fields().Len(); i++ {
fieldDescriptor := ref.Descriptor().Fields().Get(i)
index := int(fieldDescriptor.Number()) - 1 index := int(fieldDescriptor.Number()) - 1
if index < 0 || index >= len(data) || data[index] == nil { if index < 0 || index >= len(data) || data[index] == nil {
continue continue
} }
val := data[index] val := data[index]
outputVal, err := deserializeOne(val, index, ref, fieldDescriptor)
var num float64 if err != nil {
var expectedKind, str string
var boolean, ok bool
switch fieldDescriptor.Kind() {
case protoreflect.MessageKind:
ok = true
nestedData, ok := val.([]any)
if !ok {
return fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
}
nestedMessage := m.NewField(fieldDescriptor).Message()
if err := Deserialize(nestedData, nestedMessage); err != nil {
return err return err
} }
m.Set(fieldDescriptor, protoreflect.ValueOfMessage(nestedMessage)) ref.Set(fieldDescriptor, outputVal)
case protoreflect.BytesKind:
ok = true
bytesBase64, ok := val.(string)
if !ok {
return fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
}
bytes, err := base64.StdEncoding.DecodeString(bytesBase64)
if err != nil {
return fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}
m.Set(fieldDescriptor, protoreflect.ValueOfBytes(bytes))
case protoreflect.EnumKind:
num, ok = val.(float64)
expectedKind = "float64"
m.Set(fieldDescriptor, protoreflect.ValueOfEnum(protoreflect.EnumNumber(int32(num))))
case protoreflect.Int32Kind:
num, ok = val.(float64)
expectedKind = "float64"
m.Set(fieldDescriptor, protoreflect.ValueOfInt32(int32(num)))
case protoreflect.Int64Kind:
num, ok = val.(float64)
expectedKind = "float64"
m.Set(fieldDescriptor, protoreflect.ValueOfInt64(int64(num)))
case protoreflect.Uint32Kind:
num, ok = val.(float64)
expectedKind = "float64"
m.Set(fieldDescriptor, protoreflect.ValueOfUint32(uint32(num)))
case protoreflect.Uint64Kind:
num, ok = val.(float64)
expectedKind = "float64"
m.Set(fieldDescriptor, protoreflect.ValueOfUint64(uint64(num)))
case protoreflect.FloatKind:
num, ok = val.(float64)
expectedKind = "float64"
m.Set(fieldDescriptor, protoreflect.ValueOfFloat32(float32(num)))
case protoreflect.DoubleKind:
num, ok = val.(float64)
expectedKind = "float64"
m.Set(fieldDescriptor, protoreflect.ValueOfFloat64(num))
case protoreflect.StringKind:
str, ok = val.(string)
expectedKind = "string"
m.Set(fieldDescriptor, protoreflect.ValueOfString(str))
case protoreflect.BoolKind:
boolean, ok = val.(bool)
expectedKind = "bool"
m.Set(fieldDescriptor, protoreflect.ValueOfBool(boolean))
default:
return fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName())
}
if !ok {
return fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val)
}
} }
return nil return nil
} }

View file

@ -3,14 +3,16 @@ package libgm
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"time" "time"
"github.com/rs/zerolog"
"go.mau.fi/mautrix-gmessages/libgm/events" "go.mau.fi/mautrix-gmessages/libgm/events"
"go.mau.fi/mautrix-gmessages/libgm/pblite" "go.mau.fi/mautrix-gmessages/libgm/pblite"
@ -86,6 +88,7 @@ func (r *RPC) ListenReceiveMessages(payload []byte) {
}() }()
} }
r.startReadingData(resp.Body) r.startReadingData(resp.Body)
r.conn = nil
} }
} }
@ -114,21 +117,27 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
r.client.Logger.Err(err).Msg("Opening is not [[") r.client.Logger.Err(err).Msg("Opening is not [[")
return return
} }
var expectEOF bool
for { for {
n, err = reader.Read(buf) n, err = reader.Read(buf)
if err != nil { if err != nil {
if errors.Is(err, os.ErrClosed) { var logEvt *zerolog.Event
r.client.Logger.Err(err).Msg("Closed body from server") if errors.Is(err, io.EOF) && expectEOF {
r.conn = nil logEvt = r.client.Logger.Debug()
return } else {
logEvt = r.client.Logger.Warn()
} }
r.client.Logger.Err(err).Msg("Stopped reading data from server") logEvt.Err(err).Msg("Stopped reading data from server")
return return
} else if expectEOF {
r.client.Logger.Warn().Msg("Didn't get EOF after stream end marker")
} }
chunk := buf[:n] chunk := buf[:n]
if len(accumulatedData) == 0 { if len(accumulatedData) == 0 {
if len(chunk) == 2 && string(chunk) == "]]" { if len(chunk) == 2 && string(chunk) == "]]" {
r.client.Logger.Debug().Msg("Got stream end marker") r.client.Logger.Debug().Msg("Got stream end marker")
expectEOF = true
continue
} }
chunk = bytes.TrimPrefix(chunk, []byte{','}) chunk = bytes.TrimPrefix(chunk, []byte{','})
} }
@ -137,15 +146,10 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
r.client.Logger.Debug().Str("data", string(chunk)).Msg("Invalid JSON") r.client.Logger.Debug().Str("data", string(chunk)).Msg("Invalid JSON")
continue continue
} }
var msgArr []any currentBlock := accumulatedData
err = json.Unmarshal(accumulatedData, &msgArr)
if err != nil {
r.client.Logger.Err(err).Msg("Error unmarshalling json")
continue
}
accumulatedData = accumulatedData[:0] accumulatedData = accumulatedData[:0]
msg := &binary.InternalMessage{} msg := &binary.InternalMessage{}
err = pblite.Deserialize(msgArr, msg.ProtoReflect()) err = pblite.Unmarshal(currentBlock, msg)
if err != nil { if err != nil {
r.client.Logger.Err(err).Msg("Error deserializing pblite message") r.client.Logger.Err(err).Msg("Error deserializing pblite message")
continue continue
@ -161,7 +165,9 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
case msg.GetHeartbeat() != nil: case msg.GetHeartbeat() != nil:
r.client.Logger.Trace().Msg("Got heartbeat message") r.client.Logger.Trace().Msg("Got heartbeat message")
default: default:
r.client.Logger.Warn().Interface("data", msgArr).Msg("Got unknown message") r.client.Logger.Warn().
Str("data", base64.StdEncoding.EncodeToString(currentBlock)).
Msg("Got unknown message")
} }
} }
} }