diff --git a/libgm/client.go b/libgm/client.go index f361da2..460dd10 100644 --- a/libgm/client.go +++ b/libgm/client.go @@ -17,7 +17,6 @@ import ( "go.mau.fi/mautrix-gmessages/libgm/crypto" "go.mau.fi/mautrix-gmessages/libgm/events" "go.mau.fi/mautrix-gmessages/libgm/gmproto" - "go.mau.fi/mautrix-gmessages/libgm/pblite" "go.mau.fi/mautrix-gmessages/libgm/util" ) @@ -75,8 +74,7 @@ func NewClient(authData *AuthData, logger zerolog.Logger) *Client { http: &http.Client{}, } sessionHandler.client = cli - rpc := &RPC{client: cli, http: &http.Client{Transport: &http.Transport{Proxy: cli.proxy}}} - cli.rpc = rpc + cli.rpc = &RPC{client: cli} cli.FetchConfigVersion() return cli } @@ -234,7 +232,7 @@ func (c *Client) refreshAuthToken() error { return err } - payload, err := pblite.Marshal(&gmproto.RegisterRefreshRequest{ + payload := &gmproto.RegisterRefreshRequest{ MessageAuth: &gmproto.AuthMessage{ RequestID: requestID, TachyonAuthToken: c.AuthData.TachyonAuthToken, @@ -245,37 +243,18 @@ func (c *Client) refreshAuthToken() error { Signature: sig, EmptyRefreshArr: &gmproto.RegisterRefreshRequest_NestedEmptyArr{EmptyArr: &gmproto.EmptyArr{}}, MessageType: 2, // hmm - }) + } + + resp, err := typedHTTPResponse[*gmproto.RegisterRefreshResponse]( + c.makeProtobufHTTPRequest(util.RegisterRefreshURL, payload, ContentTypePBLite), + ) if err != nil { return err } - refreshResponse, requestErr := c.rpc.sendMessageRequest(util.RegisterRefreshURL, payload) - if requestErr != nil { - return requestErr - } - - if refreshResponse.StatusCode == 401 { - return fmt.Errorf("failed to refresh auth token: unauthorized (try reauthenticating through qr code)") - } - - if refreshResponse.StatusCode == 400 { - return fmt.Errorf("failed to refresh auth token: signature failed") - } - responseBody, readErr := io.ReadAll(refreshResponse.Body) - if readErr != nil { - return readErr - } - - resp := &gmproto.RegisterRefreshResponse{} - deserializeErr := pblite.Unmarshal(responseBody, resp) - if deserializeErr != nil { - return deserializeErr - } - token := resp.GetTokenData().GetTachyonAuthToken() if token == nil { - return fmt.Errorf("failed to refresh auth token: something happened") + return fmt.Errorf("no tachyon auth token in refresh response") } validFor, _ := strconv.ParseInt(resp.GetTokenData().GetValidFor(), 10, 64) diff --git a/libgm/events/ready.go b/libgm/events/ready.go index c9e20ec..25d42d1 100644 --- a/libgm/events/ready.go +++ b/libgm/events/ready.go @@ -27,6 +27,9 @@ type HTTPError struct { } func (he HTTPError) Error() string { + if he.Action == "" { + return fmt.Sprintf("unexpected http %d", he.Resp.StatusCode) + } return fmt.Sprintf("http %d while %s", he.Resp.StatusCode, he.Action) } diff --git a/libgm/gmproto/rpc.pb.go b/libgm/gmproto/rpc.pb.go index e68da99..3bdb6a5 100644 --- a/libgm/gmproto/rpc.pb.go +++ b/libgm/gmproto/rpc.pb.go @@ -789,6 +789,54 @@ func (x *OutgoingRPCData) GetSessionID() string { return "" } +type OutgoingRPCResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // This is not present for AckMessage responses, only for SendMessage + Timestamp *string `protobuf:"bytes,2,opt,name=timestamp,proto3,oneof" json:"timestamp,omitempty"` +} + +func (x *OutgoingRPCResponse) Reset() { + *x = OutgoingRPCResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_rpc_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *OutgoingRPCResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OutgoingRPCResponse) ProtoMessage() {} + +func (x *OutgoingRPCResponse) ProtoReflect() protoreflect.Message { + mi := &file_rpc_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OutgoingRPCResponse.ProtoReflect.Descriptor instead. +func (*OutgoingRPCResponse) Descriptor() ([]byte, []int) { + return file_rpc_proto_rawDescGZIP(), []int{6} +} + +func (x *OutgoingRPCResponse) GetTimestamp() string { + if x != nil && x.Timestamp != nil { + return *x.Timestamp + } + return "" +} + type OutgoingRPCMessage_Auth struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -802,7 +850,7 @@ type OutgoingRPCMessage_Auth struct { func (x *OutgoingRPCMessage_Auth) Reset() { *x = OutgoingRPCMessage_Auth{} if protoimpl.UnsafeEnabled { - mi := &file_rpc_proto_msgTypes[6] + mi := &file_rpc_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -815,7 +863,7 @@ func (x *OutgoingRPCMessage_Auth) String() string { func (*OutgoingRPCMessage_Auth) ProtoMessage() {} func (x *OutgoingRPCMessage_Auth) ProtoReflect() protoreflect.Message { - mi := &file_rpc_proto_msgTypes[6] + mi := &file_rpc_proto_msgTypes[7] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -867,7 +915,7 @@ type OutgoingRPCMessage_Data struct { func (x *OutgoingRPCMessage_Data) Reset() { *x = OutgoingRPCMessage_Data{} if protoimpl.UnsafeEnabled { - mi := &file_rpc_proto_msgTypes[7] + mi := &file_rpc_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -880,7 +928,7 @@ func (x *OutgoingRPCMessage_Data) String() string { func (*OutgoingRPCMessage_Data) ProtoMessage() {} func (x *OutgoingRPCMessage_Data) ProtoReflect() protoreflect.Message { - mi := &file_rpc_proto_msgTypes[7] + mi := &file_rpc_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -936,7 +984,7 @@ type OutgoingRPCMessage_Data_Type struct { func (x *OutgoingRPCMessage_Data_Type) Reset() { *x = OutgoingRPCMessage_Data_Type{} if protoimpl.UnsafeEnabled { - mi := &file_rpc_proto_msgTypes[8] + mi := &file_rpc_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -949,7 +997,7 @@ func (x *OutgoingRPCMessage_Data_Type) String() string { func (*OutgoingRPCMessage_Data_Type) ProtoMessage() {} func (x *OutgoingRPCMessage_Data_Type) ProtoReflect() protoreflect.Message { - mi := &file_rpc_proto_msgTypes[8] + mi := &file_rpc_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -979,6 +1027,54 @@ func (x *OutgoingRPCMessage_Data_Type) GetMessageType() MessageType { return MessageType_UNKNOWN_MESSAGE_TYPE } +type OutgoingRPCResponse_SomeIdentifier struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // 1 -> unknown + SomeNumber string `protobuf:"bytes,2,opt,name=someNumber,proto3" json:"someNumber,omitempty"` +} + +func (x *OutgoingRPCResponse_SomeIdentifier) Reset() { + *x = OutgoingRPCResponse_SomeIdentifier{} + if protoimpl.UnsafeEnabled { + mi := &file_rpc_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *OutgoingRPCResponse_SomeIdentifier) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OutgoingRPCResponse_SomeIdentifier) ProtoMessage() {} + +func (x *OutgoingRPCResponse_SomeIdentifier) ProtoReflect() protoreflect.Message { + mi := &file_rpc_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OutgoingRPCResponse_SomeIdentifier.ProtoReflect.Descriptor instead. +func (*OutgoingRPCResponse_SomeIdentifier) Descriptor() ([]byte, []int) { + return file_rpc_proto_rawDescGZIP(), []int{6, 0} +} + +func (x *OutgoingRPCResponse_SomeIdentifier) GetSomeNumber() string { + if x != nil { + return x.SomeNumber + } + return "" +} + var File_rpc_proto protoreflect.FileDescriptor //go:embed rpc.pb.raw @@ -997,43 +1093,45 @@ func file_rpc_proto_rawDescGZIP() []byte { } var file_rpc_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_rpc_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_rpc_proto_msgTypes = make([]protoimpl.MessageInfo, 11) var file_rpc_proto_goTypes = []interface{}{ - (BugleRoute)(0), // 0: rpc.BugleRoute - (ActionType)(0), // 1: rpc.ActionType - (MessageType)(0), // 2: rpc.MessageType - (*StartAckMessage)(nil), // 3: rpc.StartAckMessage - (*LongPollingPayload)(nil), // 4: rpc.LongPollingPayload - (*IncomingRPCMessage)(nil), // 5: rpc.IncomingRPCMessage - (*RPCMessageData)(nil), // 6: rpc.RPCMessageData - (*OutgoingRPCMessage)(nil), // 7: rpc.OutgoingRPCMessage - (*OutgoingRPCData)(nil), // 8: rpc.OutgoingRPCData - (*OutgoingRPCMessage_Auth)(nil), // 9: rpc.OutgoingRPCMessage.Auth - (*OutgoingRPCMessage_Data)(nil), // 10: rpc.OutgoingRPCMessage.Data - (*OutgoingRPCMessage_Data_Type)(nil), // 11: rpc.OutgoingRPCMessage.Data.Type - (*EmptyArr)(nil), // 12: util.EmptyArr - (*Device)(nil), // 13: authentication.Device - (*ConfigVersion)(nil), // 14: authentication.ConfigVersion + (BugleRoute)(0), // 0: rpc.BugleRoute + (ActionType)(0), // 1: rpc.ActionType + (MessageType)(0), // 2: rpc.MessageType + (*StartAckMessage)(nil), // 3: rpc.StartAckMessage + (*LongPollingPayload)(nil), // 4: rpc.LongPollingPayload + (*IncomingRPCMessage)(nil), // 5: rpc.IncomingRPCMessage + (*RPCMessageData)(nil), // 6: rpc.RPCMessageData + (*OutgoingRPCMessage)(nil), // 7: rpc.OutgoingRPCMessage + (*OutgoingRPCData)(nil), // 8: rpc.OutgoingRPCData + (*OutgoingRPCResponse)(nil), // 9: rpc.OutgoingRPCResponse + (*OutgoingRPCMessage_Auth)(nil), // 10: rpc.OutgoingRPCMessage.Auth + (*OutgoingRPCMessage_Data)(nil), // 11: rpc.OutgoingRPCMessage.Data + (*OutgoingRPCMessage_Data_Type)(nil), // 12: rpc.OutgoingRPCMessage.Data.Type + (*OutgoingRPCResponse_SomeIdentifier)(nil), // 13: rpc.OutgoingRPCResponse.SomeIdentifier + (*EmptyArr)(nil), // 14: util.EmptyArr + (*Device)(nil), // 15: authentication.Device + (*ConfigVersion)(nil), // 16: authentication.ConfigVersion } var file_rpc_proto_depIdxs = []int32{ 5, // 0: rpc.LongPollingPayload.data:type_name -> rpc.IncomingRPCMessage - 12, // 1: rpc.LongPollingPayload.heartbeat:type_name -> util.EmptyArr + 14, // 1: rpc.LongPollingPayload.heartbeat:type_name -> util.EmptyArr 3, // 2: rpc.LongPollingPayload.ack:type_name -> rpc.StartAckMessage - 12, // 3: rpc.LongPollingPayload.startRead:type_name -> util.EmptyArr + 14, // 3: rpc.LongPollingPayload.startRead:type_name -> util.EmptyArr 0, // 4: rpc.IncomingRPCMessage.bugleRoute:type_name -> rpc.BugleRoute 2, // 5: rpc.IncomingRPCMessage.messageType:type_name -> rpc.MessageType - 13, // 6: rpc.IncomingRPCMessage.mobile:type_name -> authentication.Device - 13, // 7: rpc.IncomingRPCMessage.browser:type_name -> authentication.Device + 15, // 6: rpc.IncomingRPCMessage.mobile:type_name -> authentication.Device + 15, // 7: rpc.IncomingRPCMessage.browser:type_name -> authentication.Device 1, // 8: rpc.RPCMessageData.action:type_name -> rpc.ActionType - 13, // 9: rpc.OutgoingRPCMessage.mobile:type_name -> authentication.Device - 10, // 10: rpc.OutgoingRPCMessage.data:type_name -> rpc.OutgoingRPCMessage.Data - 9, // 11: rpc.OutgoingRPCMessage.auth:type_name -> rpc.OutgoingRPCMessage.Auth - 12, // 12: rpc.OutgoingRPCMessage.emptyArr:type_name -> util.EmptyArr + 15, // 9: rpc.OutgoingRPCMessage.mobile:type_name -> authentication.Device + 11, // 10: rpc.OutgoingRPCMessage.data:type_name -> rpc.OutgoingRPCMessage.Data + 10, // 11: rpc.OutgoingRPCMessage.auth:type_name -> rpc.OutgoingRPCMessage.Auth + 14, // 12: rpc.OutgoingRPCMessage.emptyArr:type_name -> util.EmptyArr 1, // 13: rpc.OutgoingRPCData.action:type_name -> rpc.ActionType - 14, // 14: rpc.OutgoingRPCMessage.Auth.configVersion:type_name -> authentication.ConfigVersion + 16, // 14: rpc.OutgoingRPCMessage.Auth.configVersion:type_name -> authentication.ConfigVersion 0, // 15: rpc.OutgoingRPCMessage.Data.bugleRoute:type_name -> rpc.BugleRoute - 11, // 16: rpc.OutgoingRPCMessage.Data.messageTypeData:type_name -> rpc.OutgoingRPCMessage.Data.Type - 12, // 17: rpc.OutgoingRPCMessage.Data.Type.emptyArr:type_name -> util.EmptyArr + 12, // 16: rpc.OutgoingRPCMessage.Data.messageTypeData:type_name -> rpc.OutgoingRPCMessage.Data.Type + 14, // 17: rpc.OutgoingRPCMessage.Data.Type.emptyArr:type_name -> util.EmptyArr 2, // 18: rpc.OutgoingRPCMessage.Data.Type.messageType:type_name -> rpc.MessageType 19, // [19:19] is the sub-list for method output_type 19, // [19:19] is the sub-list for method input_type @@ -1123,7 +1221,7 @@ func file_rpc_proto_init() { } } file_rpc_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*OutgoingRPCMessage_Auth); i { + switch v := v.(*OutgoingRPCResponse); i { case 0: return &v.state case 1: @@ -1135,7 +1233,7 @@ func file_rpc_proto_init() { } } file_rpc_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*OutgoingRPCMessage_Data); i { + switch v := v.(*OutgoingRPCMessage_Auth); i { case 0: return &v.state case 1: @@ -1147,6 +1245,18 @@ func file_rpc_proto_init() { } } file_rpc_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*OutgoingRPCMessage_Data); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_rpc_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*OutgoingRPCMessage_Data_Type); i { case 0: return &v.state @@ -1158,16 +1268,29 @@ func file_rpc_proto_init() { return nil } } + file_rpc_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*OutgoingRPCResponse_SomeIdentifier); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_rpc_proto_msgTypes[0].OneofWrappers = []interface{}{} file_rpc_proto_msgTypes[1].OneofWrappers = []interface{}{} + file_rpc_proto_msgTypes[6].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_rpc_proto_rawDesc, NumEnums: 3, - NumMessages: 9, + NumMessages: 11, NumExtensions: 0, NumServices: 0, }, diff --git a/libgm/gmproto/rpc.pb.raw b/libgm/gmproto/rpc.pb.raw index 7eb1165..51bf571 100644 Binary files a/libgm/gmproto/rpc.pb.raw and b/libgm/gmproto/rpc.pb.raw differ diff --git a/libgm/gmproto/rpc.proto b/libgm/gmproto/rpc.proto index 493f20b..f6946d9 100644 --- a/libgm/gmproto/rpc.proto +++ b/libgm/gmproto/rpc.proto @@ -85,6 +85,16 @@ message OutgoingRPCData { string sessionID = 6; } +message OutgoingRPCResponse { + message SomeIdentifier { + // 1 -> unknown + string someNumber = 2; + } + + // This is not present for AckMessage responses, only for SendMessage + optional string timestamp = 2; +} + enum BugleRoute { Unknown = 0; DataEvent = 19; diff --git a/libgm/http.go b/libgm/http.go new file mode 100644 index 0000000..b49de6d --- /dev/null +++ b/libgm/http.go @@ -0,0 +1,76 @@ +package libgm + +import ( + "bytes" + "fmt" + "io" + "mime" + "net/http" + + "google.golang.org/protobuf/proto" + + "go.mau.fi/mautrix-gmessages/libgm/events" + "go.mau.fi/mautrix-gmessages/libgm/pblite" + "go.mau.fi/mautrix-gmessages/libgm/util" +) + +const ContentTypeProtobuf = "application/x-protobuf" +const ContentTypePBLite = "application/json+protobuf" + +func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) { + var body []byte + var err error + switch contentType { + case ContentTypeProtobuf: + body, err = proto.Marshal(data) + case ContentTypePBLite: + body, err = pblite.Marshal(data) + default: + return nil, fmt.Errorf("unknown request content type %s", contentType) + } + if err != nil { + return nil, err + } + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + util.BuildRelayHeaders(req, contentType, "*/*") + res, reqErr := c.http.Do(req) + if reqErr != nil { + return res, reqErr + } + return res, nil +} + +func typedHTTPResponse[T proto.Message](resp *http.Response, err error) (parsed T, retErr error) { + if err != nil { + retErr = err + return + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + retErr = events.HTTPError{Resp: resp} + return + } + body, err := io.ReadAll(resp.Body) + if err != nil { + retErr = fmt.Errorf("failed to read response body: %w", err) + return + } + contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + retErr = fmt.Errorf("failed to parse content-type: %w", err) + return + } + parsed = parsed.ProtoReflect().New().Interface().(T) + switch contentType { + case ContentTypeProtobuf: + retErr = proto.Unmarshal(body, parsed) + case ContentTypePBLite: + retErr = pblite.Unmarshal(body, parsed) + default: + retErr = fmt.Errorf("unknown content type %s in response", contentType) + } + return +} diff --git a/libgm/pair.go b/libgm/pair.go index 3e57189..ab09a1e 100644 --- a/libgm/pair.go +++ b/libgm/pair.go @@ -1,12 +1,9 @@ package libgm import ( - "bytes" "crypto/x509" "encoding/base64" "fmt" - "io" - "net/http" "github.com/google/uuid" "google.golang.org/protobuf/proto" @@ -68,26 +65,13 @@ func (c *Client) completePairing(data *gmproto.PairedData) { } } -func (c *Client) makeRelayRequest(url string, body []byte) (*http.Response, error) { - req, err := http.NewRequest("POST", url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - util.BuildRelayHeaders(req, "application/x-protobuf", "*/*") - res, reqErr := c.http.Do(req) - if reqErr != nil { - return res, reqErr - } - return res, nil -} - func (c *Client) RegisterPhoneRelay() (*gmproto.RegisterPhoneRelayResponse, error) { key, err := x509.MarshalPKIXPublicKey(c.AuthData.RefreshKey.GetPublicKey()) if err != nil { return nil, err } - body, err := proto.Marshal(&gmproto.AuthenticationContainer{ + payload := &gmproto.AuthenticationContainer{ AuthMessage: &gmproto.AuthMessage{ RequestID: uuid.NewString(), Network: &util.Network, @@ -102,50 +86,24 @@ func (c *Client) RegisterPhoneRelay() (*gmproto.RegisterPhoneRelayResponse, erro }, }, }, - }) - if err != nil { - return nil, err } - relayResponse, reqErr := c.makeRelayRequest(util.RegisterPhoneRelayURL, body) - if reqErr != nil { - return nil, err - } - responseBody, err := io.ReadAll(relayResponse.Body) - if err != nil { - return nil, err - } - relayResponse.Body.Close() - res := &gmproto.RegisterPhoneRelayResponse{} - err = proto.Unmarshal(responseBody, res) - if err != nil { - return nil, err - } - return res, err + return typedHTTPResponse[*gmproto.RegisterPhoneRelayResponse]( + c.makeProtobufHTTPRequest(util.RegisterPhoneRelayURL, payload, ContentTypeProtobuf), + ) } func (c *Client) RefreshPhoneRelay() (string, error) { - body, err := proto.Marshal(&gmproto.AuthenticationContainer{ + payload := &gmproto.AuthenticationContainer{ AuthMessage: &gmproto.AuthMessage{ RequestID: uuid.NewString(), Network: &util.Network, TachyonAuthToken: c.AuthData.TachyonAuthToken, ConfigVersion: util.ConfigMessage, }, - }) - if err != nil { - return "", err } - relayResponse, err := c.makeRelayRequest(util.RefreshPhoneRelayURL, body) - if err != nil { - return "", err - } - responseBody, err := io.ReadAll(relayResponse.Body) - defer relayResponse.Body.Close() - if err != nil { - return "", err - } - res := &gmproto.RefreshPhoneRelayResponse{} - err = proto.Unmarshal(responseBody, res) + res, err := typedHTTPResponse[*gmproto.RefreshPhoneRelayResponse]( + c.makeProtobufHTTPRequest(util.RefreshPhoneRelayURL, payload, ContentTypeProtobuf), + ) if err != nil { return "", err } @@ -157,61 +115,31 @@ func (c *Client) RefreshPhoneRelay() (string, error) { } func (c *Client) GetWebEncryptionKey() (*gmproto.WebEncryptionKeyResponse, error) { - body, err := proto.Marshal(&gmproto.AuthenticationContainer{ + payload := &gmproto.AuthenticationContainer{ AuthMessage: &gmproto.AuthMessage{ RequestID: uuid.NewString(), TachyonAuthToken: c.AuthData.TachyonAuthToken, ConfigVersion: util.ConfigMessage, }, - }) - if err != nil { - return nil, err } - webKeyResponse, err := c.makeRelayRequest(util.GetWebEncryptionKeyURL, body) - if err != nil { - return nil, err - } - responseBody, err := io.ReadAll(webKeyResponse.Body) - defer webKeyResponse.Body.Close() - if err != nil { - return nil, err - } - parsedResponse := &gmproto.WebEncryptionKeyResponse{} - err = proto.Unmarshal(responseBody, parsedResponse) - if err != nil { - return nil, err - } - return parsedResponse, nil + return typedHTTPResponse[*gmproto.WebEncryptionKeyResponse]( + c.makeProtobufHTTPRequest(util.GetWebEncryptionKeyURL, payload, ContentTypeProtobuf), + ) } func (c *Client) Unpair() (*gmproto.RevokeRelayPairingResponse, error) { if c.AuthData.TachyonAuthToken == nil || c.AuthData.Browser == nil { return nil, nil } - payload, err := proto.Marshal(&gmproto.RevokeRelayPairingRequest{ + payload := &gmproto.RevokeRelayPairingRequest{ AuthMessage: &gmproto.AuthMessage{ RequestID: uuid.NewString(), TachyonAuthToken: c.AuthData.TachyonAuthToken, ConfigVersion: util.ConfigMessage, }, Browser: c.AuthData.Browser, - }) - if err != nil { - return nil, err } - revokeResp, err := c.makeRelayRequest(util.RevokeRelayPairingURL, payload) - if err != nil { - return nil, err - } - responseBody, err := io.ReadAll(revokeResp.Body) - defer revokeResp.Body.Close() - if err != nil { - return nil, err - } - parsedResponse := &gmproto.RevokeRelayPairingResponse{} - err = proto.Unmarshal(responseBody, parsedResponse) - if err != nil { - return nil, err - } - return parsedResponse, nil + return typedHTTPResponse[*gmproto.RevokeRelayPairingResponse]( + c.makeProtobufHTTPRequest(util.RevokeRelayPairingURL, payload, ContentTypeProtobuf), + ) } diff --git a/libgm/rpc.go b/libgm/rpc.go index 4cd1cc2..dbc90c4 100644 --- a/libgm/rpc.go +++ b/libgm/rpc.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "net/http" "time" "github.com/google/uuid" @@ -23,7 +22,6 @@ import ( type RPC struct { client *Client - http *http.Client conn io.ReadCloser stopping bool listenID int @@ -49,7 +47,7 @@ func (r *RPC) ListenReceiveMessages(loggedIn bool) { return } r.client.Logger.Debug().Msg("Starting new long-polling request") - receivePayload, err := pblite.Marshal(&gmproto.ReceiveMessagesRequest{ + payload := &gmproto.ReceiveMessagesRequest{ Auth: &gmproto.AuthMessage{ RequestID: listenReqID, TachyonAuthToken: r.client.AuthData.TachyonAuthToken, @@ -58,19 +56,11 @@ func (r *RPC) ListenReceiveMessages(loggedIn bool) { Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject2{ Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject1{}, }, - }) - if err != nil { - panic(fmt.Errorf("Error marshaling request: %v", err)) } - req, err := http.NewRequest(http.MethodPost, util.ReceiveMessagesURL, bytes.NewReader(receivePayload)) + resp, err := r.client.makeProtobufHTTPRequest(util.ReceiveMessagesURL, payload, ContentTypePBLite) if err != nil { - panic(fmt.Errorf("Error creating request: %v", err)) - } - util.BuildRelayHeaders(req, "application/json+protobuf", "*/*") - resp, reqErr := r.http.Do(req) - if reqErr != nil { if loggedIn { - r.client.triggerEvent(&events.ListenTemporaryError{Error: reqErr}) + r.client.triggerEvent(&events.ListenTemporaryError{Error: err}) } errored = true r.client.Logger.Err(err).Msg("Error making listen request, retrying in 5 seconds") @@ -203,16 +193,3 @@ func (r *RPC) CloseConnection() { r.conn = nil } } - -func (r *RPC) sendMessageRequest(url string, payload []byte) (*http.Response, error) { - req, err := http.NewRequest("POST", url, bytes.NewReader(payload)) - if err != nil { - return nil, fmt.Errorf("error creating request: %w", err) - } - util.BuildRelayHeaders(req, "application/json+protobuf", "*/*") - resp, reqErr := r.client.http.Do(req) - if reqErr != nil { - return nil, fmt.Errorf("error making request: %w", err) - } - return resp, reqErr -} diff --git a/libgm/session_handler.go b/libgm/session_handler.go index 8cd2e02..e03ace1 100644 --- a/libgm/session_handler.go +++ b/libgm/session_handler.go @@ -10,8 +10,6 @@ import ( "golang.org/x/exp/slices" "google.golang.org/protobuf/proto" - "go.mau.fi/mautrix-gmessages/libgm/pblite" - "go.mau.fi/mautrix-gmessages/libgm/gmproto" "go.mau.fi/mautrix-gmessages/libgm/util" ) @@ -41,7 +39,9 @@ func (s *SessionHandler) sendMessageNoResponse(params SendMessageParams) error { return err } - _, err = s.client.rpc.sendMessageRequest(util.SendMessageURL, payload) + _, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse]( + s.client.makeProtobufHTTPRequest(util.SendMessageURL, payload, ContentTypePBLite), + ) return err } @@ -52,10 +52,12 @@ func (s *SessionHandler) sendAsyncMessage(params SendMessageParams) (<-chan *Inc } ch := s.waitResponse(requestID) - _, reqErr := s.client.rpc.sendMessageRequest(util.SendMessageURL, payload) - if reqErr != nil { + _, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse]( + s.client.makeProtobufHTTPRequest(util.SendMessageURL, payload, ContentTypePBLite), + ) + if err != nil { s.cancelResponse(requestID, ch) - return nil, reqErr + return nil, err } return ch, nil } @@ -142,7 +144,7 @@ type SendMessageParams struct { MessageType gmproto.MessageType } -func (s *SessionHandler) buildMessage(params SendMessageParams) (string, []byte, error) { +func (s *SessionHandler) buildMessage(params SendMessageParams) (string, proto.Message, error) { var requestID string var err error sessionID := s.client.sessionHandler.sessionID @@ -199,9 +201,7 @@ func (s *SessionHandler) buildMessage(params SendMessageParams) (string, []byte, return "", nil, err } - var marshaledMessage []byte - marshaledMessage, err = pblite.Marshal(message) - return requestID, marshaledMessage, err + return requestID, message, err } func (s *SessionHandler) queueMessageAck(messageID string) { @@ -243,7 +243,7 @@ func (s *SessionHandler) sendAckRequest() { Device: s.client.AuthData.Browser, } } - ackMessagePayload := &gmproto.AckMessageRequest{ + payload := &gmproto.AckMessageRequest{ AuthData: &gmproto.AuthMessage{ RequestID: uuid.NewString(), TachyonAuthToken: s.client.AuthData.TachyonAuthToken, @@ -252,13 +252,13 @@ func (s *SessionHandler) sendAckRequest() { EmptyArr: &gmproto.EmptyArr{}, Acks: ackMessages, } - jsonData, err := pblite.Marshal(ackMessagePayload) + _, err := typedHTTPResponse[*gmproto.OutgoingRPCResponse]( + s.client.makeProtobufHTTPRequest(util.AckMessagesURL, payload, ContentTypePBLite), + ) if err != nil { - panic(err) + // TODO retry? + s.client.Logger.Err(err).Strs("message_ids", dataToAck).Msg("Failed to send acks") + } else { + s.client.Logger.Debug().Strs("message_ids", dataToAck).Msg("Sent acks") } - _, err = s.client.rpc.sendMessageRequest(util.AckMessagesURL, jsonData) - if err != nil { - panic(err) - } - s.client.Logger.Debug().Strs("message_ids", dataToAck).Msg("Sent acks") }