Refactor all protobuf HTTP request sending into shared functions

This commit is contained in:
Tulir Asokan 2023-07-19 13:58:53 +03:00
parent 3e2348447a
commit d99da61869
9 changed files with 293 additions and 197 deletions

View file

@ -17,7 +17,6 @@ import (
"go.mau.fi/mautrix-gmessages/libgm/crypto" "go.mau.fi/mautrix-gmessages/libgm/crypto"
"go.mau.fi/mautrix-gmessages/libgm/events" "go.mau.fi/mautrix-gmessages/libgm/events"
"go.mau.fi/mautrix-gmessages/libgm/gmproto" "go.mau.fi/mautrix-gmessages/libgm/gmproto"
"go.mau.fi/mautrix-gmessages/libgm/pblite"
"go.mau.fi/mautrix-gmessages/libgm/util" "go.mau.fi/mautrix-gmessages/libgm/util"
) )
@ -75,8 +74,7 @@ func NewClient(authData *AuthData, logger zerolog.Logger) *Client {
http: &http.Client{}, http: &http.Client{},
} }
sessionHandler.client = cli sessionHandler.client = cli
rpc := &RPC{client: cli, http: &http.Client{Transport: &http.Transport{Proxy: cli.proxy}}} cli.rpc = &RPC{client: cli}
cli.rpc = rpc
cli.FetchConfigVersion() cli.FetchConfigVersion()
return cli return cli
} }
@ -234,7 +232,7 @@ func (c *Client) refreshAuthToken() error {
return err return err
} }
payload, err := pblite.Marshal(&gmproto.RegisterRefreshRequest{ payload := &gmproto.RegisterRefreshRequest{
MessageAuth: &gmproto.AuthMessage{ MessageAuth: &gmproto.AuthMessage{
RequestID: requestID, RequestID: requestID,
TachyonAuthToken: c.AuthData.TachyonAuthToken, TachyonAuthToken: c.AuthData.TachyonAuthToken,
@ -245,37 +243,18 @@ func (c *Client) refreshAuthToken() error {
Signature: sig, Signature: sig,
EmptyRefreshArr: &gmproto.RegisterRefreshRequest_NestedEmptyArr{EmptyArr: &gmproto.EmptyArr{}}, EmptyRefreshArr: &gmproto.RegisterRefreshRequest_NestedEmptyArr{EmptyArr: &gmproto.EmptyArr{}},
MessageType: 2, // hmm MessageType: 2, // hmm
}) }
resp, err := typedHTTPResponse[*gmproto.RegisterRefreshResponse](
c.makeProtobufHTTPRequest(util.RegisterRefreshURL, payload, ContentTypePBLite),
)
if err != nil { if err != nil {
return err return err
} }
refreshResponse, requestErr := c.rpc.sendMessageRequest(util.RegisterRefreshURL, payload)
if requestErr != nil {
return requestErr
}
if refreshResponse.StatusCode == 401 {
return fmt.Errorf("failed to refresh auth token: unauthorized (try reauthenticating through qr code)")
}
if refreshResponse.StatusCode == 400 {
return fmt.Errorf("failed to refresh auth token: signature failed")
}
responseBody, readErr := io.ReadAll(refreshResponse.Body)
if readErr != nil {
return readErr
}
resp := &gmproto.RegisterRefreshResponse{}
deserializeErr := pblite.Unmarshal(responseBody, resp)
if deserializeErr != nil {
return deserializeErr
}
token := resp.GetTokenData().GetTachyonAuthToken() token := resp.GetTokenData().GetTachyonAuthToken()
if token == nil { if token == nil {
return fmt.Errorf("failed to refresh auth token: something happened") return fmt.Errorf("no tachyon auth token in refresh response")
} }
validFor, _ := strconv.ParseInt(resp.GetTokenData().GetValidFor(), 10, 64) validFor, _ := strconv.ParseInt(resp.GetTokenData().GetValidFor(), 10, 64)

View file

@ -27,6 +27,9 @@ type HTTPError struct {
} }
func (he HTTPError) Error() string { func (he HTTPError) Error() string {
if he.Action == "" {
return fmt.Sprintf("unexpected http %d", he.Resp.StatusCode)
}
return fmt.Sprintf("http %d while %s", he.Resp.StatusCode, he.Action) return fmt.Sprintf("http %d while %s", he.Resp.StatusCode, he.Action)
} }

View file

@ -789,6 +789,54 @@ func (x *OutgoingRPCData) GetSessionID() string {
return "" return ""
} }
type OutgoingRPCResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// This is not present for AckMessage responses, only for SendMessage
Timestamp *string `protobuf:"bytes,2,opt,name=timestamp,proto3,oneof" json:"timestamp,omitempty"`
}
func (x *OutgoingRPCResponse) Reset() {
*x = OutgoingRPCResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_rpc_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *OutgoingRPCResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*OutgoingRPCResponse) ProtoMessage() {}
func (x *OutgoingRPCResponse) ProtoReflect() protoreflect.Message {
mi := &file_rpc_proto_msgTypes[6]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use OutgoingRPCResponse.ProtoReflect.Descriptor instead.
func (*OutgoingRPCResponse) Descriptor() ([]byte, []int) {
return file_rpc_proto_rawDescGZIP(), []int{6}
}
func (x *OutgoingRPCResponse) GetTimestamp() string {
if x != nil && x.Timestamp != nil {
return *x.Timestamp
}
return ""
}
type OutgoingRPCMessage_Auth struct { type OutgoingRPCMessage_Auth struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -802,7 +850,7 @@ type OutgoingRPCMessage_Auth struct {
func (x *OutgoingRPCMessage_Auth) Reset() { func (x *OutgoingRPCMessage_Auth) Reset() {
*x = OutgoingRPCMessage_Auth{} *x = OutgoingRPCMessage_Auth{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_rpc_proto_msgTypes[6] mi := &file_rpc_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -815,7 +863,7 @@ func (x *OutgoingRPCMessage_Auth) String() string {
func (*OutgoingRPCMessage_Auth) ProtoMessage() {} func (*OutgoingRPCMessage_Auth) ProtoMessage() {}
func (x *OutgoingRPCMessage_Auth) ProtoReflect() protoreflect.Message { func (x *OutgoingRPCMessage_Auth) ProtoReflect() protoreflect.Message {
mi := &file_rpc_proto_msgTypes[6] mi := &file_rpc_proto_msgTypes[7]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -867,7 +915,7 @@ type OutgoingRPCMessage_Data struct {
func (x *OutgoingRPCMessage_Data) Reset() { func (x *OutgoingRPCMessage_Data) Reset() {
*x = OutgoingRPCMessage_Data{} *x = OutgoingRPCMessage_Data{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_rpc_proto_msgTypes[7] mi := &file_rpc_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -880,7 +928,7 @@ func (x *OutgoingRPCMessage_Data) String() string {
func (*OutgoingRPCMessage_Data) ProtoMessage() {} func (*OutgoingRPCMessage_Data) ProtoMessage() {}
func (x *OutgoingRPCMessage_Data) ProtoReflect() protoreflect.Message { func (x *OutgoingRPCMessage_Data) ProtoReflect() protoreflect.Message {
mi := &file_rpc_proto_msgTypes[7] mi := &file_rpc_proto_msgTypes[8]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -936,7 +984,7 @@ type OutgoingRPCMessage_Data_Type struct {
func (x *OutgoingRPCMessage_Data_Type) Reset() { func (x *OutgoingRPCMessage_Data_Type) Reset() {
*x = OutgoingRPCMessage_Data_Type{} *x = OutgoingRPCMessage_Data_Type{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_rpc_proto_msgTypes[8] mi := &file_rpc_proto_msgTypes[9]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -949,7 +997,7 @@ func (x *OutgoingRPCMessage_Data_Type) String() string {
func (*OutgoingRPCMessage_Data_Type) ProtoMessage() {} func (*OutgoingRPCMessage_Data_Type) ProtoMessage() {}
func (x *OutgoingRPCMessage_Data_Type) ProtoReflect() protoreflect.Message { func (x *OutgoingRPCMessage_Data_Type) ProtoReflect() protoreflect.Message {
mi := &file_rpc_proto_msgTypes[8] mi := &file_rpc_proto_msgTypes[9]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -979,6 +1027,54 @@ func (x *OutgoingRPCMessage_Data_Type) GetMessageType() MessageType {
return MessageType_UNKNOWN_MESSAGE_TYPE return MessageType_UNKNOWN_MESSAGE_TYPE
} }
type OutgoingRPCResponse_SomeIdentifier struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// 1 -> unknown
SomeNumber string `protobuf:"bytes,2,opt,name=someNumber,proto3" json:"someNumber,omitempty"`
}
func (x *OutgoingRPCResponse_SomeIdentifier) Reset() {
*x = OutgoingRPCResponse_SomeIdentifier{}
if protoimpl.UnsafeEnabled {
mi := &file_rpc_proto_msgTypes[10]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *OutgoingRPCResponse_SomeIdentifier) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*OutgoingRPCResponse_SomeIdentifier) ProtoMessage() {}
func (x *OutgoingRPCResponse_SomeIdentifier) ProtoReflect() protoreflect.Message {
mi := &file_rpc_proto_msgTypes[10]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use OutgoingRPCResponse_SomeIdentifier.ProtoReflect.Descriptor instead.
func (*OutgoingRPCResponse_SomeIdentifier) Descriptor() ([]byte, []int) {
return file_rpc_proto_rawDescGZIP(), []int{6, 0}
}
func (x *OutgoingRPCResponse_SomeIdentifier) GetSomeNumber() string {
if x != nil {
return x.SomeNumber
}
return ""
}
var File_rpc_proto protoreflect.FileDescriptor var File_rpc_proto protoreflect.FileDescriptor
//go:embed rpc.pb.raw //go:embed rpc.pb.raw
@ -997,43 +1093,45 @@ func file_rpc_proto_rawDescGZIP() []byte {
} }
var file_rpc_proto_enumTypes = make([]protoimpl.EnumInfo, 3) var file_rpc_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
var file_rpc_proto_msgTypes = make([]protoimpl.MessageInfo, 9) var file_rpc_proto_msgTypes = make([]protoimpl.MessageInfo, 11)
var file_rpc_proto_goTypes = []interface{}{ var file_rpc_proto_goTypes = []interface{}{
(BugleRoute)(0), // 0: rpc.BugleRoute (BugleRoute)(0), // 0: rpc.BugleRoute
(ActionType)(0), // 1: rpc.ActionType (ActionType)(0), // 1: rpc.ActionType
(MessageType)(0), // 2: rpc.MessageType (MessageType)(0), // 2: rpc.MessageType
(*StartAckMessage)(nil), // 3: rpc.StartAckMessage (*StartAckMessage)(nil), // 3: rpc.StartAckMessage
(*LongPollingPayload)(nil), // 4: rpc.LongPollingPayload (*LongPollingPayload)(nil), // 4: rpc.LongPollingPayload
(*IncomingRPCMessage)(nil), // 5: rpc.IncomingRPCMessage (*IncomingRPCMessage)(nil), // 5: rpc.IncomingRPCMessage
(*RPCMessageData)(nil), // 6: rpc.RPCMessageData (*RPCMessageData)(nil), // 6: rpc.RPCMessageData
(*OutgoingRPCMessage)(nil), // 7: rpc.OutgoingRPCMessage (*OutgoingRPCMessage)(nil), // 7: rpc.OutgoingRPCMessage
(*OutgoingRPCData)(nil), // 8: rpc.OutgoingRPCData (*OutgoingRPCData)(nil), // 8: rpc.OutgoingRPCData
(*OutgoingRPCMessage_Auth)(nil), // 9: rpc.OutgoingRPCMessage.Auth (*OutgoingRPCResponse)(nil), // 9: rpc.OutgoingRPCResponse
(*OutgoingRPCMessage_Data)(nil), // 10: rpc.OutgoingRPCMessage.Data (*OutgoingRPCMessage_Auth)(nil), // 10: rpc.OutgoingRPCMessage.Auth
(*OutgoingRPCMessage_Data_Type)(nil), // 11: rpc.OutgoingRPCMessage.Data.Type (*OutgoingRPCMessage_Data)(nil), // 11: rpc.OutgoingRPCMessage.Data
(*EmptyArr)(nil), // 12: util.EmptyArr (*OutgoingRPCMessage_Data_Type)(nil), // 12: rpc.OutgoingRPCMessage.Data.Type
(*Device)(nil), // 13: authentication.Device (*OutgoingRPCResponse_SomeIdentifier)(nil), // 13: rpc.OutgoingRPCResponse.SomeIdentifier
(*ConfigVersion)(nil), // 14: authentication.ConfigVersion (*EmptyArr)(nil), // 14: util.EmptyArr
(*Device)(nil), // 15: authentication.Device
(*ConfigVersion)(nil), // 16: authentication.ConfigVersion
} }
var file_rpc_proto_depIdxs = []int32{ var file_rpc_proto_depIdxs = []int32{
5, // 0: rpc.LongPollingPayload.data:type_name -> rpc.IncomingRPCMessage 5, // 0: rpc.LongPollingPayload.data:type_name -> rpc.IncomingRPCMessage
12, // 1: rpc.LongPollingPayload.heartbeat:type_name -> util.EmptyArr 14, // 1: rpc.LongPollingPayload.heartbeat:type_name -> util.EmptyArr
3, // 2: rpc.LongPollingPayload.ack:type_name -> rpc.StartAckMessage 3, // 2: rpc.LongPollingPayload.ack:type_name -> rpc.StartAckMessage
12, // 3: rpc.LongPollingPayload.startRead:type_name -> util.EmptyArr 14, // 3: rpc.LongPollingPayload.startRead:type_name -> util.EmptyArr
0, // 4: rpc.IncomingRPCMessage.bugleRoute:type_name -> rpc.BugleRoute 0, // 4: rpc.IncomingRPCMessage.bugleRoute:type_name -> rpc.BugleRoute
2, // 5: rpc.IncomingRPCMessage.messageType:type_name -> rpc.MessageType 2, // 5: rpc.IncomingRPCMessage.messageType:type_name -> rpc.MessageType
13, // 6: rpc.IncomingRPCMessage.mobile:type_name -> authentication.Device 15, // 6: rpc.IncomingRPCMessage.mobile:type_name -> authentication.Device
13, // 7: rpc.IncomingRPCMessage.browser:type_name -> authentication.Device 15, // 7: rpc.IncomingRPCMessage.browser:type_name -> authentication.Device
1, // 8: rpc.RPCMessageData.action:type_name -> rpc.ActionType 1, // 8: rpc.RPCMessageData.action:type_name -> rpc.ActionType
13, // 9: rpc.OutgoingRPCMessage.mobile:type_name -> authentication.Device 15, // 9: rpc.OutgoingRPCMessage.mobile:type_name -> authentication.Device
10, // 10: rpc.OutgoingRPCMessage.data:type_name -> rpc.OutgoingRPCMessage.Data 11, // 10: rpc.OutgoingRPCMessage.data:type_name -> rpc.OutgoingRPCMessage.Data
9, // 11: rpc.OutgoingRPCMessage.auth:type_name -> rpc.OutgoingRPCMessage.Auth 10, // 11: rpc.OutgoingRPCMessage.auth:type_name -> rpc.OutgoingRPCMessage.Auth
12, // 12: rpc.OutgoingRPCMessage.emptyArr:type_name -> util.EmptyArr 14, // 12: rpc.OutgoingRPCMessage.emptyArr:type_name -> util.EmptyArr
1, // 13: rpc.OutgoingRPCData.action:type_name -> rpc.ActionType 1, // 13: rpc.OutgoingRPCData.action:type_name -> rpc.ActionType
14, // 14: rpc.OutgoingRPCMessage.Auth.configVersion:type_name -> authentication.ConfigVersion 16, // 14: rpc.OutgoingRPCMessage.Auth.configVersion:type_name -> authentication.ConfigVersion
0, // 15: rpc.OutgoingRPCMessage.Data.bugleRoute:type_name -> rpc.BugleRoute 0, // 15: rpc.OutgoingRPCMessage.Data.bugleRoute:type_name -> rpc.BugleRoute
11, // 16: rpc.OutgoingRPCMessage.Data.messageTypeData:type_name -> rpc.OutgoingRPCMessage.Data.Type 12, // 16: rpc.OutgoingRPCMessage.Data.messageTypeData:type_name -> rpc.OutgoingRPCMessage.Data.Type
12, // 17: rpc.OutgoingRPCMessage.Data.Type.emptyArr:type_name -> util.EmptyArr 14, // 17: rpc.OutgoingRPCMessage.Data.Type.emptyArr:type_name -> util.EmptyArr
2, // 18: rpc.OutgoingRPCMessage.Data.Type.messageType:type_name -> rpc.MessageType 2, // 18: rpc.OutgoingRPCMessage.Data.Type.messageType:type_name -> rpc.MessageType
19, // [19:19] is the sub-list for method output_type 19, // [19:19] is the sub-list for method output_type
19, // [19:19] is the sub-list for method input_type 19, // [19:19] is the sub-list for method input_type
@ -1123,7 +1221,7 @@ func file_rpc_proto_init() {
} }
} }
file_rpc_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { file_rpc_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*OutgoingRPCMessage_Auth); i { switch v := v.(*OutgoingRPCResponse); i {
case 0: case 0:
return &v.state return &v.state
case 1: case 1:
@ -1135,7 +1233,7 @@ func file_rpc_proto_init() {
} }
} }
file_rpc_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { file_rpc_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*OutgoingRPCMessage_Data); i { switch v := v.(*OutgoingRPCMessage_Auth); i {
case 0: case 0:
return &v.state return &v.state
case 1: case 1:
@ -1147,6 +1245,18 @@ func file_rpc_proto_init() {
} }
} }
file_rpc_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { file_rpc_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*OutgoingRPCMessage_Data); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_rpc_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*OutgoingRPCMessage_Data_Type); i { switch v := v.(*OutgoingRPCMessage_Data_Type); i {
case 0: case 0:
return &v.state return &v.state
@ -1158,16 +1268,29 @@ func file_rpc_proto_init() {
return nil return nil
} }
} }
file_rpc_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*OutgoingRPCResponse_SomeIdentifier); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
file_rpc_proto_msgTypes[0].OneofWrappers = []interface{}{} file_rpc_proto_msgTypes[0].OneofWrappers = []interface{}{}
file_rpc_proto_msgTypes[1].OneofWrappers = []interface{}{} file_rpc_proto_msgTypes[1].OneofWrappers = []interface{}{}
file_rpc_proto_msgTypes[6].OneofWrappers = []interface{}{}
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{ File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_rpc_proto_rawDesc, RawDescriptor: file_rpc_proto_rawDesc,
NumEnums: 3, NumEnums: 3,
NumMessages: 9, NumMessages: 11,
NumExtensions: 0, NumExtensions: 0,
NumServices: 0, NumServices: 0,
}, },

Binary file not shown.

View file

@ -85,6 +85,16 @@ message OutgoingRPCData {
string sessionID = 6; string sessionID = 6;
} }
message OutgoingRPCResponse {
message SomeIdentifier {
// 1 -> unknown
string someNumber = 2;
}
// This is not present for AckMessage responses, only for SendMessage
optional string timestamp = 2;
}
enum BugleRoute { enum BugleRoute {
Unknown = 0; Unknown = 0;
DataEvent = 19; DataEvent = 19;

76
libgm/http.go Normal file
View file

@ -0,0 +1,76 @@
package libgm
import (
"bytes"
"fmt"
"io"
"mime"
"net/http"
"google.golang.org/protobuf/proto"
"go.mau.fi/mautrix-gmessages/libgm/events"
"go.mau.fi/mautrix-gmessages/libgm/pblite"
"go.mau.fi/mautrix-gmessages/libgm/util"
)
const ContentTypeProtobuf = "application/x-protobuf"
const ContentTypePBLite = "application/json+protobuf"
func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) {
var body []byte
var err error
switch contentType {
case ContentTypeProtobuf:
body, err = proto.Marshal(data)
case ContentTypePBLite:
body, err = pblite.Marshal(data)
default:
return nil, fmt.Errorf("unknown request content type %s", contentType)
}
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
}
util.BuildRelayHeaders(req, contentType, "*/*")
res, reqErr := c.http.Do(req)
if reqErr != nil {
return res, reqErr
}
return res, nil
}
func typedHTTPResponse[T proto.Message](resp *http.Response, err error) (parsed T, retErr error) {
if err != nil {
retErr = err
return
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
retErr = events.HTTPError{Resp: resp}
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
retErr = fmt.Errorf("failed to read response body: %w", err)
return
}
contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
retErr = fmt.Errorf("failed to parse content-type: %w", err)
return
}
parsed = parsed.ProtoReflect().New().Interface().(T)
switch contentType {
case ContentTypeProtobuf:
retErr = proto.Unmarshal(body, parsed)
case ContentTypePBLite:
retErr = pblite.Unmarshal(body, parsed)
default:
retErr = fmt.Errorf("unknown content type %s in response", contentType)
}
return
}

View file

@ -1,12 +1,9 @@
package libgm package libgm
import ( import (
"bytes"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io"
"net/http"
"github.com/google/uuid" "github.com/google/uuid"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -68,26 +65,13 @@ func (c *Client) completePairing(data *gmproto.PairedData) {
} }
} }
func (c *Client) makeRelayRequest(url string, body []byte) (*http.Response, error) {
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
util.BuildRelayHeaders(req, "application/x-protobuf", "*/*")
res, reqErr := c.http.Do(req)
if reqErr != nil {
return res, reqErr
}
return res, nil
}
func (c *Client) RegisterPhoneRelay() (*gmproto.RegisterPhoneRelayResponse, error) { func (c *Client) RegisterPhoneRelay() (*gmproto.RegisterPhoneRelayResponse, error) {
key, err := x509.MarshalPKIXPublicKey(c.AuthData.RefreshKey.GetPublicKey()) key, err := x509.MarshalPKIXPublicKey(c.AuthData.RefreshKey.GetPublicKey())
if err != nil { if err != nil {
return nil, err return nil, err
} }
body, err := proto.Marshal(&gmproto.AuthenticationContainer{ payload := &gmproto.AuthenticationContainer{
AuthMessage: &gmproto.AuthMessage{ AuthMessage: &gmproto.AuthMessage{
RequestID: uuid.NewString(), RequestID: uuid.NewString(),
Network: &util.Network, Network: &util.Network,
@ -102,50 +86,24 @@ func (c *Client) RegisterPhoneRelay() (*gmproto.RegisterPhoneRelayResponse, erro
}, },
}, },
}, },
})
if err != nil {
return nil, err
} }
relayResponse, reqErr := c.makeRelayRequest(util.RegisterPhoneRelayURL, body) return typedHTTPResponse[*gmproto.RegisterPhoneRelayResponse](
if reqErr != nil { c.makeProtobufHTTPRequest(util.RegisterPhoneRelayURL, payload, ContentTypeProtobuf),
return nil, err )
}
responseBody, err := io.ReadAll(relayResponse.Body)
if err != nil {
return nil, err
}
relayResponse.Body.Close()
res := &gmproto.RegisterPhoneRelayResponse{}
err = proto.Unmarshal(responseBody, res)
if err != nil {
return nil, err
}
return res, err
} }
func (c *Client) RefreshPhoneRelay() (string, error) { func (c *Client) RefreshPhoneRelay() (string, error) {
body, err := proto.Marshal(&gmproto.AuthenticationContainer{ payload := &gmproto.AuthenticationContainer{
AuthMessage: &gmproto.AuthMessage{ AuthMessage: &gmproto.AuthMessage{
RequestID: uuid.NewString(), RequestID: uuid.NewString(),
Network: &util.Network, Network: &util.Network,
TachyonAuthToken: c.AuthData.TachyonAuthToken, TachyonAuthToken: c.AuthData.TachyonAuthToken,
ConfigVersion: util.ConfigMessage, ConfigVersion: util.ConfigMessage,
}, },
})
if err != nil {
return "", err
} }
relayResponse, err := c.makeRelayRequest(util.RefreshPhoneRelayURL, body) res, err := typedHTTPResponse[*gmproto.RefreshPhoneRelayResponse](
if err != nil { c.makeProtobufHTTPRequest(util.RefreshPhoneRelayURL, payload, ContentTypeProtobuf),
return "", err )
}
responseBody, err := io.ReadAll(relayResponse.Body)
defer relayResponse.Body.Close()
if err != nil {
return "", err
}
res := &gmproto.RefreshPhoneRelayResponse{}
err = proto.Unmarshal(responseBody, res)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -157,61 +115,31 @@ func (c *Client) RefreshPhoneRelay() (string, error) {
} }
func (c *Client) GetWebEncryptionKey() (*gmproto.WebEncryptionKeyResponse, error) { func (c *Client) GetWebEncryptionKey() (*gmproto.WebEncryptionKeyResponse, error) {
body, err := proto.Marshal(&gmproto.AuthenticationContainer{ payload := &gmproto.AuthenticationContainer{
AuthMessage: &gmproto.AuthMessage{ AuthMessage: &gmproto.AuthMessage{
RequestID: uuid.NewString(), RequestID: uuid.NewString(),
TachyonAuthToken: c.AuthData.TachyonAuthToken, TachyonAuthToken: c.AuthData.TachyonAuthToken,
ConfigVersion: util.ConfigMessage, ConfigVersion: util.ConfigMessage,
}, },
})
if err != nil {
return nil, err
} }
webKeyResponse, err := c.makeRelayRequest(util.GetWebEncryptionKeyURL, body) return typedHTTPResponse[*gmproto.WebEncryptionKeyResponse](
if err != nil { c.makeProtobufHTTPRequest(util.GetWebEncryptionKeyURL, payload, ContentTypeProtobuf),
return nil, err )
}
responseBody, err := io.ReadAll(webKeyResponse.Body)
defer webKeyResponse.Body.Close()
if err != nil {
return nil, err
}
parsedResponse := &gmproto.WebEncryptionKeyResponse{}
err = proto.Unmarshal(responseBody, parsedResponse)
if err != nil {
return nil, err
}
return parsedResponse, nil
} }
func (c *Client) Unpair() (*gmproto.RevokeRelayPairingResponse, error) { func (c *Client) Unpair() (*gmproto.RevokeRelayPairingResponse, error) {
if c.AuthData.TachyonAuthToken == nil || c.AuthData.Browser == nil { if c.AuthData.TachyonAuthToken == nil || c.AuthData.Browser == nil {
return nil, nil return nil, nil
} }
payload, err := proto.Marshal(&gmproto.RevokeRelayPairingRequest{ payload := &gmproto.RevokeRelayPairingRequest{
AuthMessage: &gmproto.AuthMessage{ AuthMessage: &gmproto.AuthMessage{
RequestID: uuid.NewString(), RequestID: uuid.NewString(),
TachyonAuthToken: c.AuthData.TachyonAuthToken, TachyonAuthToken: c.AuthData.TachyonAuthToken,
ConfigVersion: util.ConfigMessage, ConfigVersion: util.ConfigMessage,
}, },
Browser: c.AuthData.Browser, Browser: c.AuthData.Browser,
})
if err != nil {
return nil, err
} }
revokeResp, err := c.makeRelayRequest(util.RevokeRelayPairingURL, payload) return typedHTTPResponse[*gmproto.RevokeRelayPairingResponse](
if err != nil { c.makeProtobufHTTPRequest(util.RevokeRelayPairingURL, payload, ContentTypeProtobuf),
return nil, err )
}
responseBody, err := io.ReadAll(revokeResp.Body)
defer revokeResp.Body.Close()
if err != nil {
return nil, err
}
parsedResponse := &gmproto.RevokeRelayPairingResponse{}
err = proto.Unmarshal(responseBody, parsedResponse)
if err != nil {
return nil, err
}
return parsedResponse, nil
} }

View file

@ -8,7 +8,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -23,7 +22,6 @@ import (
type RPC struct { type RPC struct {
client *Client client *Client
http *http.Client
conn io.ReadCloser conn io.ReadCloser
stopping bool stopping bool
listenID int listenID int
@ -49,7 +47,7 @@ func (r *RPC) ListenReceiveMessages(loggedIn bool) {
return return
} }
r.client.Logger.Debug().Msg("Starting new long-polling request") r.client.Logger.Debug().Msg("Starting new long-polling request")
receivePayload, err := pblite.Marshal(&gmproto.ReceiveMessagesRequest{ payload := &gmproto.ReceiveMessagesRequest{
Auth: &gmproto.AuthMessage{ Auth: &gmproto.AuthMessage{
RequestID: listenReqID, RequestID: listenReqID,
TachyonAuthToken: r.client.AuthData.TachyonAuthToken, TachyonAuthToken: r.client.AuthData.TachyonAuthToken,
@ -58,19 +56,11 @@ func (r *RPC) ListenReceiveMessages(loggedIn bool) {
Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject2{ Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject2{
Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject1{}, Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject1{},
}, },
})
if err != nil {
panic(fmt.Errorf("Error marshaling request: %v", err))
} }
req, err := http.NewRequest(http.MethodPost, util.ReceiveMessagesURL, bytes.NewReader(receivePayload)) resp, err := r.client.makeProtobufHTTPRequest(util.ReceiveMessagesURL, payload, ContentTypePBLite)
if err != nil { if err != nil {
panic(fmt.Errorf("Error creating request: %v", err))
}
util.BuildRelayHeaders(req, "application/json+protobuf", "*/*")
resp, reqErr := r.http.Do(req)
if reqErr != nil {
if loggedIn { if loggedIn {
r.client.triggerEvent(&events.ListenTemporaryError{Error: reqErr}) r.client.triggerEvent(&events.ListenTemporaryError{Error: err})
} }
errored = true errored = true
r.client.Logger.Err(err).Msg("Error making listen request, retrying in 5 seconds") r.client.Logger.Err(err).Msg("Error making listen request, retrying in 5 seconds")
@ -203,16 +193,3 @@ func (r *RPC) CloseConnection() {
r.conn = nil r.conn = nil
} }
} }
func (r *RPC) sendMessageRequest(url string, payload []byte) (*http.Response, error) {
req, err := http.NewRequest("POST", url, bytes.NewReader(payload))
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
util.BuildRelayHeaders(req, "application/json+protobuf", "*/*")
resp, reqErr := r.client.http.Do(req)
if reqErr != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
return resp, reqErr
}

View file

@ -10,8 +10,6 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"go.mau.fi/mautrix-gmessages/libgm/pblite"
"go.mau.fi/mautrix-gmessages/libgm/gmproto" "go.mau.fi/mautrix-gmessages/libgm/gmproto"
"go.mau.fi/mautrix-gmessages/libgm/util" "go.mau.fi/mautrix-gmessages/libgm/util"
) )
@ -41,7 +39,9 @@ func (s *SessionHandler) sendMessageNoResponse(params SendMessageParams) error {
return err return err
} }
_, err = s.client.rpc.sendMessageRequest(util.SendMessageURL, payload) _, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse](
s.client.makeProtobufHTTPRequest(util.SendMessageURL, payload, ContentTypePBLite),
)
return err return err
} }
@ -52,10 +52,12 @@ func (s *SessionHandler) sendAsyncMessage(params SendMessageParams) (<-chan *Inc
} }
ch := s.waitResponse(requestID) ch := s.waitResponse(requestID)
_, reqErr := s.client.rpc.sendMessageRequest(util.SendMessageURL, payload) _, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse](
if reqErr != nil { s.client.makeProtobufHTTPRequest(util.SendMessageURL, payload, ContentTypePBLite),
)
if err != nil {
s.cancelResponse(requestID, ch) s.cancelResponse(requestID, ch)
return nil, reqErr return nil, err
} }
return ch, nil return ch, nil
} }
@ -142,7 +144,7 @@ type SendMessageParams struct {
MessageType gmproto.MessageType MessageType gmproto.MessageType
} }
func (s *SessionHandler) buildMessage(params SendMessageParams) (string, []byte, error) { func (s *SessionHandler) buildMessage(params SendMessageParams) (string, proto.Message, error) {
var requestID string var requestID string
var err error var err error
sessionID := s.client.sessionHandler.sessionID sessionID := s.client.sessionHandler.sessionID
@ -199,9 +201,7 @@ func (s *SessionHandler) buildMessage(params SendMessageParams) (string, []byte,
return "", nil, err return "", nil, err
} }
var marshaledMessage []byte return requestID, message, err
marshaledMessage, err = pblite.Marshal(message)
return requestID, marshaledMessage, err
} }
func (s *SessionHandler) queueMessageAck(messageID string) { func (s *SessionHandler) queueMessageAck(messageID string) {
@ -243,7 +243,7 @@ func (s *SessionHandler) sendAckRequest() {
Device: s.client.AuthData.Browser, Device: s.client.AuthData.Browser,
} }
} }
ackMessagePayload := &gmproto.AckMessageRequest{ payload := &gmproto.AckMessageRequest{
AuthData: &gmproto.AuthMessage{ AuthData: &gmproto.AuthMessage{
RequestID: uuid.NewString(), RequestID: uuid.NewString(),
TachyonAuthToken: s.client.AuthData.TachyonAuthToken, TachyonAuthToken: s.client.AuthData.TachyonAuthToken,
@ -252,13 +252,13 @@ func (s *SessionHandler) sendAckRequest() {
EmptyArr: &gmproto.EmptyArr{}, EmptyArr: &gmproto.EmptyArr{},
Acks: ackMessages, Acks: ackMessages,
} }
jsonData, err := pblite.Marshal(ackMessagePayload) _, err := typedHTTPResponse[*gmproto.OutgoingRPCResponse](
s.client.makeProtobufHTTPRequest(util.AckMessagesURL, payload, ContentTypePBLite),
)
if err != nil { if err != nil {
panic(err) // TODO retry?
s.client.Logger.Err(err).Strs("message_ids", dataToAck).Msg("Failed to send acks")
} else {
s.client.Logger.Debug().Strs("message_ids", dataToAck).Msg("Sent acks")
} }
_, err = s.client.rpc.sendMessageRequest(util.AckMessagesURL, jsonData)
if err != nil {
panic(err)
}
s.client.Logger.Debug().Strs("message_ids", dataToAck).Msg("Sent acks")
} }