From 10affb59b17528bd7d7a76d837bff9ecf8dd9034 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Jul 2023 01:45:57 +0300 Subject: [PATCH] Make response waiting less hacky --- libgm/client.go | 2 +- libgm/conversations.go | 54 +++-------------- libgm/event_handler.go | 49 ++++++++-------- libgm/messages.go | 27 ++------- libgm/pblite/internal.go | 10 ++-- libgm/response_handler.go | 115 +++++++++++-------------------------- libgm/session.go | 30 ++-------- libgm/session_handler.go | 83 +++++++++++++------------- libgm/useralert_handler.go | 4 +- 9 files changed, 126 insertions(+), 248 deletions(-) diff --git a/libgm/client.go b/libgm/client.go index 1545272..e382499 100644 --- a/libgm/client.go +++ b/libgm/client.go @@ -47,7 +47,7 @@ type Client struct { func NewClient(authData *AuthData, logger zerolog.Logger) *Client { sessionHandler := &SessionHandler{ - requests: make(map[string]map[binary.ActionType]*ResponseChan), + responseWaiters: make(map[string]chan<- *pblite.Response), responseTimeout: time.Duration(5000) * time.Millisecond, } if authData == nil { diff --git a/libgm/conversations.go b/libgm/conversations.go index 961eff2..f18ea0d 100644 --- a/libgm/conversations.go +++ b/libgm/conversations.go @@ -15,12 +15,7 @@ func (c *Client) ListConversations(count int64, folder binary.ListConversationsP //} else { actionType := binary.ActionType_LIST_CONVERSATIONS - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return nil, sendErr - } - - response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) + response, err := c.sessionHandler.sendMessage(actionType, payload) if err != nil { return nil, err } @@ -37,12 +32,7 @@ func (c *Client) GetConversationType(conversationID string) (*binary.GetConversa payload := &binary.ConversationTypePayload{ConversationID: conversationID} actionType := binary.ActionType_GET_CONVERSATION_TYPE - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return nil, sendErr - } - - response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) + response, err := c.sessionHandler.sendMessage(actionType, payload) if err != nil { return nil, err } @@ -63,12 +53,7 @@ func (c *Client) FetchMessages(conversationID string, count int64, cursor *binar actionType := binary.ActionType_LIST_MESSAGES - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return nil, sendErr - } - - response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) + response, err := c.sessionHandler.sendMessage(actionType, payload) if err != nil { return nil, err } @@ -84,12 +69,7 @@ func (c *Client) FetchMessages(conversationID string, count int64, cursor *binar func (c *Client) SendMessage(payload *binary.SendMessagePayload) (*binary.SendMessageResponse, error) { actionType := binary.ActionType_SEND_MESSAGE - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return nil, sendErr - } - - response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) + response, err := c.sessionHandler.sendMessage(actionType, payload) if err != nil { return nil, err } @@ -106,12 +86,7 @@ func (c *Client) GetParticipantThumbnail(convID string) (*binary.ParticipantThum payload := &binary.GetParticipantThumbnailPayload{ConversationID: convID} actionType := binary.ActionType_GET_PARTICIPANTS_THUMBNAIL - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return nil, sendErr - } - - response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) + response, err := c.sessionHandler.sendMessage(actionType, payload) if err != nil { return nil, err } @@ -134,12 +109,7 @@ func (c *Client) UpdateConversation(convBuilder *ConversationBuilder) (*binary.U actionType := binary.ActionType_UPDATE_CONVERSATION - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return nil, sendErr - } - - response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) + response, err := c.sessionHandler.sendMessage(actionType, payload) if err != nil { return nil, err } @@ -156,14 +126,6 @@ func (c *Client) SetTyping(convID string) error { payload := &binary.TypingUpdatePayload{Data: &binary.SetTypingIn{ConversationID: convID, Typing: true}} actionType := binary.ActionType_TYPING_UPDATES - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return sendErr - } - - _, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) - if err != nil { - return err - } - return nil + _, err := c.sessionHandler.sendMessage(actionType, payload) + return err } diff --git a/libgm/event_handler.go b/libgm/event_handler.go index a96a0d9..9852e0a 100644 --- a/libgm/event_handler.go +++ b/libgm/event_handler.go @@ -52,34 +52,33 @@ func (r *RPC) HandleRPCMsg(msg *binary.InternalMessage) { r.client.Logger.Error().Msg("nil response in rpc handler") return } - _, waitingForResponse := r.client.sessionHandler.requests[response.Data.RequestId] - r.client.sessionHandler.addResponseAck(response.ResponseId) - if waitingForResponse { - r.client.sessionHandler.respondToRequestChannel(response) - } else { - switch response.BugleRoute { - case binary.BugleRoute_PairEvent: - go r.client.handlePairingEvent(response) - case binary.BugleRoute_DataEvent: - if r.skipCount > 0 { - r.skipCount-- - r.client.Logger.Debug(). - Any("action", response.Data.Action). - Any("toSkip", r.skipCount). - Msg("Skipped DataEvent") - if response.Data.Decrypted != nil { - r.client.Logger.Trace(). - Str("proto_name", string(response.Data.Decrypted.ProtoReflect().Descriptor().FullName())). - Str("data", base64.StdEncoding.EncodeToString(response.Data.RawDecrypted)). - Msg("Skipped event data") - } - return + r.client.sessionHandler.queueMessageAck(response.ResponseID) + if r.client.sessionHandler.receiveResponse(response) { + r.client.Logger.Debug().Str("request_id", response.Data.RequestID).Msg("Received response") + return + } + switch response.BugleRoute { + case binary.BugleRoute_PairEvent: + go r.client.handlePairingEvent(response) + case binary.BugleRoute_DataEvent: + if r.skipCount > 0 { + r.skipCount-- + r.client.Logger.Debug(). + Any("action", response.Data.Action). + Int("remaining_skip_count", r.skipCount). + Msg("Skipped DataEvent") + if response.Data.Decrypted != nil { + r.client.Logger.Trace(). + Str("proto_name", string(response.Data.Decrypted.ProtoReflect().Descriptor().FullName())). + Str("data", base64.StdEncoding.EncodeToString(response.Data.RawDecrypted)). + Msg("Skipped event data") } - r.client.handleUpdatesEvent(response) - default: - r.client.Logger.Debug().Any("res", response).Msg("Got unknown bugleroute") + return } + r.client.handleUpdatesEvent(response) + default: + r.client.Logger.Debug().Any("res", response).Msg("Got unknown bugleroute") } } diff --git a/libgm/messages.go b/libgm/messages.go index 28c9e6b..bd506a0 100644 --- a/libgm/messages.go +++ b/libgm/messages.go @@ -9,12 +9,7 @@ import ( func (c *Client) SendReaction(payload *binary.SendReactionPayload) (*binary.SendReactionResponse, error) { actionType := binary.ActionType_SEND_REACTION - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return nil, sendErr - } - - response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) + response, err := c.sessionHandler.sendMessage(actionType, payload) if err != nil { return nil, err } @@ -31,12 +26,7 @@ func (c *Client) DeleteMessage(messageID string) (*binary.DeleteMessageResponse, payload := &binary.DeleteMessagePayload{MessageID: messageID} actionType := binary.ActionType_DELETE_MESSAGE - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return nil, sendErr - } - - response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) + response, err := c.sessionHandler.sendMessage(actionType, payload) if err != nil { return nil, err } @@ -53,15 +43,6 @@ func (c *Client) MarkRead(conversationID, messageID string) error { payload := &binary.MessageReadPayload{ConversationID: conversationID, MessageID: messageID} actionType := binary.ActionType_MESSAGE_READ - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return sendErr - } - - _, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) - if err != nil { - return err - } - - return nil + _, err := c.sessionHandler.sendMessage(actionType, payload) + return err } diff --git a/libgm/pblite/internal.go b/libgm/pblite/internal.go index 7592780..910a5ee 100644 --- a/libgm/pblite/internal.go +++ b/libgm/pblite/internal.go @@ -15,7 +15,7 @@ type DevicePair struct { } type RequestData struct { - RequestId string `json:"requestId,omitempty"` + RequestID string `json:"requestId,omitempty"` Timestamp int64 `json:"timestamp,omitempty"` Action binary.ActionType `json:"action,omitempty"` Bool1 bool `json:"bool1,omitempty"` @@ -27,7 +27,7 @@ type RequestData struct { } type Response struct { - ResponseId string `json:"responseId,omitempty"` + ResponseID string `json:"responseId,omitempty"` BugleRoute binary.BugleRoute `json:"bugleRoute,omitempty"` StartExecute string `json:"startExecute,omitempty"` MessageType binary.MessageType `json:"eventType,omitempty"` @@ -76,7 +76,7 @@ func DecryptInternalMessage(internalMessage *binary.InternalMessage, cryptor *cr func newResponseFromPairEvent(internalMsg *binary.InternalMessageData, data *binary.PairEvents) *Response { resp := &Response{ - ResponseId: internalMsg.GetResponseID(), + ResponseID: internalMsg.GetResponseID(), BugleRoute: internalMsg.GetBugleRoute(), StartExecute: internalMsg.GetStartExecute(), MessageType: internalMsg.GetMessageType(), @@ -98,7 +98,7 @@ func newResponseFromPairEvent(internalMsg *binary.InternalMessageData, data *bin func newResponseFromDataEvent(internalMsg *binary.InternalMessageData, internalRequestData *binary.InternalRequestData, rawData []byte, decrypted protoreflect.ProtoMessage) *Response { resp := &Response{ - ResponseId: internalMsg.GetResponseID(), + ResponseID: internalMsg.GetResponseID(), BugleRoute: internalMsg.GetBugleRoute(), StartExecute: internalMsg.GetStartExecute(), MessageType: internalMsg.GetMessageType(), @@ -109,7 +109,7 @@ func newResponseFromDataEvent(internalMsg *binary.InternalMessageData, internalR Browser: internalMsg.GetBrowser(), }, Data: RequestData{ - RequestId: internalRequestData.GetSessionID(), + RequestID: internalRequestData.GetSessionID(), Timestamp: internalRequestData.GetTimestamp(), Action: internalRequestData.GetAction(), Bool1: internalRequestData.GetBool1(), diff --git a/libgm/response_handler.go b/libgm/response_handler.go index ae55887..1129667 100644 --- a/libgm/response_handler.go +++ b/libgm/response_handler.go @@ -1,99 +1,48 @@ package libgm import ( + "encoding/base64" "fmt" - "sync" "go.mau.fi/mautrix-gmessages/libgm/pblite" - - "go.mau.fi/mautrix-gmessages/libgm/binary" - "go.mau.fi/mautrix-gmessages/libgm/routes" ) -type ResponseChan struct { - response *pblite.Response - wg sync.WaitGroup - mu sync.Mutex +func (s *SessionHandler) waitResponse(requestID string) chan *pblite.Response { + ch := make(chan *pblite.Response, 1) + s.responseWaitersLock.Lock() + // DEBUG + if _, ok := s.responseWaiters[requestID]; ok { + panic(fmt.Errorf("request %s already has a response waiter", requestID)) + } + // END DEBUG + s.responseWaiters[requestID] = ch + s.responseWaitersLock.Unlock() + return ch } -func (s *SessionHandler) addRequestToChannel(requestId string, actionType binary.ActionType) { - _, notOk := routes.Routes[actionType] - if !notOk { - panic(fmt.Errorf("missing action type: %v", actionType)) - } - if msgMap, ok := s.requests[requestId]; ok { - responseChan := &ResponseChan{ - response: &pblite.Response{}, - wg: sync.WaitGroup{}, - mu: sync.Mutex{}, - } - responseChan.wg.Add(1) - responseChan.mu.Lock() - msgMap[actionType] = responseChan - } else { - s.requests[requestId] = make(map[binary.ActionType]*ResponseChan) - responseChan := &ResponseChan{ - response: &pblite.Response{}, - wg: sync.WaitGroup{}, - mu: sync.Mutex{}, - } - responseChan.wg.Add(1) - responseChan.mu.Lock() - s.requests[requestId][actionType] = responseChan - } +func (s *SessionHandler) cancelResponse(requestID string, ch chan *pblite.Response) { + s.responseWaitersLock.Lock() + close(ch) + delete(s.responseWaiters, requestID) + s.responseWaitersLock.Unlock() } -func (s *SessionHandler) respondToRequestChannel(res *pblite.Response) { - requestId := res.Data.RequestId - reqChannel, ok := s.requests[requestId] - actionType := res.Data.Action +func (s *SessionHandler) receiveResponse(resp *pblite.Response) bool { + s.responseWaitersLock.Lock() + ch, ok := s.responseWaiters[resp.Data.RequestID] if !ok { - s.client.Logger.Debug().Any("actionType", actionType).Any("requestId", requestId).Msg("Did not expect response for this requestId") - return + s.responseWaitersLock.Unlock() + return false } - actionResponseChan, ok2 := reqChannel[actionType] - if !ok2 { - s.client.Logger.Debug().Any("actionType", actionType).Any("requestId", requestId).Msg("Did not expect response for this actionType") - return + delete(s.responseWaiters, resp.Data.RequestID) + s.responseWaitersLock.Unlock() + evt := s.client.Logger.Trace(). + Str("request_id", resp.Data.RequestID) + if evt.Enabled() && resp.Data.Decrypted != nil { + evt.Str("proto_name", string(resp.Data.Decrypted.ProtoReflect().Descriptor().FullName())). + Str("data", base64.StdEncoding.EncodeToString(resp.Data.RawDecrypted)) } - actionResponseChan.mu.Lock() - actionResponseChan, ok2 = reqChannel[actionType] - if !ok2 { - s.client.Logger.Debug().Any("actionType", actionType).Any("requestId", requestId).Msg("Ignoring request for action...") - return - } - s.client.Logger.Debug().Any("actionType", actionType).Any("requestId", requestId).Msg("responding to request") - s.client.rpc.logContent(res) - actionResponseChan.response = res - actionResponseChan.wg.Done() - - delete(reqChannel, actionType) - if len(reqChannel) == 0 { - delete(s.requests, requestId) - } - - actionResponseChan.mu.Unlock() -} - -func (s *SessionHandler) WaitForResponse(requestId string, actionType binary.ActionType) (*pblite.Response, error) { - requestResponses, ok := s.requests[requestId] - if !ok { - return nil, fmt.Errorf("no response channel found for request ID: %s (actionType: %v)", requestId, actionType) - } - - routeInfo, notFound := routes.Routes[actionType] - if !notFound { - return nil, fmt.Errorf("no action exists for actionType: %v (requestId: %s)", actionType, requestId) - } - - responseChan, ok2 := requestResponses[routeInfo.Action] - if !ok2 { - return nil, fmt.Errorf("no response channel found for actionType: %v (requestId: %s)", routeInfo.Action, requestId) - } - - responseChan.mu.Unlock() - - responseChan.wg.Wait() - - return responseChan.response, nil + evt.Msg("Received response") + ch <- resp + return true } diff --git a/libgm/session.go b/libgm/session.go index 2d156b9..dc8c58e 100644 --- a/libgm/session.go +++ b/libgm/session.go @@ -7,26 +7,17 @@ import ( ) func (c *Client) SetActiveSession() error { - c.sessionHandler.ResetSessionId() - + c.sessionHandler.ResetSessionID() actionType := binary.ActionType_GET_UPDATES - _, sendErr := c.sessionHandler.completeSendMessage(actionType, false, nil) - if sendErr != nil { - return sendErr - } - return nil + return c.sessionHandler.sendMessageNoResponse(actionType, nil) } func (c *Client) IsBugleDefault() (*binary.IsBugleDefaultResponse, error) { - c.sessionHandler.ResetSessionId() + c.sessionHandler.ResetSessionID() actionType := binary.ActionType_IS_BUGLE_DEFAULT - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, nil) - if sendErr != nil { - return nil, sendErr - } - response, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) + response, err := c.sessionHandler.sendMessage(actionType, nil) if err != nil { return nil, err } @@ -43,15 +34,6 @@ func (c *Client) NotifyDittoActivity() error { payload := &binary.NotifyDittoActivityPayload{Success: true} actionType := binary.ActionType_NOTIFY_DITTO_ACTIVITY - sentRequestId, sendErr := c.sessionHandler.completeSendMessage(actionType, true, payload) - if sendErr != nil { - return sendErr - } - - _, err := c.sessionHandler.WaitForResponse(sentRequestId, actionType) - if err != nil { - return err - } - - return nil + _, err := c.sessionHandler.sendMessage(actionType, payload) + return err } diff --git a/libgm/session_handler.go b/libgm/session_handler.go index 8312901..2f20ebf 100644 --- a/libgm/session_handler.go +++ b/libgm/session_handler.go @@ -16,60 +16,64 @@ import ( "go.mau.fi/mautrix-gmessages/libgm/util" ) -/* -type Response struct { - client *Client - ResponseId string - RoutingOpCode int64 - Data *binary.EncodedResponse // base64 encoded (decode -> protomessage) - - StartExecute string - FinishExecute string - DevicePair *pblite.DevicePair -} -*/ - type SessionHandler struct { - client *Client - requests map[string]map[binary.ActionType]*ResponseChan + client *Client + + responseWaiters map[string]chan<- *pblite.Response + responseWaitersLock sync.Mutex ackMapLock sync.Mutex ackMap []string ackTicker *time.Ticker - sessionId string + sessionID string responseTimeout time.Duration } -func (s *SessionHandler) SetResponseTimeout(milliSeconds int) { - s.responseTimeout = time.Duration(milliSeconds) * time.Millisecond +func (s *SessionHandler) ResetSessionID() { + s.sessionID = util.RandomUUIDv4() } -func (s *SessionHandler) ResetSessionId() { - s.sessionId = util.RandomUUIDv4() +func (s *SessionHandler) sendMessageNoResponse(actionType binary.ActionType, encryptedData proto.Message) error { + _, payload, _, err := s.buildMessage(actionType, encryptedData) + if err != nil { + return err + } + + _, err = s.client.rpc.sendMessageRequest(util.SEND_MESSAGE, payload) + return err } -func (s *SessionHandler) completeSendMessage(actionType binary.ActionType, addToChannel bool, encryptedData proto.Message) (string, error) { - requestId, payload, action, buildErr := s.buildMessage(actionType, encryptedData) +func (s *SessionHandler) sendAsyncMessage(actionType binary.ActionType, encryptedData proto.Message) (<-chan *pblite.Response, error) { + requestID, payload, _, buildErr := s.buildMessage(actionType, encryptedData) if buildErr != nil { - return "", buildErr + return nil, buildErr } - if addToChannel { - s.addRequestToChannel(requestId, action) - } + ch := s.waitResponse(requestID) _, reqErr := s.client.rpc.sendMessageRequest(util.SEND_MESSAGE, payload) if reqErr != nil { - return "", reqErr + s.cancelResponse(requestID, ch) + return nil, reqErr } - return requestId, nil + return ch, nil +} + +func (s *SessionHandler) sendMessage(actionType binary.ActionType, encryptedData proto.Message) (*pblite.Response, error) { + ch, err := s.sendAsyncMessage(actionType, encryptedData) + if err != nil { + return nil, err + } + + // TODO add timeout + return <-ch, nil } func (s *SessionHandler) buildMessage(actionType binary.ActionType, encryptedData proto.Message) (string, []byte, binary.ActionType, error) { - var requestId string + var requestID string pairedDevice := s.client.authData.DevicePair.Mobile - sessionId := s.client.sessionHandler.sessionId + sessionId := s.client.sessionHandler.sessionID token := s.client.authData.TachyonAuthToken routeInfo, ok := routes.Routes[actionType] @@ -78,12 +82,12 @@ func (s *SessionHandler) buildMessage(actionType binary.ActionType, encryptedDat } if routeInfo.UseSessionID { - requestId = s.sessionId + requestID = s.sessionID } else { - requestId = util.RandomUUIDv4() + requestID = util.RandomUUIDv4() } - tmpMessage := payload.NewSendMessageBuilder(token, pairedDevice, requestId, sessionId).SetRoute(routeInfo.Action).SetSessionId(s.sessionId) + tmpMessage := payload.NewSendMessageBuilder(token, pairedDevice, requestID, sessionId).SetRoute(routeInfo.Action).SetSessionId(s.sessionID) if encryptedData != nil { tmpMessage.SetEncryptedProtoMessage(encryptedData, s.client.authData.Cryptor) @@ -98,16 +102,17 @@ func (s *SessionHandler) buildMessage(actionType binary.ActionType, encryptedDat return "", nil, 0, buildErr } - return requestId, message, routeInfo.Action, nil + return requestID, message, routeInfo.Action, nil } -func (s *SessionHandler) addResponseAck(responseId string) { - s.client.Logger.Debug().Any("responseId", responseId).Msg("Added to ack map") +func (s *SessionHandler) queueMessageAck(messageID string) { s.ackMapLock.Lock() defer s.ackMapLock.Unlock() - hasResponseId := slices.Contains(s.ackMap, responseId) - if !hasResponseId { - s.ackMap = append(s.ackMap, responseId) + if !slices.Contains(s.ackMap, messageID) { + s.ackMap = append(s.ackMap, messageID) + s.client.Logger.Trace().Any("message_id", messageID).Msg("Queued ack for message") + } else { + s.client.Logger.Trace().Any("message_id", messageID).Msg("Ack for message was already queued") } } diff --git a/libgm/useralert_handler.go b/libgm/useralert_handler.go index 8e333f1..e29d129 100644 --- a/libgm/useralert_handler.go +++ b/libgm/useralert_handler.go @@ -26,9 +26,9 @@ func (c *Client) handleUserAlertEvent(res *pblite.Response, data *binary.UserAle alertType := data.AlertType switch alertType { case binary.AlertType_BROWSER_ACTIVE: - newSessionId := res.Data.RequestId + newSessionId := res.Data.RequestID c.Logger.Info().Any("sessionId", newSessionId).Msg("[NEW_BROWSER_ACTIVE] Opened new browser connection") - if newSessionId != c.sessionHandler.sessionId { + if newSessionId != c.sessionHandler.sessionID { evt := events.NewBrowserActive(newSessionId) c.triggerEvent(evt) } else {