Make response waiting less hacky

This commit is contained in:
Tulir Asokan 2023-07-16 01:45:57 +03:00
parent 50c2d45316
commit 10affb59b1
9 changed files with 126 additions and 248 deletions

View file

@ -47,7 +47,7 @@ type Client struct {
func NewClient(authData *AuthData, logger zerolog.Logger) *Client {
sessionHandler := &SessionHandler{
requests: make(map[string]map[binary.ActionType]*ResponseChan),
responseWaiters: make(map[string]chan<- *pblite.Response),
responseTimeout: time.Duration(5000) * time.Millisecond,
}
if authData == nil {

View file

@ -15,12 +15,7 @@ func (c *Client) ListConversations(count int64, folder binary.ListConversationsP
//} else {
actionType := binary.ActionType_LIST_CONVERSATIONS
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return nil, sendErr
}
response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
response, err := c.sessionHandler.sendMessage(actionType, payload)
if err != nil {
return nil, err
}
@ -37,12 +32,7 @@ func (c *Client) GetConversationType(conversationID string) (*binary.GetConversa
payload := &binary.ConversationTypePayload{ConversationID: conversationID}
actionType := binary.ActionType_GET_CONVERSATION_TYPE
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return nil, sendErr
}
response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
response, err := c.sessionHandler.sendMessage(actionType, payload)
if err != nil {
return nil, err
}
@ -63,12 +53,7 @@ func (c *Client) FetchMessages(conversationID string, count int64, cursor *binar
actionType := binary.ActionType_LIST_MESSAGES
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return nil, sendErr
}
response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
response, err := c.sessionHandler.sendMessage(actionType, payload)
if err != nil {
return nil, err
}
@ -84,12 +69,7 @@ func (c *Client) FetchMessages(conversationID string, count int64, cursor *binar
func (c *Client) SendMessage(payload *binary.SendMessagePayload) (*binary.SendMessageResponse, error) {
actionType := binary.ActionType_SEND_MESSAGE
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return nil, sendErr
}
response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
response, err := c.sessionHandler.sendMessage(actionType, payload)
if err != nil {
return nil, err
}
@ -106,12 +86,7 @@ func (c *Client) GetParticipantThumbnail(convID string) (*binary.ParticipantThum
payload := &binary.GetParticipantThumbnailPayload{ConversationID: convID}
actionType := binary.ActionType_GET_PARTICIPANTS_THUMBNAIL
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return nil, sendErr
}
response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
response, err := c.sessionHandler.sendMessage(actionType, payload)
if err != nil {
return nil, err
}
@ -134,12 +109,7 @@ func (c *Client) UpdateConversation(convBuilder *ConversationBuilder) (*binary.U
actionType := binary.ActionType_UPDATE_CONVERSATION
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return nil, sendErr
}
response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
response, err := c.sessionHandler.sendMessage(actionType, payload)
if err != nil {
return nil, err
}
@ -156,14 +126,6 @@ func (c *Client) SetTyping(convID string) error {
payload := &binary.TypingUpdatePayload{Data: &binary.SetTypingIn{ConversationID: convID, Typing: true}}
actionType := binary.ActionType_TYPING_UPDATES
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return sendErr
}
_, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
if err != nil {
return err
}
return nil
_, err := c.sessionHandler.sendMessage(actionType, payload)
return err
}

View file

@ -52,34 +52,33 @@ func (r *RPC) HandleRPCMsg(msg *binary.InternalMessage) {
r.client.Logger.Error().Msg("nil response in rpc handler")
return
}
_, waitingForResponse := r.client.sessionHandler.requests[response.Data.RequestId]
r.client.sessionHandler.addResponseAck(response.ResponseId)
if waitingForResponse {
r.client.sessionHandler.respondToRequestChannel(response)
} else {
switch response.BugleRoute {
case binary.BugleRoute_PairEvent:
go r.client.handlePairingEvent(response)
case binary.BugleRoute_DataEvent:
if r.skipCount > 0 {
r.skipCount--
r.client.Logger.Debug().
Any("action", response.Data.Action).
Any("toSkip", r.skipCount).
Msg("Skipped DataEvent")
if response.Data.Decrypted != nil {
r.client.Logger.Trace().
Str("proto_name", string(response.Data.Decrypted.ProtoReflect().Descriptor().FullName())).
Str("data", base64.StdEncoding.EncodeToString(response.Data.RawDecrypted)).
Msg("Skipped event data")
}
return
r.client.sessionHandler.queueMessageAck(response.ResponseID)
if r.client.sessionHandler.receiveResponse(response) {
r.client.Logger.Debug().Str("request_id", response.Data.RequestID).Msg("Received response")
return
}
switch response.BugleRoute {
case binary.BugleRoute_PairEvent:
go r.client.handlePairingEvent(response)
case binary.BugleRoute_DataEvent:
if r.skipCount > 0 {
r.skipCount--
r.client.Logger.Debug().
Any("action", response.Data.Action).
Int("remaining_skip_count", r.skipCount).
Msg("Skipped DataEvent")
if response.Data.Decrypted != nil {
r.client.Logger.Trace().
Str("proto_name", string(response.Data.Decrypted.ProtoReflect().Descriptor().FullName())).
Str("data", base64.StdEncoding.EncodeToString(response.Data.RawDecrypted)).
Msg("Skipped event data")
}
r.client.handleUpdatesEvent(response)
default:
r.client.Logger.Debug().Any("res", response).Msg("Got unknown bugleroute")
return
}
r.client.handleUpdatesEvent(response)
default:
r.client.Logger.Debug().Any("res", response).Msg("Got unknown bugleroute")
}
}

View file

@ -9,12 +9,7 @@ import (
func (c *Client) SendReaction(payload *binary.SendReactionPayload) (*binary.SendReactionResponse, error) {
actionType := binary.ActionType_SEND_REACTION
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return nil, sendErr
}
response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
response, err := c.sessionHandler.sendMessage(actionType, payload)
if err != nil {
return nil, err
}
@ -31,12 +26,7 @@ func (c *Client) DeleteMessage(messageID string) (*binary.DeleteMessageResponse,
payload := &binary.DeleteMessagePayload{MessageID: messageID}
actionType := binary.ActionType_DELETE_MESSAGE
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return nil, sendErr
}
response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
response, err := c.sessionHandler.sendMessage(actionType, payload)
if err != nil {
return nil, err
}
@ -53,15 +43,6 @@ func (c *Client) MarkRead(conversationID, messageID string) error {
payload := &binary.MessageReadPayload{ConversationID: conversationID, MessageID: messageID}
actionType := binary.ActionType_MESSAGE_READ
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return sendErr
}
_, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
if err != nil {
return err
}
return nil
_, err := c.sessionHandler.sendMessage(actionType, payload)
return err
}

View file

@ -15,7 +15,7 @@ type DevicePair struct {
}
type RequestData struct {
RequestId string `json:"requestId,omitempty"`
RequestID string `json:"requestId,omitempty"`
Timestamp int64 `json:"timestamp,omitempty"`
Action binary.ActionType `json:"action,omitempty"`
Bool1 bool `json:"bool1,omitempty"`
@ -27,7 +27,7 @@ type RequestData struct {
}
type Response struct {
ResponseId string `json:"responseId,omitempty"`
ResponseID string `json:"responseId,omitempty"`
BugleRoute binary.BugleRoute `json:"bugleRoute,omitempty"`
StartExecute string `json:"startExecute,omitempty"`
MessageType binary.MessageType `json:"eventType,omitempty"`
@ -76,7 +76,7 @@ func DecryptInternalMessage(internalMessage *binary.InternalMessage, cryptor *cr
func newResponseFromPairEvent(internalMsg *binary.InternalMessageData, data *binary.PairEvents) *Response {
resp := &Response{
ResponseId: internalMsg.GetResponseID(),
ResponseID: internalMsg.GetResponseID(),
BugleRoute: internalMsg.GetBugleRoute(),
StartExecute: internalMsg.GetStartExecute(),
MessageType: internalMsg.GetMessageType(),
@ -98,7 +98,7 @@ func newResponseFromPairEvent(internalMsg *binary.InternalMessageData, data *bin
func newResponseFromDataEvent(internalMsg *binary.InternalMessageData, internalRequestData *binary.InternalRequestData, rawData []byte, decrypted protoreflect.ProtoMessage) *Response {
resp := &Response{
ResponseId: internalMsg.GetResponseID(),
ResponseID: internalMsg.GetResponseID(),
BugleRoute: internalMsg.GetBugleRoute(),
StartExecute: internalMsg.GetStartExecute(),
MessageType: internalMsg.GetMessageType(),
@ -109,7 +109,7 @@ func newResponseFromDataEvent(internalMsg *binary.InternalMessageData, internalR
Browser: internalMsg.GetBrowser(),
},
Data: RequestData{
RequestId: internalRequestData.GetSessionID(),
RequestID: internalRequestData.GetSessionID(),
Timestamp: internalRequestData.GetTimestamp(),
Action: internalRequestData.GetAction(),
Bool1: internalRequestData.GetBool1(),

View file

@ -1,99 +1,48 @@
package libgm
import (
"encoding/base64"
"fmt"
"sync"
"go.mau.fi/mautrix-gmessages/libgm/pblite"
"go.mau.fi/mautrix-gmessages/libgm/binary"
"go.mau.fi/mautrix-gmessages/libgm/routes"
)
type ResponseChan struct {
response *pblite.Response
wg sync.WaitGroup
mu sync.Mutex
func (s *SessionHandler) waitResponse(requestID string) chan *pblite.Response {
ch := make(chan *pblite.Response, 1)
s.responseWaitersLock.Lock()
// DEBUG
if _, ok := s.responseWaiters[requestID]; ok {
panic(fmt.Errorf("request %s already has a response waiter", requestID))
}
// END DEBUG
s.responseWaiters[requestID] = ch
s.responseWaitersLock.Unlock()
return ch
}
func (s *SessionHandler) addRequestToChannel(requestId string, actionType binary.ActionType) {
_, notOk := routes.Routes[actionType]
if !notOk {
panic(fmt.Errorf("missing action type: %v", actionType))
}
if msgMap, ok := s.requests[requestId]; ok {
responseChan := &ResponseChan{
response: &pblite.Response{},
wg: sync.WaitGroup{},
mu: sync.Mutex{},
}
responseChan.wg.Add(1)
responseChan.mu.Lock()
msgMap[actionType] = responseChan
} else {
s.requests[requestId] = make(map[binary.ActionType]*ResponseChan)
responseChan := &ResponseChan{
response: &pblite.Response{},
wg: sync.WaitGroup{},
mu: sync.Mutex{},
}
responseChan.wg.Add(1)
responseChan.mu.Lock()
s.requests[requestId][actionType] = responseChan
}
func (s *SessionHandler) cancelResponse(requestID string, ch chan *pblite.Response) {
s.responseWaitersLock.Lock()
close(ch)
delete(s.responseWaiters, requestID)
s.responseWaitersLock.Unlock()
}
func (s *SessionHandler) respondToRequestChannel(res *pblite.Response) {
requestId := res.Data.RequestId
reqChannel, ok := s.requests[requestId]
actionType := res.Data.Action
func (s *SessionHandler) receiveResponse(resp *pblite.Response) bool {
s.responseWaitersLock.Lock()
ch, ok := s.responseWaiters[resp.Data.RequestID]
if !ok {
s.client.Logger.Debug().Any("actionType", actionType).Any("requestId", requestId).Msg("Did not expect response for this requestId")
return
s.responseWaitersLock.Unlock()
return false
}
actionResponseChan, ok2 := reqChannel[actionType]
if !ok2 {
s.client.Logger.Debug().Any("actionType", actionType).Any("requestId", requestId).Msg("Did not expect response for this actionType")
return
delete(s.responseWaiters, resp.Data.RequestID)
s.responseWaitersLock.Unlock()
evt := s.client.Logger.Trace().
Str("request_id", resp.Data.RequestID)
if evt.Enabled() && resp.Data.Decrypted != nil {
evt.Str("proto_name", string(resp.Data.Decrypted.ProtoReflect().Descriptor().FullName())).
Str("data", base64.StdEncoding.EncodeToString(resp.Data.RawDecrypted))
}
actionResponseChan.mu.Lock()
actionResponseChan, ok2 = reqChannel[actionType]
if !ok2 {
s.client.Logger.Debug().Any("actionType", actionType).Any("requestId", requestId).Msg("Ignoring request for action...")
return
}
s.client.Logger.Debug().Any("actionType", actionType).Any("requestId", requestId).Msg("responding to request")
s.client.rpc.logContent(res)
actionResponseChan.response = res
actionResponseChan.wg.Done()
delete(reqChannel, actionType)
if len(reqChannel) == 0 {
delete(s.requests, requestId)
}
actionResponseChan.mu.Unlock()
}
func (s *SessionHandler) WaitForResponse(requestId string, actionType binary.ActionType) (*pblite.Response, error) {
requestResponses, ok := s.requests[requestId]
if !ok {
return nil, fmt.Errorf("no response channel found for request ID: %s (actionType: %v)", requestId, actionType)
}
routeInfo, notFound := routes.Routes[actionType]
if !notFound {
return nil, fmt.Errorf("no action exists for actionType: %v (requestId: %s)", actionType, requestId)
}
responseChan, ok2 := requestResponses[routeInfo.Action]
if !ok2 {
return nil, fmt.Errorf("no response channel found for actionType: %v (requestId: %s)", routeInfo.Action, requestId)
}
responseChan.mu.Unlock()
responseChan.wg.Wait()
return responseChan.response, nil
evt.Msg("Received response")
ch <- resp
return true
}

View file

@ -7,26 +7,17 @@ import (
)
func (c *Client) SetActiveSession() error {
c.sessionHandler.ResetSessionId()
c.sessionHandler.ResetSessionID()
actionType := binary.ActionType_GET_UPDATES
_, sendErr := c.sessionHandler.completeSendMessage(actionType, false, nil)
if sendErr != nil {
return sendErr
}
return nil
return c.sessionHandler.sendMessageNoResponse(actionType, nil)
}
func (c *Client) IsBugleDefault() (*binary.IsBugleDefaultResponse, error) {
c.sessionHandler.ResetSessionId()
c.sessionHandler.ResetSessionID()
actionType := binary.ActionType_IS_BUGLE_DEFAULT
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, nil)
if sendErr != nil {
return nil, sendErr
}
response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
response, err := c.sessionHandler.sendMessage(actionType, nil)
if err != nil {
return nil, err
}
@ -43,15 +34,6 @@ func (c *Client) NotifyDittoActivity() error {
payload := &binary.NotifyDittoActivityPayload{Success: true}
actionType := binary.ActionType_NOTIFY_DITTO_ACTIVITY
sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload)
if sendErr != nil {
return sendErr
}
_, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType)
if err != nil {
return err
}
return nil
_, err := c.sessionHandler.sendMessage(actionType, payload)
return err
}

View file

@ -16,60 +16,64 @@ import (
"go.mau.fi/mautrix-gmessages/libgm/util"
)
/*
type Response struct {
client *Client
ResponseId string
RoutingOpCode int64
Data *binary.EncodedResponse // base64 encoded (decode -> protomessage)
StartExecute string
FinishExecute string
DevicePair *pblite.DevicePair
}
*/
type SessionHandler struct {
client *Client
requests map[string]map[binary.ActionType]*ResponseChan
client *Client
responseWaiters map[string]chan<- *pblite.Response
responseWaitersLock sync.Mutex
ackMapLock sync.Mutex
ackMap []string
ackTicker *time.Ticker
sessionId string
sessionID string
responseTimeout time.Duration
}
func (s *SessionHandler) SetResponseTimeout(milliSeconds int) {
s.responseTimeout = time.Duration(milliSeconds) * time.Millisecond
func (s *SessionHandler) ResetSessionID() {
s.sessionID = util.RandomUUIDv4()
}
func (s *SessionHandler) ResetSessionId() {
s.sessionId = util.RandomUUIDv4()
func (s *SessionHandler) sendMessageNoResponse(actionType binary.ActionType, encryptedData proto.Message) error {
_, payload, _, err := s.buildMessage(actionType, encryptedData)
if err != nil {
return err
}
_, err = s.client.rpc.sendMessageRequest(util.SEND_MESSAGE, payload)
return err
}
func (s *SessionHandler) completeSendMessage(actionType binary.ActionType, addToChannel bool, encryptedData proto.Message) (string, error) {
requestId, payload, action, buildErr := s.buildMessage(actionType, encryptedData)
func (s *SessionHandler) sendAsyncMessage(actionType binary.ActionType, encryptedData proto.Message) (<-chan *pblite.Response, error) {
requestID, payload, _, buildErr := s.buildMessage(actionType, encryptedData)
if buildErr != nil {
return "", buildErr
return nil, buildErr
}
if addToChannel {
s.addRequestToChannel(requestId, action)
}
ch := s.waitResponse(requestID)
_, reqErr := s.client.rpc.sendMessageRequest(util.SEND_MESSAGE, payload)
if reqErr != nil {
return "", reqErr
s.cancelResponse(requestID, ch)
return nil, reqErr
}
return requestId, nil
return ch, nil
}
func (s *SessionHandler) sendMessage(actionType binary.ActionType, encryptedData proto.Message) (*pblite.Response, error) {
ch, err := s.sendAsyncMessage(actionType, encryptedData)
if err != nil {
return nil, err
}
// TODO add timeout
return <-ch, nil
}
func (s *SessionHandler) buildMessage(actionType binary.ActionType, encryptedData proto.Message) (string, []byte, binary.ActionType, error) {
var requestId string
var requestID string
pairedDevice := s.client.authData.DevicePair.Mobile
sessionId := s.client.sessionHandler.sessionId
sessionId := s.client.sessionHandler.sessionID
token := s.client.authData.TachyonAuthToken
routeInfo, ok := routes.Routes[actionType]
@ -78,12 +82,12 @@ func (s *SessionHandler) buildMessage(actionType binary.ActionType, encryptedDat
}
if routeInfo.UseSessionID {
requestId = s.sessionId
requestID = s.sessionID
} else {
requestId = util.RandomUUIDv4()
requestID = util.RandomUUIDv4()
}
tmpMessage := payload.NewSendMessageBuilder(token, pairedDevice, requestId, sessionId).SetRoute(routeInfo.Action).SetSessionId(s.sessionId)
tmpMessage := payload.NewSendMessageBuilder(token, pairedDevice, requestID, sessionId).SetRoute(routeInfo.Action).SetSessionId(s.sessionID)
if encryptedData != nil {
tmpMessage.SetEncryptedProtoMessage(encryptedData, s.client.authData.Cryptor)
@ -98,16 +102,17 @@ func (s *SessionHandler) buildMessage(actionType binary.ActionType, encryptedDat
return "", nil, 0, buildErr
}
return requestId, message, routeInfo.Action, nil
return requestID, message, routeInfo.Action, nil
}
func (s *SessionHandler) addResponseAck(responseId string) {
s.client.Logger.Debug().Any("responseId", responseId).Msg("Added to ack map")
func (s *SessionHandler) queueMessageAck(messageID string) {
s.ackMapLock.Lock()
defer s.ackMapLock.Unlock()
hasResponseId := slices.Contains(s.ackMap, responseId)
if !hasResponseId {
s.ackMap = append(s.ackMap, responseId)
if !slices.Contains(s.ackMap, messageID) {
s.ackMap = append(s.ackMap, messageID)
s.client.Logger.Trace().Any("message_id", messageID).Msg("Queued ack for message")
} else {
s.client.Logger.Trace().Any("message_id", messageID).Msg("Ack for message was already queued")
}
}

View file

@ -26,9 +26,9 @@ func (c *Client) handleUserAlertEvent(res *pblite.Response, data *binary.UserAle
alertType := data.AlertType
switch alertType {
case binary.AlertType_BROWSER_ACTIVE:
newSessionId := res.Data.RequestId
newSessionId := res.Data.RequestID
c.Logger.Info().Any("sessionId", newSessionId).Msg("[NEW_BROWSER_ACTIVE] Opened new browser connection")
if newSessionId != c.sessionHandler.sessionId {
if newSessionId != c.sessionHandler.sessionID {
evt := events.NewBrowserActive(newSessionId)
c.triggerEvent(evt)
} else {