Unnest long polling handler

This commit is contained in:
Tulir Asokan 2023-07-19 14:12:23 +03:00
parent d99da61869
commit 01464c5cc2
5 changed files with 94 additions and 114 deletions

View file

@ -43,10 +43,17 @@ type EventHandler func(evt any)
type Client struct { type Client struct {
Logger zerolog.Logger Logger zerolog.Logger
rpc *RPC
evHandler EventHandler evHandler EventHandler
sessionHandler *SessionHandler sessionHandler *SessionHandler
longPollingConn io.Closer
listenID int
skipCount int
disconnecting bool
recentUpdates [8][32]byte
recentUpdatesPtr int
conversationsFetchedOnce bool conversationsFetchedOnce bool
AuthData *AuthData AuthData *AuthData
@ -74,7 +81,6 @@ func NewClient(authData *AuthData, logger zerolog.Logger) *Client {
http: &http.Client{}, http: &http.Client{},
} }
sessionHandler.client = cli sessionHandler.client = cli
cli.rpc = &RPC{client: cli}
cli.FetchConfigVersion() cli.FetchConfigVersion()
return cli return cli
} }
@ -114,7 +120,7 @@ func (c *Client) Connect() error {
return fmt.Errorf("failed to get web encryption key: %w", err) return fmt.Errorf("failed to get web encryption key: %w", err)
} }
c.updateWebEncryptionKey(webEncryptionKeyResponse.GetKey()) c.updateWebEncryptionKey(webEncryptionKeyResponse.GetKey())
go c.rpc.ListenReceiveMessages(true) go c.ListenReceiveMessages(true)
c.sessionHandler.startAckInterval() c.sessionHandler.startAckInterval()
bugleRes, bugleErr := c.IsBugleDefault() bugleRes, bugleErr := c.IsBugleDefault()
@ -130,12 +136,13 @@ func (c *Client) Connect() error {
} }
func (c *Client) Disconnect() { func (c *Client) Disconnect() {
c.rpc.CloseConnection() c.closeLongPolling()
c.http.CloseIdleConnections() c.http.CloseIdleConnections()
} }
func (c *Client) IsConnected() bool { func (c *Client) IsConnected() bool {
return c.rpc != nil // TODO add better check (longPollingConn is set to nil while the polling reconnects)
return c.longPollingConn != nil
} }
func (c *Client) IsLoggedIn() bool { func (c *Client) IsLoggedIn() bool {
@ -143,10 +150,7 @@ func (c *Client) IsLoggedIn() bool {
} }
func (c *Client) Reconnect() error { func (c *Client) Reconnect() error {
c.rpc.CloseConnection() c.closeLongPolling()
for c.rpc.conn != nil {
time.Sleep(time.Millisecond * 100)
}
err := c.Connect() err := c.Connect()
if err != nil { if err != nil {
c.Logger.Err(err).Msg("Failed to reconnect") c.Logger.Err(err).Msg("Failed to reconnect")

View file

@ -38,7 +38,7 @@ var responseType = map[gmproto.ActionType]proto.Message{
gmproto.ActionType_UPDATE_CONVERSATION: &gmproto.UpdateConversationResponse{}, gmproto.ActionType_UPDATE_CONVERSATION: &gmproto.UpdateConversationResponse{},
} }
func (r *RPC) decryptInternalMessage(data *gmproto.IncomingRPCMessage) (*IncomingRPCMessage, error) { func (c *Client) decryptInternalMessage(data *gmproto.IncomingRPCMessage) (*IncomingRPCMessage, error) {
msg := &IncomingRPCMessage{ msg := &IncomingRPCMessage{
IncomingRPCMessage: data, IncomingRPCMessage: data,
} }
@ -60,7 +60,7 @@ func (r *RPC) decryptInternalMessage(data *gmproto.IncomingRPCMessage) (*Incomin
msg.DecryptedMessage = responseStruct.ProtoReflect().New().Interface() msg.DecryptedMessage = responseStruct.ProtoReflect().New().Interface()
} }
if msg.Message.EncryptedData != nil { if msg.Message.EncryptedData != nil {
msg.DecryptedData, err = r.client.AuthData.RequestCrypto.Decrypt(msg.Message.EncryptedData) msg.DecryptedData, err = c.AuthData.RequestCrypto.Decrypt(msg.Message.EncryptedData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -77,21 +77,21 @@ func (r *RPC) decryptInternalMessage(data *gmproto.IncomingRPCMessage) (*Incomin
return msg, nil return msg, nil
} }
func (r *RPC) deduplicateHash(hash [32]byte) bool { func (c *Client) deduplicateHash(hash [32]byte) bool {
const recentUpdatesLen = len(r.recentUpdates) const recentUpdatesLen = len(c.recentUpdates)
for i := r.recentUpdatesPtr + recentUpdatesLen - 1; i >= r.recentUpdatesPtr; i-- { for i := c.recentUpdatesPtr + recentUpdatesLen - 1; i >= c.recentUpdatesPtr; i-- {
if r.recentUpdates[i%recentUpdatesLen] == hash { if c.recentUpdates[i%recentUpdatesLen] == hash {
return true return true
} }
} }
r.recentUpdates[r.recentUpdatesPtr] = hash c.recentUpdates[c.recentUpdatesPtr] = hash
r.recentUpdatesPtr = (r.recentUpdatesPtr + 1) % recentUpdatesLen c.recentUpdatesPtr = (c.recentUpdatesPtr + 1) % recentUpdatesLen
return false return false
} }
func (r *RPC) logContent(res *IncomingRPCMessage) { func (c *Client) logContent(res *IncomingRPCMessage) {
if r.client.Logger.Trace().Enabled() && (res.DecryptedData != nil || res.DecryptedMessage != nil) { if c.Logger.Trace().Enabled() && (res.DecryptedData != nil || res.DecryptedMessage != nil) {
evt := r.client.Logger.Trace() evt := c.Logger.Trace()
if res.DecryptedMessage != nil { if res.DecryptedMessage != nil {
evt.Str("proto_name", string(res.DecryptedMessage.ProtoReflect().Descriptor().FullName())) evt.Str("proto_name", string(res.DecryptedMessage.ProtoReflect().Descriptor().FullName()))
} }
@ -104,47 +104,47 @@ func (r *RPC) logContent(res *IncomingRPCMessage) {
} }
} }
func (r *RPC) deduplicateUpdate(msg *IncomingRPCMessage) bool { func (c *Client) deduplicateUpdate(msg *IncomingRPCMessage) bool {
if msg.DecryptedData != nil { if msg.DecryptedData != nil {
contentHash := sha256.Sum256(msg.DecryptedData) contentHash := sha256.Sum256(msg.DecryptedData)
if r.deduplicateHash(contentHash) { if c.deduplicateHash(contentHash) {
r.client.Logger.Trace().Hex("data_hash", contentHash[:]).Msg("Ignoring duplicate update") c.Logger.Trace().Hex("data_hash", contentHash[:]).Msg("Ignoring duplicate update")
return true return true
} }
r.logContent(msg) c.logContent(msg)
} }
return false return false
} }
func (r *RPC) HandleRPCMsg(rawMsg *gmproto.IncomingRPCMessage) { func (c *Client) HandleRPCMsg(rawMsg *gmproto.IncomingRPCMessage) {
msg, err := r.decryptInternalMessage(rawMsg) msg, err := c.decryptInternalMessage(rawMsg)
if err != nil { if err != nil {
r.client.Logger.Err(err).Msg("Failed to decode incoming RPC message") c.Logger.Err(err).Msg("Failed to decode incoming RPC message")
return return
} }
r.client.sessionHandler.queueMessageAck(msg.ResponseID) c.sessionHandler.queueMessageAck(msg.ResponseID)
if r.client.sessionHandler.receiveResponse(msg) { if c.sessionHandler.receiveResponse(msg) {
return return
} }
switch msg.BugleRoute { switch msg.BugleRoute {
case gmproto.BugleRoute_PairEvent: case gmproto.BugleRoute_PairEvent:
go r.client.handlePairingEvent(msg) go c.handlePairingEvent(msg)
case gmproto.BugleRoute_DataEvent: case gmproto.BugleRoute_DataEvent:
if r.skipCount > 0 { if c.skipCount > 0 {
r.skipCount-- c.skipCount--
r.client.Logger.Debug(). c.Logger.Debug().
Any("action", msg.Message.GetAction()). Any("action", msg.Message.GetAction()).
Int("remaining_skip_count", r.skipCount). Int("remaining_skip_count", c.skipCount).
Msg("Skipped DataEvent") Msg("Skipped DataEvent")
if msg.DecryptedMessage != nil { if msg.DecryptedMessage != nil {
r.client.Logger.Trace(). c.Logger.Trace().
Str("proto_name", string(msg.DecryptedMessage.ProtoReflect().Descriptor().FullName())). Str("proto_name", string(msg.DecryptedMessage.ProtoReflect().Descriptor().FullName())).
Str("data", base64.StdEncoding.EncodeToString(msg.DecryptedData)). Str("data", base64.StdEncoding.EncodeToString(msg.DecryptedData)).
Msg("Skipped event data") Msg("Skipped event data")
} }
return return
} }
r.client.handleUpdatesEvent(msg) c.handleUpdatesEvent(msg)
} }
} }

View file

@ -19,7 +19,7 @@ func (c *Client) StartLogin() (string, error) {
return "", err return "", err
} }
c.AuthData.TachyonAuthToken = registered.AuthKeyData.TachyonAuthToken c.AuthData.TachyonAuthToken = registered.AuthKeyData.TachyonAuthToken
go c.rpc.ListenReceiveMessages(false) go c.ListenReceiveMessages(false)
qr, err := c.GenerateQRCodeData(registered.GetPairingKey()) qr, err := c.GenerateQRCodeData(registered.GetPairingKey())
if err != nil { if err != nil {
return "", fmt.Errorf("failed to generate QR code: %w", err) return "", fmt.Errorf("failed to generate QR code: %w", err)

View file

@ -20,113 +20,89 @@ import (
"go.mau.fi/mautrix-gmessages/libgm/util" "go.mau.fi/mautrix-gmessages/libgm/util"
) )
type RPC struct { func (c *Client) ListenReceiveMessages(loggedIn bool) {
client *Client c.listenID++
conn io.ReadCloser listenID := c.listenID
stopping bool
listenID int
skipCount int
recentUpdates [8][32]byte
recentUpdatesPtr int
}
func (r *RPC) ListenReceiveMessages(loggedIn bool) {
r.listenID++
listenID := r.listenID
errored := true errored := true
listenReqID := uuid.NewString() listenReqID := uuid.NewString()
for r.listenID == listenID { for c.listenID == listenID {
err := r.client.refreshAuthToken() err := c.refreshAuthToken()
if err != nil { if err != nil {
r.client.Logger.Err(err).Msg("Error refreshing auth token") c.Logger.Err(err).Msg("Error refreshing auth token")
if loggedIn { if loggedIn {
r.client.triggerEvent(&events.ListenFatalError{Error: fmt.Errorf("failed to refresh auth token: %w", err)}) c.triggerEvent(&events.ListenFatalError{Error: fmt.Errorf("failed to refresh auth token: %w", err)})
} }
return return
} }
r.client.Logger.Debug().Msg("Starting new long-polling request") c.Logger.Debug().Msg("Starting new long-polling request")
payload := &gmproto.ReceiveMessagesRequest{ payload := &gmproto.ReceiveMessagesRequest{
Auth: &gmproto.AuthMessage{ Auth: &gmproto.AuthMessage{
RequestID: listenReqID, RequestID: listenReqID,
TachyonAuthToken: r.client.AuthData.TachyonAuthToken, TachyonAuthToken: c.AuthData.TachyonAuthToken,
ConfigVersion: util.ConfigMessage, ConfigVersion: util.ConfigMessage,
}, },
Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject2{ Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject2{
Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject1{}, Unknown: &gmproto.ReceiveMessagesRequest_UnknownEmptyObject1{},
}, },
} }
resp, err := r.client.makeProtobufHTTPRequest(util.ReceiveMessagesURL, payload, ContentTypePBLite) resp, err := c.makeProtobufHTTPRequest(util.ReceiveMessagesURL, payload, ContentTypePBLite)
if err != nil { if err != nil {
if loggedIn { if loggedIn {
r.client.triggerEvent(&events.ListenTemporaryError{Error: err}) c.triggerEvent(&events.ListenTemporaryError{Error: err})
} }
errored = true errored = true
r.client.Logger.Err(err).Msg("Error making listen request, retrying in 5 seconds") c.Logger.Err(err).Msg("Error making listen request, retrying in 5 seconds")
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
continue continue
} }
if resp.StatusCode >= 400 && resp.StatusCode < 500 { if resp.StatusCode >= 400 && resp.StatusCode < 500 {
r.client.Logger.Error().Int("status_code", resp.StatusCode).Msg("Error making listen request") c.Logger.Error().Int("status_code", resp.StatusCode).Msg("Error making listen request")
if loggedIn { if loggedIn {
r.client.triggerEvent(&events.ListenFatalError{Error: events.HTTPError{Action: "polling", Resp: resp}}) c.triggerEvent(&events.ListenFatalError{Error: events.HTTPError{Action: "polling", Resp: resp}})
} }
return return
} else if resp.StatusCode >= 500 { } else if resp.StatusCode >= 500 {
if loggedIn { if loggedIn {
r.client.triggerEvent(&events.ListenTemporaryError{Error: events.HTTPError{Action: "polling", Resp: resp}}) c.triggerEvent(&events.ListenTemporaryError{Error: events.HTTPError{Action: "polling", Resp: resp}})
} }
errored = true errored = true
r.client.Logger.Debug().Int("statusCode", resp.StatusCode).Msg("5xx error in long polling, retrying in 5 seconds") c.Logger.Debug().Int("statusCode", resp.StatusCode).Msg("5xx error in long polling, retrying in 5 seconds")
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
continue continue
} }
if errored { if errored {
errored = false errored = false
if loggedIn { if loggedIn {
r.client.triggerEvent(&events.ListenRecovered{}) c.triggerEvent(&events.ListenRecovered{})
} }
} }
r.client.Logger.Debug().Int("statusCode", resp.StatusCode).Msg("Long polling opened") c.Logger.Debug().Int("statusCode", resp.StatusCode).Msg("Long polling opened")
r.conn = resp.Body c.longPollingConn = resp.Body
if r.client.AuthData.Browser != nil { if c.AuthData.Browser != nil {
go func() { go func() {
err := r.client.NotifyDittoActivity() err := c.NotifyDittoActivity()
if err != nil { if err != nil {
r.client.Logger.Err(err).Msg("Error notifying ditto activity") c.Logger.Err(err).Msg("Error notifying ditto activity")
} }
}() }()
} }
r.startReadingData(resp.Body) c.startReadingData(resp.Body)
r.conn = nil c.longPollingConn = nil
} }
} }
/* func (c *Client) startReadingData(rc io.ReadCloser) {
The start of a message always begins with byte 44 (",")
If the message is parsable (after , has been removed) as an array of interfaces:
func (r *RPC) tryUnmarshalJSON(jsonData []byte, msgArr *[]interface{}) error {
err := json.Unmarshal(jsonData, &msgArr)
return err
}
then the message is complete and it should continue to the HandleRPCMsg function and it should also reset the buffer so that the next message can be received properly.
if it's not parsable, it should just append the received data to the buf and attempt to parse it until it's parsable. Because that would indicate that the full msg has been received
*/
func (r *RPC) startReadingData(rc io.ReadCloser) {
r.stopping = false
defer rc.Close() defer rc.Close()
c.disconnecting = false
reader := bufio.NewReader(rc) reader := bufio.NewReader(rc)
buf := make([]byte, 2621440) buf := make([]byte, 2621440)
var accumulatedData []byte var accumulatedData []byte
n, err := reader.Read(buf[:2]) n, err := reader.Read(buf[:2])
if err != nil { if err != nil {
r.client.Logger.Err(err).Msg("Error reading opening bytes") c.Logger.Err(err).Msg("Error reading opening bytes")
return return
} else if n != 2 || string(buf[:2]) != "[[" { } else if n != 2 || string(buf[:2]) != "[[" {
r.client.Logger.Err(err).Msg("Opening is not [[") c.Logger.Err(err).Msg("Opening is not [[")
return return
} }
var expectEOF bool var expectEOF bool
@ -134,20 +110,20 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
n, err = reader.Read(buf) n, err = reader.Read(buf)
if err != nil { if err != nil {
var logEvt *zerolog.Event var logEvt *zerolog.Event
if (errors.Is(err, io.EOF) && expectEOF) || r.stopping { if (errors.Is(err, io.EOF) && expectEOF) || c.disconnecting {
logEvt = r.client.Logger.Debug() logEvt = c.Logger.Debug()
} else { } else {
logEvt = r.client.Logger.Warn() logEvt = c.Logger.Warn()
} }
logEvt.Err(err).Msg("Stopped reading data from server") logEvt.Err(err).Msg("Stopped reading data from server")
return return
} else if expectEOF { } else if expectEOF {
r.client.Logger.Warn().Msg("Didn't get EOF after stream end marker") c.Logger.Warn().Msg("Didn't get EOF after stream end marker")
} }
chunk := buf[:n] chunk := buf[:n]
if len(accumulatedData) == 0 { if len(accumulatedData) == 0 {
if len(chunk) == 2 && string(chunk) == "]]" { if len(chunk) == 2 && string(chunk) == "]]" {
r.client.Logger.Debug().Msg("Got stream end marker") c.Logger.Debug().Msg("Got stream end marker")
expectEOF = true expectEOF = true
continue continue
} }
@ -155,7 +131,7 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
} }
accumulatedData = append(accumulatedData, chunk...) accumulatedData = append(accumulatedData, chunk...)
if !json.Valid(accumulatedData) { if !json.Valid(accumulatedData) {
r.client.Logger.Trace().Bytes("data", chunk).Msg("Invalid JSON, reading next chunk") c.Logger.Trace().Bytes("data", chunk).Msg("Invalid JSON, reading next chunk")
continue continue
} }
currentBlock := accumulatedData currentBlock := accumulatedData
@ -163,33 +139,33 @@ func (r *RPC) startReadingData(rc io.ReadCloser) {
msg := &gmproto.LongPollingPayload{} msg := &gmproto.LongPollingPayload{}
err = pblite.Unmarshal(currentBlock, msg) err = pblite.Unmarshal(currentBlock, msg)
if err != nil { if err != nil {
r.client.Logger.Err(err).Msg("Error deserializing pblite message") c.Logger.Err(err).Msg("Error deserializing pblite message")
continue continue
} }
switch { switch {
case msg.GetData() != nil: case msg.GetData() != nil:
r.HandleRPCMsg(msg.GetData()) c.HandleRPCMsg(msg.GetData())
case msg.GetAck() != nil: case msg.GetAck() != nil:
r.client.Logger.Debug().Int32("count", msg.GetAck().GetCount()).Msg("Got startup ack count message") c.Logger.Debug().Int32("count", msg.GetAck().GetCount()).Msg("Got startup ack count message")
r.skipCount = int(msg.GetAck().GetCount()) c.skipCount = int(msg.GetAck().GetCount())
case msg.GetStartRead() != nil: case msg.GetStartRead() != nil:
r.client.Logger.Trace().Msg("Got startRead message") c.Logger.Trace().Msg("Got startRead message")
case msg.GetHeartbeat() != nil: case msg.GetHeartbeat() != nil:
r.client.Logger.Trace().Msg("Got heartbeat message") c.Logger.Trace().Msg("Got heartbeat message")
default: default:
r.client.Logger.Warn(). c.Logger.Warn().
Str("data", base64.StdEncoding.EncodeToString(currentBlock)). Str("data", base64.StdEncoding.EncodeToString(currentBlock)).
Msg("Got unknown message") Msg("Got unknown message")
} }
} }
} }
func (r *RPC) CloseConnection() { func (c *Client) closeLongPolling() {
if r.conn != nil { if conn := c.longPollingConn; conn != nil {
r.client.Logger.Debug().Msg("Closing connection manually") c.Logger.Debug().Msg("Closing long polling connection manually")
r.listenID++ c.listenID++
r.stopping = true c.disconnecting = true
r.conn.Close() _ = conn.Close()
r.conn = nil c.longPollingConn = nil
} }
} }

View file

@ -16,27 +16,27 @@ func (c *Client) handleUpdatesEvent(msg *IncomingRPCMessage) {
switch evt := data.Event.(type) { switch evt := data.Event.(type) {
case *gmproto.UpdateEvents_UserAlertEvent: case *gmproto.UpdateEvents_UserAlertEvent:
c.rpc.logContent(msg) c.logContent(msg)
c.handleUserAlertEvent(msg, evt.UserAlertEvent) c.handleUserAlertEvent(msg, evt.UserAlertEvent)
case *gmproto.UpdateEvents_SettingsEvent: case *gmproto.UpdateEvents_SettingsEvent:
c.rpc.logContent(msg) c.logContent(msg)
c.triggerEvent(evt.SettingsEvent) c.triggerEvent(evt.SettingsEvent)
case *gmproto.UpdateEvents_ConversationEvent: case *gmproto.UpdateEvents_ConversationEvent:
if c.rpc.deduplicateUpdate(msg) { if c.deduplicateUpdate(msg) {
return return
} }
c.triggerEvent(evt.ConversationEvent.GetData()) c.triggerEvent(evt.ConversationEvent.GetData())
case *gmproto.UpdateEvents_MessageEvent: case *gmproto.UpdateEvents_MessageEvent:
if c.rpc.deduplicateUpdate(msg) { if c.deduplicateUpdate(msg) {
return return
} }
c.triggerEvent(evt.MessageEvent.GetData()) c.triggerEvent(evt.MessageEvent.GetData())
case *gmproto.UpdateEvents_TypingEvent: case *gmproto.UpdateEvents_TypingEvent:
c.rpc.logContent(msg) c.logContent(msg)
c.triggerEvent(evt.TypingEvent.GetData()) c.triggerEvent(evt.TypingEvent.GetData())
default: default: