Refactor incoming RPC data parsing to remove useless structs

This commit is contained in:
Tulir Asokan 2023-07-18 02:01:06 +03:00
parent 604aa19a46
commit 4599f3f0e5
17 changed files with 246 additions and 338 deletions

View file

@ -65,7 +65,7 @@ func NewAuthData() *AuthData {
func NewClient(authData *AuthData, logger zerolog.Logger) *Client { func NewClient(authData *AuthData, logger zerolog.Logger) *Client {
sessionHandler := &SessionHandler{ sessionHandler := &SessionHandler{
responseWaiters: make(map[string]chan<- *pblite.Response), responseWaiters: make(map[string]chan<- *IncomingRPCMessage),
responseTimeout: time.Duration(5000) * time.Millisecond, responseTimeout: time.Duration(5000) * time.Millisecond,
} }
cli := &Client{ cli := &Client{

View file

@ -20,9 +20,9 @@ func (c *Client) ListConversations(count int64, folder gmproto.ListConversations
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.Conversations) res, ok := response.DecryptedMessage.(*gmproto.Conversations)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.Conversations", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.Conversations", response.DecryptedMessage)
} }
return res, nil return res, nil
@ -41,9 +41,9 @@ func (c *Client) ListContacts() (*gmproto.ListContactsResponse, error) {
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.ListContactsResponse) res, ok := response.DecryptedMessage.(*gmproto.ListContactsResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.ListContactsResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.ListContactsResponse", response.DecryptedMessage)
} }
return res, nil return res, nil
@ -60,9 +60,9 @@ func (c *Client) ListTopContacts() (*gmproto.ListTopContactsResponse, error) {
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.ListTopContactsResponse) res, ok := response.DecryptedMessage.(*gmproto.ListTopContactsResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.ListTopContactsResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.ListTopContactsResponse", response.DecryptedMessage)
} }
return res, nil return res, nil
@ -76,9 +76,9 @@ func (c *Client) GetOrCreateConversation(req *gmproto.GetOrCreateConversationPay
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.GetOrCreateConversationResponse) res, ok := response.DecryptedMessage.(*gmproto.GetOrCreateConversationResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.GetOrCreateConversationResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.GetOrCreateConversationResponse", response.DecryptedMessage)
} }
return res, nil return res, nil
@ -93,9 +93,9 @@ func (c *Client) GetConversationType(conversationID string) (*gmproto.GetConvers
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.GetConversationTypeResponse) res, ok := response.DecryptedMessage.(*gmproto.GetConversationTypeResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.GetConversationTypeResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.GetConversationTypeResponse", response.DecryptedMessage)
} }
return res, nil return res, nil
@ -114,9 +114,9 @@ func (c *Client) FetchMessages(conversationID string, count int64, cursor *gmpro
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.FetchMessagesResponse) res, ok := response.DecryptedMessage.(*gmproto.FetchMessagesResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.FetchMessagesResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.FetchMessagesResponse", response.DecryptedMessage)
} }
return res, nil return res, nil
@ -130,9 +130,9 @@ func (c *Client) SendMessage(payload *gmproto.SendMessagePayload) (*gmproto.Send
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.SendMessageResponse) res, ok := response.DecryptedMessage.(*gmproto.SendMessageResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.SendMessageResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.SendMessageResponse", response.DecryptedMessage)
} }
return res, nil return res, nil
@ -147,9 +147,9 @@ func (c *Client) GetParticipantThumbnail(convID string) (*gmproto.ParticipantThu
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.ParticipantThumbnail) res, ok := response.DecryptedMessage.(*gmproto.ParticipantThumbnail)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.ParticipantThumbnail", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.ParticipantThumbnail", response.DecryptedMessage)
} }
return res, nil return res, nil
@ -170,9 +170,9 @@ func (c *Client) UpdateConversation(convBuilder *ConversationBuilder) (*gmproto.
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.UpdateConversationResponse) res, ok := response.DecryptedMessage.(*gmproto.UpdateConversationResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.UpdateConversationResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.UpdateConversationResponse", response.DecryptedMessage)
} }
return res, nil return res, nil

View file

@ -3,12 +3,60 @@ package libgm
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"fmt"
"go.mau.fi/mautrix-gmessages/libgm/pblite" "google.golang.org/protobuf/proto"
"go.mau.fi/mautrix-gmessages/libgm/routes"
"go.mau.fi/mautrix-gmessages/libgm/gmproto" "go.mau.fi/mautrix-gmessages/libgm/gmproto"
) )
type IncomingRPCMessage struct {
*gmproto.IncomingRPCMessage
Pair *gmproto.RPCPairData
Message *gmproto.RPCMessageData
DecryptedData []byte
DecryptedMessage proto.Message
}
func (r *RPC) decryptInternalMessage(data *gmproto.IncomingRPCMessage) (*IncomingRPCMessage, error) {
msg := &IncomingRPCMessage{
IncomingRPCMessage: data,
}
switch data.BugleRoute {
case gmproto.BugleRoute_PairEvent:
msg.Pair = &gmproto.RPCPairData{}
err := proto.Unmarshal(data.GetMessageData(), msg.Pair)
if err != nil {
return nil, err
}
case gmproto.BugleRoute_DataEvent:
msg.Message = &gmproto.RPCMessageData{}
err := proto.Unmarshal(data.GetMessageData(), msg.Message)
if err != nil {
return nil, err
}
if msg.Message.EncryptedData != nil {
msg.DecryptedData, err = r.client.AuthData.RequestCrypto.Decrypt(msg.Message.EncryptedData)
if err != nil {
return nil, err
}
responseStruct := routes.Routes[msg.Message.GetAction()].ResponseStruct
msg.DecryptedMessage = responseStruct.ProtoReflect().New().Interface()
err = proto.Unmarshal(msg.DecryptedData, msg.DecryptedMessage)
if err != nil {
return nil, err
}
}
default:
return nil, fmt.Errorf("unknown bugle route %d", data.BugleRoute)
}
return msg, nil
}
func (r *RPC) deduplicateHash(hash [32]byte) bool { func (r *RPC) deduplicateHash(hash [32]byte) bool {
const recentUpdatesLen = len(r.recentUpdates) const recentUpdatesLen = len(r.recentUpdates)
for i := r.recentUpdatesPtr + recentUpdatesLen - 1; i >= r.recentUpdatesPtr; i-- { for i := r.recentUpdatesPtr + recentUpdatesLen - 1; i >= r.recentUpdatesPtr; i-- {
@ -21,62 +69,62 @@ func (r *RPC) deduplicateHash(hash [32]byte) bool {
return false return false
} }
func (r *RPC) logContent(res *pblite.Response) { func (r *RPC) logContent(res *IncomingRPCMessage) {
if r.client.Logger.Trace().Enabled() && res.Data.Decrypted != nil { if r.client.Logger.Trace().Enabled() && res.DecryptedData != nil {
r.client.Logger.Trace(). evt := r.client.Logger.Trace()
Str("proto_name", string(res.Data.Decrypted.ProtoReflect().Descriptor().FullName())). if res.DecryptedMessage != nil {
Str("data", base64.StdEncoding.EncodeToString(res.Data.RawDecrypted)). evt.Str("proto_name", string(res.DecryptedMessage.ProtoReflect().Descriptor().FullName()))
Msg("Got event") }
if res.DecryptedData != nil {
evt.Str("data", base64.StdEncoding.EncodeToString(res.DecryptedData))
} else {
evt.Str("data", "<null>")
}
evt.Msg("Got event")
} }
} }
func (r *RPC) deduplicateUpdate(response *pblite.Response) bool { func (r *RPC) deduplicateUpdate(msg *IncomingRPCMessage) bool {
if response.Data.RawDecrypted != nil { if msg.DecryptedData != nil {
contentHash := sha256.Sum256(response.Data.RawDecrypted) contentHash := sha256.Sum256(msg.DecryptedData)
if r.deduplicateHash(contentHash) { if r.deduplicateHash(contentHash) {
r.client.Logger.Trace().Hex("data_hash", contentHash[:]).Msg("Ignoring duplicate update") r.client.Logger.Trace().Hex("data_hash", contentHash[:]).Msg("Ignoring duplicate update")
return true return true
} }
r.logContent(response) r.logContent(msg)
} }
return false return false
} }
func (r *RPC) HandleRPCMsg(msg *gmproto.InternalMessage) { func (r *RPC) HandleRPCMsg(rawMsg *gmproto.IncomingRPCMessage) {
response, decodeErr := pblite.DecryptInternalMessage(msg, r.client.AuthData.RequestCrypto) msg, err := r.decryptInternalMessage(rawMsg)
if decodeErr != nil { if err != nil {
r.client.Logger.Error().Err(decodeErr).Msg("rpc decrypt msg err") r.client.Logger.Err(err).Msg("Failed to decode incoming RPC message")
return
}
if response == nil {
r.client.Logger.Error().Msg("nil response in rpc handler")
return return
} }
r.client.sessionHandler.queueMessageAck(response.ResponseID) r.client.sessionHandler.queueMessageAck(msg.ResponseID)
if r.client.sessionHandler.receiveResponse(response) { if r.client.sessionHandler.receiveResponse(msg) {
return return
} }
switch response.BugleRoute { switch msg.BugleRoute {
case gmproto.BugleRoute_PairEvent: case gmproto.BugleRoute_PairEvent:
go r.client.handlePairingEvent(response) go r.client.handlePairingEvent(msg)
case gmproto.BugleRoute_DataEvent: case gmproto.BugleRoute_DataEvent:
if r.skipCount > 0 { if r.skipCount > 0 {
r.skipCount-- r.skipCount--
r.client.Logger.Debug(). r.client.Logger.Debug().
Any("action", response.Data.Action). Any("action", msg.Message.GetAction()).
Int("remaining_skip_count", r.skipCount). Int("remaining_skip_count", r.skipCount).
Msg("Skipped DataEvent") Msg("Skipped DataEvent")
if response.Data.Decrypted != nil { if msg.DecryptedMessage != nil {
r.client.Logger.Trace(). r.client.Logger.Trace().
Str("proto_name", string(response.Data.Decrypted.ProtoReflect().Descriptor().FullName())). Str("proto_name", string(msg.DecryptedMessage.ProtoReflect().Descriptor().FullName())).
Str("data", base64.StdEncoding.EncodeToString(response.Data.RawDecrypted)). Str("data", base64.StdEncoding.EncodeToString(msg.DecryptedData)).
Msg("Skipped event data") Msg("Skipped event data")
} }
return return
} }
r.client.handleUpdatesEvent(response) r.client.handleUpdatesEvent(msg)
default:
r.client.Logger.Debug().Any("res", response).Msg("Got unknown bugle route")
} }
} }

View file

@ -682,20 +682,20 @@ func (x *User) GetNumber() string {
return "" return ""
} }
type PairEvents struct { type RPCPairData struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
// Types that are assignable to Event: // Types that are assignable to Event:
// //
// *PairEvents_Paired // *RPCPairData_Paired
// *PairEvents_Revoked // *RPCPairData_Revoked
Event isPairEvents_Event `protobuf_oneof:"event"` Event isRPCPairData_Event `protobuf_oneof:"event"`
} }
func (x *PairEvents) Reset() { func (x *RPCPairData) Reset() {
*x = PairEvents{} *x = RPCPairData{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_events_proto_msgTypes[7] mi := &file_events_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
@ -703,13 +703,13 @@ func (x *PairEvents) Reset() {
} }
} }
func (x *PairEvents) String() string { func (x *RPCPairData) String() string {
return protoimpl.X.MessageStringOf(x) return protoimpl.X.MessageStringOf(x)
} }
func (*PairEvents) ProtoMessage() {} func (*RPCPairData) ProtoMessage() {}
func (x *PairEvents) ProtoReflect() protoreflect.Message { func (x *RPCPairData) ProtoReflect() protoreflect.Message {
mi := &file_events_proto_msgTypes[7] mi := &file_events_proto_msgTypes[7]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
@ -721,47 +721,47 @@ func (x *PairEvents) ProtoReflect() protoreflect.Message {
return mi.MessageOf(x) return mi.MessageOf(x)
} }
// Deprecated: Use PairEvents.ProtoReflect.Descriptor instead. // Deprecated: Use RPCPairData.ProtoReflect.Descriptor instead.
func (*PairEvents) Descriptor() ([]byte, []int) { func (*RPCPairData) Descriptor() ([]byte, []int) {
return file_events_proto_rawDescGZIP(), []int{7} return file_events_proto_rawDescGZIP(), []int{7}
} }
func (m *PairEvents) GetEvent() isPairEvents_Event { func (m *RPCPairData) GetEvent() isRPCPairData_Event {
if m != nil { if m != nil {
return m.Event return m.Event
} }
return nil return nil
} }
func (x *PairEvents) GetPaired() *PairedData { func (x *RPCPairData) GetPaired() *PairedData {
if x, ok := x.GetEvent().(*PairEvents_Paired); ok { if x, ok := x.GetEvent().(*RPCPairData_Paired); ok {
return x.Paired return x.Paired
} }
return nil return nil
} }
func (x *PairEvents) GetRevoked() *RevokePairData { func (x *RPCPairData) GetRevoked() *RevokePairData {
if x, ok := x.GetEvent().(*PairEvents_Revoked); ok { if x, ok := x.GetEvent().(*RPCPairData_Revoked); ok {
return x.Revoked return x.Revoked
} }
return nil return nil
} }
type isPairEvents_Event interface { type isRPCPairData_Event interface {
isPairEvents_Event() isRPCPairData_Event()
} }
type PairEvents_Paired struct { type RPCPairData_Paired struct {
Paired *PairedData `protobuf:"bytes,4,opt,name=paired,proto3,oneof"` Paired *PairedData `protobuf:"bytes,4,opt,name=paired,proto3,oneof"`
} }
type PairEvents_Revoked struct { type RPCPairData_Revoked struct {
Revoked *RevokePairData `protobuf:"bytes,5,opt,name=revoked,proto3,oneof"` Revoked *RevokePairData `protobuf:"bytes,5,opt,name=revoked,proto3,oneof"`
} }
func (*PairEvents_Paired) isPairEvents_Event() {} func (*RPCPairData_Paired) isRPCPairData_Event() {}
func (*PairEvents_Revoked) isPairEvents_Event() {} func (*RPCPairData_Revoked) isRPCPairData_Event() {}
var File_events_proto protoreflect.FileDescriptor var File_events_proto protoreflect.FileDescriptor
@ -793,7 +793,7 @@ var file_events_proto_goTypes = []interface{}{
(*UserAlertEvent)(nil), // 7: events.UserAlertEvent (*UserAlertEvent)(nil), // 7: events.UserAlertEvent
(*TypingData)(nil), // 8: events.TypingData (*TypingData)(nil), // 8: events.TypingData
(*User)(nil), // 9: events.User (*User)(nil), // 9: events.User
(*PairEvents)(nil), // 10: events.PairEvents (*RPCPairData)(nil), // 10: events.RPCPairData
(*Settings)(nil), // 11: settings.Settings (*Settings)(nil), // 11: settings.Settings
(*Conversation)(nil), // 12: conversations.Conversation (*Conversation)(nil), // 12: conversations.Conversation
(*Message)(nil), // 13: conversations.Message (*Message)(nil), // 13: conversations.Message
@ -812,8 +812,8 @@ var file_events_proto_depIdxs = []int32{
0, // 8: events.UserAlertEvent.alertType:type_name -> events.AlertType 0, // 8: events.UserAlertEvent.alertType:type_name -> events.AlertType
9, // 9: events.TypingData.user:type_name -> events.User 9, // 9: events.TypingData.user:type_name -> events.User
2, // 10: events.TypingData.type:type_name -> events.TypingTypes 2, // 10: events.TypingData.type:type_name -> events.TypingTypes
14, // 11: events.PairEvents.paired:type_name -> authentication.PairedData 14, // 11: events.RPCPairData.paired:type_name -> authentication.PairedData
15, // 12: events.PairEvents.revoked:type_name -> authentication.RevokePairData 15, // 12: events.RPCPairData.revoked:type_name -> authentication.RevokePairData
13, // [13:13] is the sub-list for method output_type 13, // [13:13] is the sub-list for method output_type
13, // [13:13] is the sub-list for method input_type 13, // [13:13] is the sub-list for method input_type
13, // [13:13] is the sub-list for extension type_name 13, // [13:13] is the sub-list for extension type_name
@ -915,7 +915,7 @@ func file_events_proto_init() {
} }
} }
file_events_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { file_events_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PairEvents); i { switch v := v.(*RPCPairData); i {
case 0: case 0:
return &v.state return &v.state
case 1: case 1:
@ -935,8 +935,8 @@ func file_events_proto_init() {
(*UpdateEvents_UserAlertEvent)(nil), (*UpdateEvents_UserAlertEvent)(nil),
} }
file_events_proto_msgTypes[7].OneofWrappers = []interface{}{ file_events_proto_msgTypes[7].OneofWrappers = []interface{}{
(*PairEvents_Paired)(nil), (*RPCPairData_Paired)(nil),
(*PairEvents_Revoked)(nil), (*RPCPairData_Revoked)(nil),
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{

Binary file not shown.

View file

@ -61,7 +61,7 @@ message User {
string number = 2; string number = 2;
} }
message PairEvents { message RPCPairData {
oneof event { oneof event {
authentication.PairedData paired = 4; authentication.PairedData paired = 4;
authentication.RevokePairData revoked = 5; authentication.RevokePairData revoked = 5;

View file

@ -479,19 +479,19 @@ func (x *StartAckMessage) GetCount() int32 {
return 0 return 0
} }
type InternalMessage struct { type LongPollingPayload struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
Data *InternalMessageData `protobuf:"bytes,2,opt,name=data,proto3,oneof" json:"data,omitempty"` Data *IncomingRPCMessage `protobuf:"bytes,2,opt,name=data,proto3,oneof" json:"data,omitempty"`
Heartbeat *EmptyArr `protobuf:"bytes,3,opt,name=heartbeat,proto3,oneof" json:"heartbeat,omitempty"` Heartbeat *EmptyArr `protobuf:"bytes,3,opt,name=heartbeat,proto3,oneof" json:"heartbeat,omitempty"`
Ack *StartAckMessage `protobuf:"bytes,4,opt,name=ack,proto3,oneof" json:"ack,omitempty"` Ack *StartAckMessage `protobuf:"bytes,4,opt,name=ack,proto3,oneof" json:"ack,omitempty"`
StartRead *EmptyArr `protobuf:"bytes,5,opt,name=startRead,proto3,oneof" json:"startRead,omitempty"` StartRead *EmptyArr `protobuf:"bytes,5,opt,name=startRead,proto3,oneof" json:"startRead,omitempty"`
} }
func (x *InternalMessage) Reset() { func (x *LongPollingPayload) Reset() {
*x = InternalMessage{} *x = LongPollingPayload{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_messages_proto_msgTypes[3] mi := &file_messages_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
@ -499,13 +499,13 @@ func (x *InternalMessage) Reset() {
} }
} }
func (x *InternalMessage) String() string { func (x *LongPollingPayload) String() string {
return protoimpl.X.MessageStringOf(x) return protoimpl.X.MessageStringOf(x)
} }
func (*InternalMessage) ProtoMessage() {} func (*LongPollingPayload) ProtoMessage() {}
func (x *InternalMessage) ProtoReflect() protoreflect.Message { func (x *LongPollingPayload) ProtoReflect() protoreflect.Message {
mi := &file_messages_proto_msgTypes[3] mi := &file_messages_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
@ -517,40 +517,40 @@ func (x *InternalMessage) ProtoReflect() protoreflect.Message {
return mi.MessageOf(x) return mi.MessageOf(x)
} }
// Deprecated: Use InternalMessage.ProtoReflect.Descriptor instead. // Deprecated: Use LongPollingPayload.ProtoReflect.Descriptor instead.
func (*InternalMessage) Descriptor() ([]byte, []int) { func (*LongPollingPayload) Descriptor() ([]byte, []int) {
return file_messages_proto_rawDescGZIP(), []int{3} return file_messages_proto_rawDescGZIP(), []int{3}
} }
func (x *InternalMessage) GetData() *InternalMessageData { func (x *LongPollingPayload) GetData() *IncomingRPCMessage {
if x != nil { if x != nil {
return x.Data return x.Data
} }
return nil return nil
} }
func (x *InternalMessage) GetHeartbeat() *EmptyArr { func (x *LongPollingPayload) GetHeartbeat() *EmptyArr {
if x != nil { if x != nil {
return x.Heartbeat return x.Heartbeat
} }
return nil return nil
} }
func (x *InternalMessage) GetAck() *StartAckMessage { func (x *LongPollingPayload) GetAck() *StartAckMessage {
if x != nil { if x != nil {
return x.Ack return x.Ack
} }
return nil return nil
} }
func (x *InternalMessage) GetStartRead() *EmptyArr { func (x *LongPollingPayload) GetStartRead() *EmptyArr {
if x != nil { if x != nil {
return x.StartRead return x.StartRead
} }
return nil return nil
} }
type InternalMessageData struct { type IncomingRPCMessage struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
@ -563,13 +563,14 @@ type InternalMessageData struct {
MillisecondsTaken string `protobuf:"bytes,7,opt,name=millisecondsTaken,proto3" json:"millisecondsTaken,omitempty"` MillisecondsTaken string `protobuf:"bytes,7,opt,name=millisecondsTaken,proto3" json:"millisecondsTaken,omitempty"`
Mobile *Device `protobuf:"bytes,8,opt,name=mobile,proto3" json:"mobile,omitempty"` Mobile *Device `protobuf:"bytes,8,opt,name=mobile,proto3" json:"mobile,omitempty"`
Browser *Device `protobuf:"bytes,9,opt,name=browser,proto3" json:"browser,omitempty"` Browser *Device `protobuf:"bytes,9,opt,name=browser,proto3" json:"browser,omitempty"`
ProtobufData []byte `protobuf:"bytes,12,opt,name=protobufData,proto3" json:"protobufData,omitempty"` // Either a RPCMessageData or a RPCPairData encoded as bytes
SignatureID string `protobuf:"bytes,17,opt,name=signatureID,proto3" json:"signatureID,omitempty"` MessageData []byte `protobuf:"bytes,12,opt,name=messageData,proto3" json:"messageData,omitempty"`
Timestamp string `protobuf:"bytes,21,opt,name=timestamp,proto3" json:"timestamp,omitempty"` SignatureID string `protobuf:"bytes,17,opt,name=signatureID,proto3" json:"signatureID,omitempty"`
Timestamp string `protobuf:"bytes,21,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
} }
func (x *InternalMessageData) Reset() { func (x *IncomingRPCMessage) Reset() {
*x = InternalMessageData{} *x = IncomingRPCMessage{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_messages_proto_msgTypes[4] mi := &file_messages_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
@ -577,13 +578,13 @@ func (x *InternalMessageData) Reset() {
} }
} }
func (x *InternalMessageData) String() string { func (x *IncomingRPCMessage) String() string {
return protoimpl.X.MessageStringOf(x) return protoimpl.X.MessageStringOf(x)
} }
func (*InternalMessageData) ProtoMessage() {} func (*IncomingRPCMessage) ProtoMessage() {}
func (x *InternalMessageData) ProtoReflect() protoreflect.Message { func (x *IncomingRPCMessage) ProtoReflect() protoreflect.Message {
mi := &file_messages_proto_msgTypes[4] mi := &file_messages_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
@ -595,89 +596,89 @@ func (x *InternalMessageData) ProtoReflect() protoreflect.Message {
return mi.MessageOf(x) return mi.MessageOf(x)
} }
// Deprecated: Use InternalMessageData.ProtoReflect.Descriptor instead. // Deprecated: Use IncomingRPCMessage.ProtoReflect.Descriptor instead.
func (*InternalMessageData) Descriptor() ([]byte, []int) { func (*IncomingRPCMessage) Descriptor() ([]byte, []int) {
return file_messages_proto_rawDescGZIP(), []int{4} return file_messages_proto_rawDescGZIP(), []int{4}
} }
func (x *InternalMessageData) GetResponseID() string { func (x *IncomingRPCMessage) GetResponseID() string {
if x != nil { if x != nil {
return x.ResponseID return x.ResponseID
} }
return "" return ""
} }
func (x *InternalMessageData) GetBugleRoute() BugleRoute { func (x *IncomingRPCMessage) GetBugleRoute() BugleRoute {
if x != nil { if x != nil {
return x.BugleRoute return x.BugleRoute
} }
return BugleRoute_UNKNOWN_BUGLE_ROUTE return BugleRoute_UNKNOWN_BUGLE_ROUTE
} }
func (x *InternalMessageData) GetStartExecute() string { func (x *IncomingRPCMessage) GetStartExecute() string {
if x != nil { if x != nil {
return x.StartExecute return x.StartExecute
} }
return "" return ""
} }
func (x *InternalMessageData) GetMessageType() MessageType { func (x *IncomingRPCMessage) GetMessageType() MessageType {
if x != nil { if x != nil {
return x.MessageType return x.MessageType
} }
return MessageType_UNKNOWN_MESSAGE_TYPE return MessageType_UNKNOWN_MESSAGE_TYPE
} }
func (x *InternalMessageData) GetFinishExecute() string { func (x *IncomingRPCMessage) GetFinishExecute() string {
if x != nil { if x != nil {
return x.FinishExecute return x.FinishExecute
} }
return "" return ""
} }
func (x *InternalMessageData) GetMillisecondsTaken() string { func (x *IncomingRPCMessage) GetMillisecondsTaken() string {
if x != nil { if x != nil {
return x.MillisecondsTaken return x.MillisecondsTaken
} }
return "" return ""
} }
func (x *InternalMessageData) GetMobile() *Device { func (x *IncomingRPCMessage) GetMobile() *Device {
if x != nil { if x != nil {
return x.Mobile return x.Mobile
} }
return nil return nil
} }
func (x *InternalMessageData) GetBrowser() *Device { func (x *IncomingRPCMessage) GetBrowser() *Device {
if x != nil { if x != nil {
return x.Browser return x.Browser
} }
return nil return nil
} }
func (x *InternalMessageData) GetProtobufData() []byte { func (x *IncomingRPCMessage) GetMessageData() []byte {
if x != nil { if x != nil {
return x.ProtobufData return x.MessageData
} }
return nil return nil
} }
func (x *InternalMessageData) GetSignatureID() string { func (x *IncomingRPCMessage) GetSignatureID() string {
if x != nil { if x != nil {
return x.SignatureID return x.SignatureID
} }
return "" return ""
} }
func (x *InternalMessageData) GetTimestamp() string { func (x *IncomingRPCMessage) GetTimestamp() string {
if x != nil { if x != nil {
return x.Timestamp return x.Timestamp
} }
return "" return ""
} }
type InternalRequestData struct { type RPCMessageData struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
@ -691,8 +692,8 @@ type InternalRequestData struct {
Bool3 bool `protobuf:"varint,9,opt,name=bool3,proto3" json:"bool3,omitempty"` Bool3 bool `protobuf:"varint,9,opt,name=bool3,proto3" json:"bool3,omitempty"`
} }
func (x *InternalRequestData) Reset() { func (x *RPCMessageData) Reset() {
*x = InternalRequestData{} *x = RPCMessageData{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_messages_proto_msgTypes[5] mi := &file_messages_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
@ -700,13 +701,13 @@ func (x *InternalRequestData) Reset() {
} }
} }
func (x *InternalRequestData) String() string { func (x *RPCMessageData) String() string {
return protoimpl.X.MessageStringOf(x) return protoimpl.X.MessageStringOf(x)
} }
func (*InternalRequestData) ProtoMessage() {} func (*RPCMessageData) ProtoMessage() {}
func (x *InternalRequestData) ProtoReflect() protoreflect.Message { func (x *RPCMessageData) ProtoReflect() protoreflect.Message {
mi := &file_messages_proto_msgTypes[5] mi := &file_messages_proto_msgTypes[5]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
@ -718,54 +719,54 @@ func (x *InternalRequestData) ProtoReflect() protoreflect.Message {
return mi.MessageOf(x) return mi.MessageOf(x)
} }
// Deprecated: Use InternalRequestData.ProtoReflect.Descriptor instead. // Deprecated: Use RPCMessageData.ProtoReflect.Descriptor instead.
func (*InternalRequestData) Descriptor() ([]byte, []int) { func (*RPCMessageData) Descriptor() ([]byte, []int) {
return file_messages_proto_rawDescGZIP(), []int{5} return file_messages_proto_rawDescGZIP(), []int{5}
} }
func (x *InternalRequestData) GetSessionID() string { func (x *RPCMessageData) GetSessionID() string {
if x != nil { if x != nil {
return x.SessionID return x.SessionID
} }
return "" return ""
} }
func (x *InternalRequestData) GetTimestamp() int64 { func (x *RPCMessageData) GetTimestamp() int64 {
if x != nil { if x != nil {
return x.Timestamp return x.Timestamp
} }
return 0 return 0
} }
func (x *InternalRequestData) GetAction() ActionType { func (x *RPCMessageData) GetAction() ActionType {
if x != nil { if x != nil {
return x.Action return x.Action
} }
return ActionType_UNSPECIFIED return ActionType_UNSPECIFIED
} }
func (x *InternalRequestData) GetBool1() bool { func (x *RPCMessageData) GetBool1() bool {
if x != nil { if x != nil {
return x.Bool1 return x.Bool1
} }
return false return false
} }
func (x *InternalRequestData) GetBool2() bool { func (x *RPCMessageData) GetBool2() bool {
if x != nil { if x != nil {
return x.Bool2 return x.Bool2
} }
return false return false
} }
func (x *InternalRequestData) GetEncryptedData() []byte { func (x *RPCMessageData) GetEncryptedData() []byte {
if x != nil { if x != nil {
return x.EncryptedData return x.EncryptedData
} }
return nil return nil
} }
func (x *InternalRequestData) GetBool3() bool { func (x *RPCMessageData) GetBool3() bool {
if x != nil { if x != nil {
return x.Bool3 return x.Bool3
} }
@ -1645,9 +1646,9 @@ var file_messages_proto_goTypes = []interface{}{
(*RegisterRefreshPayload)(nil), // 3: messages.RegisterRefreshPayload (*RegisterRefreshPayload)(nil), // 3: messages.RegisterRefreshPayload
(*EmptyRefreshArr)(nil), // 4: messages.EmptyRefreshArr (*EmptyRefreshArr)(nil), // 4: messages.EmptyRefreshArr
(*StartAckMessage)(nil), // 5: messages.StartAckMessage (*StartAckMessage)(nil), // 5: messages.StartAckMessage
(*InternalMessage)(nil), // 6: messages.InternalMessage (*LongPollingPayload)(nil), // 6: messages.LongPollingPayload
(*InternalMessageData)(nil), // 7: messages.InternalMessageData (*IncomingRPCMessage)(nil), // 7: messages.IncomingRPCMessage
(*InternalRequestData)(nil), // 8: messages.InternalRequestData (*RPCMessageData)(nil), // 8: messages.RPCMessageData
(*RevokeRelayPairing)(nil), // 9: messages.RevokeRelayPairing (*RevokeRelayPairing)(nil), // 9: messages.RevokeRelayPairing
(*SendMessage)(nil), // 10: messages.SendMessage (*SendMessage)(nil), // 10: messages.SendMessage
(*SendMessageAuth)(nil), // 11: messages.SendMessageAuth (*SendMessageAuth)(nil), // 11: messages.SendMessageAuth
@ -1668,15 +1669,15 @@ var file_messages_proto_depIdxs = []int32{
19, // 1: messages.RegisterRefreshPayload.currBrowserDevice:type_name -> messages.Device 19, // 1: messages.RegisterRefreshPayload.currBrowserDevice:type_name -> messages.Device
4, // 2: messages.RegisterRefreshPayload.emptyRefreshArr:type_name -> messages.EmptyRefreshArr 4, // 2: messages.RegisterRefreshPayload.emptyRefreshArr:type_name -> messages.EmptyRefreshArr
15, // 3: messages.EmptyRefreshArr.emptyArr:type_name -> messages.EmptyArr 15, // 3: messages.EmptyRefreshArr.emptyArr:type_name -> messages.EmptyArr
7, // 4: messages.InternalMessage.data:type_name -> messages.InternalMessageData 7, // 4: messages.LongPollingPayload.data:type_name -> messages.IncomingRPCMessage
15, // 5: messages.InternalMessage.heartbeat:type_name -> messages.EmptyArr 15, // 5: messages.LongPollingPayload.heartbeat:type_name -> messages.EmptyArr
5, // 6: messages.InternalMessage.ack:type_name -> messages.StartAckMessage 5, // 6: messages.LongPollingPayload.ack:type_name -> messages.StartAckMessage
15, // 7: messages.InternalMessage.startRead:type_name -> messages.EmptyArr 15, // 7: messages.LongPollingPayload.startRead:type_name -> messages.EmptyArr
0, // 8: messages.InternalMessageData.bugleRoute:type_name -> messages.BugleRoute 0, // 8: messages.IncomingRPCMessage.bugleRoute:type_name -> messages.BugleRoute
2, // 9: messages.InternalMessageData.messageType:type_name -> messages.MessageType 2, // 9: messages.IncomingRPCMessage.messageType:type_name -> messages.MessageType
19, // 10: messages.InternalMessageData.mobile:type_name -> messages.Device 19, // 10: messages.IncomingRPCMessage.mobile:type_name -> messages.Device
19, // 11: messages.InternalMessageData.browser:type_name -> messages.Device 19, // 11: messages.IncomingRPCMessage.browser:type_name -> messages.Device
1, // 12: messages.InternalRequestData.action:type_name -> messages.ActionType 1, // 12: messages.RPCMessageData.action:type_name -> messages.ActionType
16, // 13: messages.RevokeRelayPairing.authMessage:type_name -> messages.AuthMessage 16, // 13: messages.RevokeRelayPairing.authMessage:type_name -> messages.AuthMessage
19, // 14: messages.RevokeRelayPairing.browser:type_name -> messages.Device 19, // 14: messages.RevokeRelayPairing.browser:type_name -> messages.Device
19, // 15: messages.SendMessage.mobile:type_name -> messages.Device 19, // 15: messages.SendMessage.mobile:type_name -> messages.Device
@ -1744,7 +1745,7 @@ func file_messages_proto_init() {
} }
} }
file_messages_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { file_messages_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*InternalMessage); i { switch v := v.(*LongPollingPayload); i {
case 0: case 0:
return &v.state return &v.state
case 1: case 1:
@ -1756,7 +1757,7 @@ func file_messages_proto_init() {
} }
} }
file_messages_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { file_messages_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*InternalMessageData); i { switch v := v.(*IncomingRPCMessage); i {
case 0: case 0:
return &v.state return &v.state
case 1: case 1:
@ -1768,7 +1769,7 @@ func file_messages_proto_init() {
} }
} }
file_messages_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { file_messages_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*InternalRequestData); i { switch v := v.(*RPCMessageData); i {
case 0: case 0:
return &v.state return &v.state
case 1: case 1:

Binary file not shown.

View file

@ -20,14 +20,14 @@ message StartAckMessage {
optional int32 count = 1; optional int32 count = 1;
} }
message InternalMessage { message LongPollingPayload {
optional InternalMessageData data = 2; optional IncomingRPCMessage data = 2;
optional EmptyArr heartbeat = 3; optional EmptyArr heartbeat = 3;
optional StartAckMessage ack = 4; optional StartAckMessage ack = 4;
optional EmptyArr startRead = 5; optional EmptyArr startRead = 5;
} }
message InternalMessageData { message IncomingRPCMessage {
string responseID = 1; string responseID = 1;
BugleRoute bugleRoute = 2; BugleRoute bugleRoute = 2;
string startExecute = 3; string startExecute = 3;
@ -38,14 +38,15 @@ message InternalMessageData {
Device mobile = 8; Device mobile = 8;
Device browser = 9; Device browser = 9;
bytes protobufData = 12; // Either a RPCMessageData or a RPCPairData encoded as bytes
bytes messageData = 12;
string signatureID = 17; string signatureID = 17;
string timestamp = 21; string timestamp = 21;
} }
message InternalRequestData { message RPCMessageData {
string sessionID = 1; string sessionID = 1;
int64 timestamp = 3; int64 timestamp = 3;
ActionType action = 4; ActionType action = 4;

View file

@ -14,9 +14,9 @@ func (c *Client) SendReaction(payload *gmproto.SendReactionPayload) (*gmproto.Se
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.SendReactionResponse) res, ok := response.DecryptedMessage.(*gmproto.SendReactionResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.SendReactionResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.SendReactionResponse", response.DecryptedMessage)
} }
return res, nil return res, nil
@ -31,9 +31,9 @@ func (c *Client) DeleteMessage(messageID string) (*gmproto.DeleteMessageResponse
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.DeleteMessageResponse) res, ok := response.DecryptedMessage.(*gmproto.DeleteMessageResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.DeleteMessagesResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.DeleteMessagesResponse", response.DecryptedMessage)
} }
return res, nil return res, nil

View file

@ -4,22 +4,14 @@ import (
"fmt" "fmt"
"go.mau.fi/mautrix-gmessages/libgm/events" "go.mau.fi/mautrix-gmessages/libgm/events"
"go.mau.fi/mautrix-gmessages/libgm/pblite"
"go.mau.fi/mautrix-gmessages/libgm/gmproto" "go.mau.fi/mautrix-gmessages/libgm/gmproto"
) )
func (c *Client) handlePairingEvent(response *pblite.Response) { func (c *Client) handlePairingEvent(msg *IncomingRPCMessage) {
pairEventData, ok := response.Data.Decrypted.(*gmproto.PairEvents) switch evt := msg.Pair.Event.(type) {
if !ok { case *gmproto.RPCPairData_Paired:
c.Logger.Error().Type("decrypted_type", response.Data.Decrypted).Msg("Unexpected data type in pair event")
return
}
switch evt := pairEventData.Event.(type) {
case *gmproto.PairEvents_Paired:
c.completePairing(evt.Paired) c.completePairing(evt.Paired)
case *gmproto.PairEvents_Revoked: case *gmproto.RPCPairData_Revoked:
c.triggerEvent(evt.Revoked) c.triggerEvent(evt.Revoked)
default: default:
c.Logger.Debug().Any("evt", evt).Msg("Unknown pair event type") c.Logger.Debug().Any("evt", evt).Msg("Unknown pair event type")

View file

@ -1,127 +0,0 @@
package pblite
import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"go.mau.fi/mautrix-gmessages/libgm/crypto"
"go.mau.fi/mautrix-gmessages/libgm/gmproto"
"go.mau.fi/mautrix-gmessages/libgm/routes"
)
type DevicePair struct {
Mobile *gmproto.Device `json:"mobile,omitempty"`
Browser *gmproto.Device `json:"browser,omitempty"`
}
type RequestData struct {
RequestID string `json:"requestId,omitempty"`
Timestamp int64 `json:"timestamp,omitempty"`
Action gmproto.ActionType `json:"action,omitempty"`
Bool1 bool `json:"bool1,omitempty"`
Bool2 bool `json:"bool2,omitempty"`
EncryptedData []byte `json:"requestData,omitempty"`
RawDecrypted []byte `json:"-,omitempty"`
Decrypted proto.Message `json:"decrypted,omitempty"`
Bool3 bool `json:"bool3,omitempty"`
}
type Response struct {
ResponseID string `json:"responseId,omitempty"`
BugleRoute gmproto.BugleRoute `json:"bugleRoute,omitempty"`
StartExecute string `json:"startExecute,omitempty"`
MessageType gmproto.MessageType `json:"eventType,omitempty"`
FinishExecute string `json:"finishExecute,omitempty"`
MillisecondsTaken string `json:"millisecondsTaken,omitempty"`
Devices *DevicePair `json:"devices,omitempty"`
Data RequestData `json:"data,omitempty"`
SignatureId string `json:"signatureId,omitempty"`
Timestamp string `json:"timestamp"`
}
func DecryptInternalMessage(internalMessage *gmproto.InternalMessage, cryptor *crypto.AESCTRHelper) (*Response, error) {
var resp *Response
switch internalMessage.Data.BugleRoute {
case gmproto.BugleRoute_PairEvent:
decodedData := &gmproto.PairEvents{}
decodeErr := proto.Unmarshal(internalMessage.Data.ProtobufData, decodedData)
if decodeErr != nil {
return nil, decodeErr
}
resp = newResponseFromPairEvent(internalMessage.GetData(), decodedData)
case gmproto.BugleRoute_DataEvent:
internalRequestData := &gmproto.InternalRequestData{}
decodeErr := proto.Unmarshal(internalMessage.Data.ProtobufData, internalRequestData)
if decodeErr != nil {
return nil, decodeErr
}
if internalRequestData.EncryptedData != nil {
decryptedBytes, err := cryptor.Decrypt(internalRequestData.EncryptedData)
if err != nil {
return nil, err
}
responseStruct := routes.Routes[internalRequestData.GetAction()].ResponseStruct
deserializedData := responseStruct.ProtoReflect().New().Interface()
err = proto.Unmarshal(decryptedBytes, deserializedData)
if err != nil {
return nil, err
}
resp = newResponseFromDataEvent(internalMessage.GetData(), internalRequestData, decryptedBytes, deserializedData)
} else {
resp = newResponseFromDataEvent(internalMessage.GetData(), internalRequestData, nil, nil)
}
}
return resp, nil
}
func newResponseFromPairEvent(internalMsg *gmproto.InternalMessageData, data *gmproto.PairEvents) *Response {
resp := &Response{
ResponseID: internalMsg.GetResponseID(),
BugleRoute: internalMsg.GetBugleRoute(),
StartExecute: internalMsg.GetStartExecute(),
MessageType: internalMsg.GetMessageType(),
FinishExecute: internalMsg.GetFinishExecute(),
MillisecondsTaken: internalMsg.GetMillisecondsTaken(),
Devices: &DevicePair{
Mobile: internalMsg.GetMobile(),
Browser: internalMsg.GetBrowser(),
},
Data: RequestData{
Decrypted: data,
},
Timestamp: internalMsg.GetTimestamp(),
SignatureId: internalMsg.GetSignatureID(),
}
return resp
}
func newResponseFromDataEvent(internalMsg *gmproto.InternalMessageData, internalRequestData *gmproto.InternalRequestData, rawData []byte, decrypted protoreflect.ProtoMessage) *Response {
resp := &Response{
ResponseID: internalMsg.GetResponseID(),
BugleRoute: internalMsg.GetBugleRoute(),
StartExecute: internalMsg.GetStartExecute(),
MessageType: internalMsg.GetMessageType(),
FinishExecute: internalMsg.GetFinishExecute(),
MillisecondsTaken: internalMsg.GetMillisecondsTaken(),
Devices: &DevicePair{
Mobile: internalMsg.GetMobile(),
Browser: internalMsg.GetBrowser(),
},
Data: RequestData{
RequestID: internalRequestData.GetSessionID(),
Timestamp: internalRequestData.GetTimestamp(),
Action: internalRequestData.GetAction(),
Bool1: internalRequestData.GetBool1(),
Bool2: internalRequestData.GetBool2(),
EncryptedData: internalRequestData.GetEncryptedData(),
Decrypted: decrypted,
RawDecrypted: rawData,
Bool3: internalRequestData.GetBool3(),
},
SignatureId: internalMsg.GetSignatureID(),
Timestamp: internalMsg.GetTimestamp(),
}
return resp
}

View file

@ -2,49 +2,44 @@ package libgm
import ( import (
"encoding/base64" "encoding/base64"
"fmt"
"go.mau.fi/mautrix-gmessages/libgm/pblite"
) )
func (s *SessionHandler) waitResponse(requestID string) chan *pblite.Response { func (s *SessionHandler) waitResponse(requestID string) chan *IncomingRPCMessage {
ch := make(chan *pblite.Response, 1) ch := make(chan *IncomingRPCMessage, 1)
s.responseWaitersLock.Lock() s.responseWaitersLock.Lock()
// DEBUG
if _, ok := s.responseWaiters[requestID]; ok {
panic(fmt.Errorf("request %s already has a response waiter", requestID))
}
// END DEBUG
s.responseWaiters[requestID] = ch s.responseWaiters[requestID] = ch
s.responseWaitersLock.Unlock() s.responseWaitersLock.Unlock()
return ch return ch
} }
func (s *SessionHandler) cancelResponse(requestID string, ch chan *pblite.Response) { func (s *SessionHandler) cancelResponse(requestID string, ch chan *IncomingRPCMessage) {
s.responseWaitersLock.Lock() s.responseWaitersLock.Lock()
close(ch) close(ch)
delete(s.responseWaiters, requestID) delete(s.responseWaiters, requestID)
s.responseWaitersLock.Unlock() s.responseWaitersLock.Unlock()
} }
func (s *SessionHandler) receiveResponse(resp *pblite.Response) bool { func (s *SessionHandler) receiveResponse(msg *IncomingRPCMessage) bool {
requestID := msg.Message.SessionID
s.responseWaitersLock.Lock() s.responseWaitersLock.Lock()
ch, ok := s.responseWaiters[resp.Data.RequestID] ch, ok := s.responseWaiters[requestID]
if !ok { if !ok {
s.responseWaitersLock.Unlock() s.responseWaitersLock.Unlock()
return false return false
} }
delete(s.responseWaiters, resp.Data.RequestID) delete(s.responseWaiters, requestID)
s.responseWaitersLock.Unlock() s.responseWaitersLock.Unlock()
evt := s.client.Logger.Trace(). evt := s.client.Logger.Trace().
Str("request_id", resp.Data.RequestID) Str("request_id", requestID)
if evt.Enabled() && resp.Data.Decrypted != nil { if evt.Enabled() {
evt.Str("proto_name", string(resp.Data.Decrypted.ProtoReflect().Descriptor().FullName())). if msg.DecryptedData != nil {
Str("data", base64.StdEncoding.EncodeToString(resp.Data.RawDecrypted)) evt.Str("data", base64.StdEncoding.EncodeToString(msg.DecryptedData))
} else if resp.Data.RawDecrypted != nil { }
evt.Str("unrecognized_data", base64.StdEncoding.EncodeToString(resp.Data.RawDecrypted)) if msg.DecryptedMessage != nil {
evt.Str("proto_name", string(msg.DecryptedMessage.ProtoReflect().Descriptor().FullName()))
}
} }
evt.Msg("Received response") evt.Msg("Received response")
ch <- resp ch <- msg
return true return true
} }

View file

@ -160,7 +160,7 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
} }
currentBlock := accumulatedData currentBlock := accumulatedData
accumulatedData = accumulatedData[:0] accumulatedData = accumulatedData[:0]
msg := &gmproto.InternalMessage{} msg := &gmproto.LongPollingPayload{}
err = pblite.Unmarshal(currentBlock, msg) err = pblite.Unmarshal(currentBlock, msg)
if err != nil { if err != nil {
r.client.Logger.Err(err).Msg("Error deserializing pblite message") r.client.Logger.Err(err).Msg("Error deserializing pblite message")
@ -168,7 +168,7 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
} }
switch { switch {
case msg.GetData() != nil: case msg.GetData() != nil:
r.HandleRPCMsg(msg) r.HandleRPCMsg(msg.GetData())
case msg.GetAck() != nil: case msg.GetAck() != nil:
r.client.Logger.Debug().Int32("count", msg.GetAck().GetCount()).Msg("Got startup ack count message") r.client.Logger.Debug().Int32("count", msg.GetAck().GetCount()).Msg("Got startup ack count message")
r.skipCount = int(msg.GetAck().GetCount()) r.skipCount = int(msg.GetAck().GetCount())

View file

@ -22,9 +22,9 @@ func (c *Client) IsBugleDefault() (*gmproto.IsBugleDefaultResponse, error) {
return nil, err return nil, err
} }
res, ok := response.Data.Decrypted.(*gmproto.IsBugleDefaultResponse) res, ok := response.DecryptedMessage.(*gmproto.IsBugleDefaultResponse)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.IsBugleDefaultResponse", response.Data.Decrypted) return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.IsBugleDefaultResponse", response.DecryptedMessage)
} }
return res, nil return res, nil

View file

@ -20,7 +20,7 @@ import (
type SessionHandler struct { type SessionHandler struct {
client *Client client *Client
responseWaiters map[string]chan<- *pblite.Response responseWaiters map[string]chan<- *IncomingRPCMessage
responseWaitersLock sync.Mutex responseWaitersLock sync.Mutex
ackMapLock sync.Mutex ackMapLock sync.Mutex
@ -46,7 +46,7 @@ func (s *SessionHandler) sendMessageNoResponse(actionType gmproto.ActionType, en
return err return err
} }
func (s *SessionHandler) sendAsyncMessage(actionType gmproto.ActionType, encryptedData proto.Message) (<-chan *pblite.Response, error) { func (s *SessionHandler) sendAsyncMessage(actionType gmproto.ActionType, encryptedData proto.Message) (<-chan *IncomingRPCMessage, error) {
requestID, payload, _, buildErr := s.buildMessage(actionType, encryptedData) requestID, payload, _, buildErr := s.buildMessage(actionType, encryptedData)
if buildErr != nil { if buildErr != nil {
return nil, buildErr return nil, buildErr
@ -61,7 +61,7 @@ func (s *SessionHandler) sendAsyncMessage(actionType gmproto.ActionType, encrypt
return ch, nil return ch, nil
} }
func (s *SessionHandler) sendMessage(actionType gmproto.ActionType, encryptedData proto.Message) (*pblite.Response, error) { func (s *SessionHandler) sendMessage(actionType gmproto.ActionType, encryptedData proto.Message) (*IncomingRPCMessage, error) {
ch, err := s.sendAsyncMessage(actionType, encryptedData) ch, err := s.sendAsyncMessage(actionType, encryptedData)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -2,50 +2,48 @@ package libgm
import ( import (
"go.mau.fi/mautrix-gmessages/libgm/events" "go.mau.fi/mautrix-gmessages/libgm/events"
"go.mau.fi/mautrix-gmessages/libgm/pblite"
"go.mau.fi/mautrix-gmessages/libgm/gmproto" "go.mau.fi/mautrix-gmessages/libgm/gmproto"
) )
func (c *Client) handleUpdatesEvent(res *pblite.Response) { func (c *Client) handleUpdatesEvent(msg *IncomingRPCMessage) {
switch res.Data.Action { switch msg.Message.Action {
case gmproto.ActionType_GET_UPDATES: case gmproto.ActionType_GET_UPDATES:
data, ok := res.Data.Decrypted.(*gmproto.UpdateEvents) data, ok := msg.DecryptedMessage.(*gmproto.UpdateEvents)
if !ok { if !ok {
c.Logger.Error().Type("data_type", res.Data.Decrypted).Msg("Unexpected data type in GET_UPDATES event") c.Logger.Error().Type("data_type", msg.DecryptedMessage).Msg("Unexpected data type in GET_UPDATES event")
return return
} }
switch evt := data.Event.(type) { switch evt := data.Event.(type) {
case *gmproto.UpdateEvents_UserAlertEvent: case *gmproto.UpdateEvents_UserAlertEvent:
c.rpc.logContent(res) c.rpc.logContent(msg)
c.handleUserAlertEvent(res, evt.UserAlertEvent) c.handleUserAlertEvent(msg, evt.UserAlertEvent)
case *gmproto.UpdateEvents_SettingsEvent: case *gmproto.UpdateEvents_SettingsEvent:
c.rpc.logContent(res) c.rpc.logContent(msg)
c.triggerEvent(evt.SettingsEvent) c.triggerEvent(evt.SettingsEvent)
case *gmproto.UpdateEvents_ConversationEvent: case *gmproto.UpdateEvents_ConversationEvent:
if c.rpc.deduplicateUpdate(res) { if c.rpc.deduplicateUpdate(msg) {
return return
} }
c.triggerEvent(evt.ConversationEvent.GetData()) c.triggerEvent(evt.ConversationEvent.GetData())
case *gmproto.UpdateEvents_MessageEvent: case *gmproto.UpdateEvents_MessageEvent:
if c.rpc.deduplicateUpdate(res) { if c.rpc.deduplicateUpdate(msg) {
return return
} }
c.triggerEvent(evt.MessageEvent.GetData()) c.triggerEvent(evt.MessageEvent.GetData())
case *gmproto.UpdateEvents_TypingEvent: case *gmproto.UpdateEvents_TypingEvent:
c.rpc.logContent(res) c.rpc.logContent(msg)
c.triggerEvent(evt.TypingEvent.GetData()) c.triggerEvent(evt.TypingEvent.GetData())
default: default:
c.Logger.Trace().Any("evt", evt).Msg("Got unknown event type") c.Logger.Trace().Any("evt", evt).Msg("Got unknown event type")
} }
default: default:
c.Logger.Trace().Any("response", res).Msg("Got unexpected response") c.Logger.Trace().Any("response", msg).Msg("Got unexpected response")
} }
} }
@ -62,11 +60,11 @@ func (c *Client) handleClientReady(newSessionId string) {
c.triggerEvent(readyEvt) c.triggerEvent(readyEvt)
} }
func (c *Client) handleUserAlertEvent(res *pblite.Response, data *gmproto.UserAlertEvent) { func (c *Client) handleUserAlertEvent(msg *IncomingRPCMessage, data *gmproto.UserAlertEvent) {
alertType := data.AlertType alertType := data.AlertType
switch alertType { switch alertType {
case gmproto.AlertType_BROWSER_ACTIVE: case gmproto.AlertType_BROWSER_ACTIVE:
newSessionID := res.Data.RequestID newSessionID := msg.Message.SessionID
c.Logger.Debug().Any("session_id", newSessionID).Msg("Got browser active notification") c.Logger.Debug().Any("session_id", newSessionID).Msg("Got browser active notification")
if newSessionID != c.sessionHandler.sessionID { if newSessionID != c.sessionHandler.sessionID {
evt := events.NewBrowserActive(newSessionID) evt := events.NewBrowserActive(newSessionID)