From a6b91da574a5a0518385a5c7a246827cd189b6f8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 18 Jul 2023 02:11:43 +0300 Subject: [PATCH] Use generics for casting responses --- libgm/conversations.go | 118 +++------------------------------------ libgm/messages.go | 27 +-------- libgm/session.go | 16 +----- libgm/session_handler.go | 13 +++++ 4 files changed, 25 insertions(+), 149 deletions(-) diff --git a/libgm/conversations.go b/libgm/conversations.go index cf908a2..1996274 100644 --- a/libgm/conversations.go +++ b/libgm/conversations.go @@ -1,8 +1,6 @@ package libgm import ( - "fmt" - "go.mau.fi/mautrix-gmessages/libgm/gmproto" ) @@ -15,17 +13,7 @@ func (c *Client) ListConversations(count int64, folder gmproto.ListConversations //} else { actionType := gmproto.ActionType_LIST_CONVERSATIONS - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.Conversations) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.Conversations", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.Conversations](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) ListContacts() (*gmproto.ListContactsResponse, error) { @@ -35,18 +23,7 @@ func (c *Client) ListContacts() (*gmproto.ListContactsResponse, error) { I3: 50, } actionType := gmproto.ActionType_LIST_CONTACTS - - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.ListContactsResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.ListContactsResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.ListContactsResponse](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) ListTopContacts() (*gmproto.ListTopContactsResponse, error) { @@ -54,51 +31,18 @@ func (c *Client) ListTopContacts() (*gmproto.ListTopContactsResponse, error) { Count: 8, } actionType := gmproto.ActionType_LIST_TOP_CONTACTS - - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.ListTopContactsResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.ListTopContactsResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.ListTopContactsResponse](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) GetOrCreateConversation(req *gmproto.GetOrCreateConversationPayload) (*gmproto.GetOrCreateConversationResponse, error) { actionType := gmproto.ActionType_GET_OR_CREATE_CONVERSATION - - response, err := c.sessionHandler.sendMessage(actionType, req) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.GetOrCreateConversationResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.GetOrCreateConversationResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.GetOrCreateConversationResponse](c.sessionHandler.sendMessage(actionType, req)) } func (c *Client) GetConversationType(conversationID string) (*gmproto.GetConversationTypeResponse, error) { payload := &gmproto.ConversationTypePayload{ConversationID: conversationID} actionType := gmproto.ActionType_GET_CONVERSATION_TYPE - - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.GetConversationTypeResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.GetConversationTypeResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.GetConversationTypeResponse](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) FetchMessages(conversationID string, count int64, cursor *gmproto.Cursor) (*gmproto.FetchMessagesResponse, error) { @@ -106,53 +50,19 @@ func (c *Client) FetchMessages(conversationID string, count int64, cursor *gmpro if cursor != nil { payload.Cursor = cursor } - actionType := gmproto.ActionType_LIST_MESSAGES - - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.FetchMessagesResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.FetchMessagesResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.FetchMessagesResponse](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) SendMessage(payload *gmproto.SendMessagePayload) (*gmproto.SendMessageResponse, error) { actionType := gmproto.ActionType_SEND_MESSAGE - - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.SendMessageResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.SendMessageResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.SendMessageResponse](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) GetParticipantThumbnail(convID string) (*gmproto.ParticipantThumbnail, error) { payload := &gmproto.GetParticipantThumbnailPayload{ConversationID: convID} actionType := gmproto.ActionType_GET_PARTICIPANTS_THUMBNAIL - - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.ParticipantThumbnail) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.ParticipantThumbnail", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.ParticipantThumbnail](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) UpdateConversation(convBuilder *ConversationBuilder) (*gmproto.UpdateConversationResponse, error) { @@ -165,17 +75,7 @@ func (c *Client) UpdateConversation(convBuilder *ConversationBuilder) (*gmproto. actionType := gmproto.ActionType_UPDATE_CONVERSATION - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.UpdateConversationResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.UpdateConversationResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.UpdateConversationResponse](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) SetTyping(convID string) error { diff --git a/libgm/messages.go b/libgm/messages.go index fc2c498..719bc9a 100644 --- a/libgm/messages.go +++ b/libgm/messages.go @@ -1,42 +1,19 @@ package libgm import ( - "fmt" - "go.mau.fi/mautrix-gmessages/libgm/gmproto" ) func (c *Client) SendReaction(payload *gmproto.SendReactionPayload) (*gmproto.SendReactionResponse, error) { actionType := gmproto.ActionType_SEND_REACTION - - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.SendReactionResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.SendReactionResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.SendReactionResponse](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) DeleteMessage(messageID string) (*gmproto.DeleteMessageResponse, error) { payload := &gmproto.DeleteMessagePayload{MessageID: messageID} actionType := gmproto.ActionType_DELETE_MESSAGE - response, err := c.sessionHandler.sendMessage(actionType, payload) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.DeleteMessageResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.DeleteMessagesResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.DeleteMessageResponse](c.sessionHandler.sendMessage(actionType, payload)) } func (c *Client) MarkRead(conversationID, messageID string) error { diff --git a/libgm/session.go b/libgm/session.go index 6d068f5..5ae9d3e 100644 --- a/libgm/session.go +++ b/libgm/session.go @@ -1,8 +1,6 @@ package libgm import ( - "fmt" - "go.mau.fi/mautrix-gmessages/libgm/gmproto" ) @@ -14,20 +12,8 @@ func (c *Client) SetActiveSession() error { func (c *Client) IsBugleDefault() (*gmproto.IsBugleDefaultResponse, error) { c.sessionHandler.ResetSessionID() - actionType := gmproto.ActionType_IS_BUGLE_DEFAULT - - response, err := c.sessionHandler.sendMessage(actionType, nil) - if err != nil { - return nil, err - } - - res, ok := response.DecryptedMessage.(*gmproto.IsBugleDefaultResponse) - if !ok { - return nil, fmt.Errorf("unexpected response type %T, expected *gmproto.IsBugleDefaultResponse", response.DecryptedMessage) - } - - return res, nil + return typedResponse[*gmproto.IsBugleDefaultResponse](c.sessionHandler.sendMessage(actionType, nil)) } func (c *Client) NotifyDittoActivity() error { diff --git a/libgm/session_handler.go b/libgm/session_handler.go index d9a743b..28d7c4f 100644 --- a/libgm/session_handler.go +++ b/libgm/session_handler.go @@ -61,6 +61,19 @@ func (s *SessionHandler) sendAsyncMessage(actionType gmproto.ActionType, encrypt return ch, nil } +func typedResponse[T proto.Message](resp *IncomingRPCMessage, err error) (casted T, retErr error) { + if err != nil { + retErr = err + return + } + var ok bool + casted, ok = resp.DecryptedMessage.(T) + if !ok { + retErr = fmt.Errorf("unexpected response type %T, expected %T", resp.DecryptedMessage, casted) + } + return +} + func (s *SessionHandler) sendMessage(actionType gmproto.ActionType, encryptedData proto.Message) (*IncomingRPCMessage, error) { ch, err := s.sendAsyncMessage(actionType, encryptedData) if err != nil {