From 3e2348447ac9985abda59b0bc07de16f419552fc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jul 2023 01:20:32 +0300 Subject: [PATCH] Merge response waiting methods with sending --- libgm/response_handler.go | 48 --------------------------------------- libgm/session_handler.go | 44 +++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 48 deletions(-) delete mode 100644 libgm/response_handler.go diff --git a/libgm/response_handler.go b/libgm/response_handler.go deleted file mode 100644 index 966de1c..0000000 --- a/libgm/response_handler.go +++ /dev/null @@ -1,48 +0,0 @@ -package libgm - -import ( - "encoding/base64" -) - -func (s *SessionHandler) waitResponse(requestID string) chan *IncomingRPCMessage { - ch := make(chan *IncomingRPCMessage, 1) - s.responseWaitersLock.Lock() - s.responseWaiters[requestID] = ch - s.responseWaitersLock.Unlock() - return ch -} - -func (s *SessionHandler) cancelResponse(requestID string, ch chan *IncomingRPCMessage) { - s.responseWaitersLock.Lock() - close(ch) - delete(s.responseWaiters, requestID) - s.responseWaitersLock.Unlock() -} - -func (s *SessionHandler) receiveResponse(msg *IncomingRPCMessage) bool { - if msg.Message == nil { - return false - } - requestID := msg.Message.SessionID - s.responseWaitersLock.Lock() - ch, ok := s.responseWaiters[requestID] - if !ok { - s.responseWaitersLock.Unlock() - return false - } - delete(s.responseWaiters, requestID) - s.responseWaitersLock.Unlock() - evt := s.client.Logger.Trace(). - Str("request_id", requestID) - if evt.Enabled() { - if msg.DecryptedData != nil { - evt.Str("data", base64.StdEncoding.EncodeToString(msg.DecryptedData)) - } - if msg.DecryptedMessage != nil { - evt.Str("proto_name", string(msg.DecryptedMessage.ProtoReflect().Descriptor().FullName())) - } - } - evt.Msg("Received response") - ch <- msg - return true -} diff --git a/libgm/session_handler.go b/libgm/session_handler.go index fbcae52..8cd2e02 100644 --- a/libgm/session_handler.go +++ b/libgm/session_handler.go @@ -1,6 +1,7 @@ package libgm import ( + "encoding/base64" "fmt" "sync" "time" @@ -72,6 +73,49 @@ func typedResponse[T proto.Message](resp *IncomingRPCMessage, err error) (casted return } +func (s *SessionHandler) waitResponse(requestID string) chan *IncomingRPCMessage { + ch := make(chan *IncomingRPCMessage, 1) + s.responseWaitersLock.Lock() + s.responseWaiters[requestID] = ch + s.responseWaitersLock.Unlock() + return ch +} + +func (s *SessionHandler) cancelResponse(requestID string, ch chan *IncomingRPCMessage) { + s.responseWaitersLock.Lock() + close(ch) + delete(s.responseWaiters, requestID) + s.responseWaitersLock.Unlock() +} + +func (s *SessionHandler) receiveResponse(msg *IncomingRPCMessage) bool { + if msg.Message == nil { + return false + } + requestID := msg.Message.SessionID + s.responseWaitersLock.Lock() + ch, ok := s.responseWaiters[requestID] + if !ok { + s.responseWaitersLock.Unlock() + return false + } + delete(s.responseWaiters, requestID) + s.responseWaitersLock.Unlock() + evt := s.client.Logger.Trace(). + Str("request_id", requestID) + if evt.Enabled() { + if msg.DecryptedData != nil { + evt.Str("data", base64.StdEncoding.EncodeToString(msg.DecryptedData)) + } + if msg.DecryptedMessage != nil { + evt.Str("proto_name", string(msg.DecryptedMessage.ProtoReflect().Descriptor().FullName())) + } + } + evt.Msg("Received response") + ch <- msg + return true +} + func (s *SessionHandler) sendMessageWithParams(params SendMessageParams) (*IncomingRPCMessage, error) { ch, err := s.sendAsyncMessage(params) if err != nil {