diff --git a/libgm/client.go b/libgm/client.go index 9a85b51..6cdab26 100644 --- a/libgm/client.go +++ b/libgm/client.go @@ -372,16 +372,8 @@ func (c *Client) refreshAuthToken() error { return readErr } - var deserialized []interface{} - - marshalErr := json.Unmarshal(responseBody, &deserialized) - if marshalErr != nil { - return marshalErr - } - resp := &binary.RegisterRefreshResponse{} - - deserializeErr := pblite.Deserialize(deserialized, resp.ProtoReflect()) + deserializeErr := pblite.Unmarshal(responseBody, resp) if deserializeErr != nil { return deserializeErr } diff --git a/libgm/pblite/deserialize.go b/libgm/pblite/deserialize.go index b4dc2d4..df986ef 100644 --- a/libgm/pblite/deserialize.go +++ b/libgm/pblite/deserialize.go @@ -2,90 +2,113 @@ package pblite import ( "encoding/base64" + "encoding/json" "fmt" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) -func Deserialize(data []any, m protoreflect.Message) error { - for i := 0; i < m.Descriptor().Fields().Len(); i++ { - fieldDescriptor := m.Descriptor().Fields().Get(i) +func Unmarshal(data []byte, m proto.Message) error { + var anyData any + if err := json.Unmarshal(data, &anyData); err != nil { + return err + } + anyDataArr, ok := anyData.([]any) + if !ok { + return fmt.Errorf("expected array in JSON, got %T", anyData) + } + return deserializeFromSlice(anyDataArr, m.ProtoReflect()) +} + +func deserializeOne(val any, index int, ref protoreflect.Message, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) { + var num float64 + var expectedKind, str string + var boolean, ok bool + var outputVal protoreflect.Value + switch fieldDescriptor.Kind() { + case protoreflect.MessageKind: + ok = true + nestedData, ok := val.([]any) + if !ok { + return outputVal, fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) + } + nestedMessage := ref.NewField(fieldDescriptor).Message() + if err := deserializeFromSlice(nestedData, nestedMessage); err != nil { + return outputVal, err + } + outputVal = protoreflect.ValueOfMessage(nestedMessage) + case protoreflect.BytesKind: + ok = true + bytesBase64, ok := val.(string) + if !ok { + return outputVal, fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) + } + bytes, err := base64.StdEncoding.DecodeString(bytesBase64) + if err != nil { + return outputVal, fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) + } + + outputVal = protoreflect.ValueOfBytes(bytes) + case protoreflect.EnumKind: + num, ok = val.(float64) + expectedKind = "float64" + outputVal = protoreflect.ValueOfEnum(protoreflect.EnumNumber(int32(num))) + case protoreflect.Int32Kind: + num, ok = val.(float64) + expectedKind = "float64" + outputVal = protoreflect.ValueOfInt32(int32(num)) + case protoreflect.Int64Kind: + num, ok = val.(float64) + expectedKind = "float64" + outputVal = protoreflect.ValueOfInt64(int64(num)) + case protoreflect.Uint32Kind: + num, ok = val.(float64) + expectedKind = "float64" + outputVal = protoreflect.ValueOfUint32(uint32(num)) + case protoreflect.Uint64Kind: + num, ok = val.(float64) + expectedKind = "float64" + outputVal = protoreflect.ValueOfUint64(uint64(num)) + case protoreflect.FloatKind: + num, ok = val.(float64) + expectedKind = "float64" + outputVal = protoreflect.ValueOfFloat32(float32(num)) + case protoreflect.DoubleKind: + num, ok = val.(float64) + expectedKind = "float64" + outputVal = protoreflect.ValueOfFloat64(num) + case protoreflect.StringKind: + str, ok = val.(string) + expectedKind = "string" + outputVal = protoreflect.ValueOfString(str) + case protoreflect.BoolKind: + boolean, ok = val.(bool) + expectedKind = "bool" + outputVal = protoreflect.ValueOfBool(boolean) + default: + return outputVal, fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName()) + } + if !ok { + return outputVal, fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val) + } + return outputVal, nil +} + +func deserializeFromSlice(data []any, ref protoreflect.Message) error { + for i := 0; i < ref.Descriptor().Fields().Len(); i++ { + fieldDescriptor := ref.Descriptor().Fields().Get(i) index := int(fieldDescriptor.Number()) - 1 if index < 0 || index >= len(data) || data[index] == nil { continue } val := data[index] - - var num float64 - var expectedKind, str string - var boolean, ok bool - switch fieldDescriptor.Kind() { - case protoreflect.MessageKind: - ok = true - nestedData, ok := val.([]any) - if !ok { - return fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) - } - nestedMessage := m.NewField(fieldDescriptor).Message() - if err := Deserialize(nestedData, nestedMessage); err != nil { - return err - } - m.Set(fieldDescriptor, protoreflect.ValueOfMessage(nestedMessage)) - case protoreflect.BytesKind: - ok = true - bytesBase64, ok := val.(string) - if !ok { - return fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) - } - bytes, err := base64.StdEncoding.DecodeString(bytesBase64) - if err != nil { - return fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) - } - - m.Set(fieldDescriptor, protoreflect.ValueOfBytes(bytes)) - case protoreflect.EnumKind: - num, ok = val.(float64) - expectedKind = "float64" - m.Set(fieldDescriptor, protoreflect.ValueOfEnum(protoreflect.EnumNumber(int32(num)))) - case protoreflect.Int32Kind: - num, ok = val.(float64) - expectedKind = "float64" - m.Set(fieldDescriptor, protoreflect.ValueOfInt32(int32(num))) - case protoreflect.Int64Kind: - num, ok = val.(float64) - expectedKind = "float64" - m.Set(fieldDescriptor, protoreflect.ValueOfInt64(int64(num))) - case protoreflect.Uint32Kind: - num, ok = val.(float64) - expectedKind = "float64" - m.Set(fieldDescriptor, protoreflect.ValueOfUint32(uint32(num))) - case protoreflect.Uint64Kind: - num, ok = val.(float64) - expectedKind = "float64" - m.Set(fieldDescriptor, protoreflect.ValueOfUint64(uint64(num))) - case protoreflect.FloatKind: - num, ok = val.(float64) - expectedKind = "float64" - m.Set(fieldDescriptor, protoreflect.ValueOfFloat32(float32(num))) - case protoreflect.DoubleKind: - num, ok = val.(float64) - expectedKind = "float64" - m.Set(fieldDescriptor, protoreflect.ValueOfFloat64(num)) - case protoreflect.StringKind: - str, ok = val.(string) - expectedKind = "string" - m.Set(fieldDescriptor, protoreflect.ValueOfString(str)) - case protoreflect.BoolKind: - boolean, ok = val.(bool) - expectedKind = "bool" - m.Set(fieldDescriptor, protoreflect.ValueOfBool(boolean)) - default: - return fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName()) - } - if !ok { - return fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val) + outputVal, err := deserializeOne(val, index, ref, fieldDescriptor) + if err != nil { + return err } + ref.Set(fieldDescriptor, outputVal) } return nil } diff --git a/libgm/rpc.go b/libgm/rpc.go index 52416ef..e87b429 100644 --- a/libgm/rpc.go +++ b/libgm/rpc.go @@ -3,14 +3,16 @@ package libgm import ( "bufio" "bytes" + "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" - "os" "time" + "github.com/rs/zerolog" + "go.mau.fi/mautrix-gmessages/libgm/events" "go.mau.fi/mautrix-gmessages/libgm/pblite" @@ -86,6 +88,7 @@ func (r *RPC) ListenReceiveMessages(payload []byte) { }() } r.startReadingData(resp.Body) + r.conn = nil } } @@ -114,21 +117,27 @@ func (r *RPC) startReadingData(rc io.ReadCloser) { r.client.Logger.Err(err).Msg("Opening is not [[") return } + var expectEOF bool for { n, err = reader.Read(buf) if err != nil { - if errors.Is(err, os.ErrClosed) { - r.client.Logger.Err(err).Msg("Closed body from server") - r.conn = nil - return + var logEvt *zerolog.Event + if errors.Is(err, io.EOF) && expectEOF { + logEvt = r.client.Logger.Debug() + } else { + logEvt = r.client.Logger.Warn() } - r.client.Logger.Err(err).Msg("Stopped reading data from server") + logEvt.Err(err).Msg("Stopped reading data from server") return + } else if expectEOF { + r.client.Logger.Warn().Msg("Didn't get EOF after stream end marker") } chunk := buf[:n] if len(accumulatedData) == 0 { if len(chunk) == 2 && string(chunk) == "]]" { r.client.Logger.Debug().Msg("Got stream end marker") + expectEOF = true + continue } chunk = bytes.TrimPrefix(chunk, []byte{','}) } @@ -137,15 +146,10 @@ func (r *RPC) startReadingData(rc io.ReadCloser) { r.client.Logger.Debug().Str("data", string(chunk)).Msg("Invalid JSON") continue } - var msgArr []any - err = json.Unmarshal(accumulatedData, &msgArr) - if err != nil { - r.client.Logger.Err(err).Msg("Error unmarshalling json") - continue - } + currentBlock := accumulatedData accumulatedData = accumulatedData[:0] msg := &binary.InternalMessage{} - err = pblite.Deserialize(msgArr, msg.ProtoReflect()) + err = pblite.Unmarshal(currentBlock, msg) if err != nil { r.client.Logger.Err(err).Msg("Error deserializing pblite message") continue @@ -161,7 +165,9 @@ func (r *RPC) startReadingData(rc io.ReadCloser) { case msg.GetHeartbeat() != nil: r.client.Logger.Trace().Msg("Got heartbeat message") default: - r.client.Logger.Warn().Interface("data", msgArr).Msg("Got unknown message") + r.client.Logger.Warn(). + Str("data", base64.StdEncoding.EncodeToString(currentBlock)). + Msg("Got unknown message") } } }