From 9a45e6a534fe2de645f819aafe281f2c30ba4ff9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 8 Aug 2023 18:45:48 +0300 Subject: [PATCH] Refactor message status handling and bridge read receipts --- ROADMAP.md | 3 +- backfill.go | 5 +- database/message.go | 5 ++ messagetracking.go | 8 ++-- portal.go | 113 ++++++++++++++++++++++++++++++++------------ 5 files changed, 97 insertions(+), 37 deletions(-) diff --git a/ROADMAP.md b/ROADMAP.md index 5544669..0c446eb 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -15,7 +15,8 @@ * [x] Replies (RCS) * [x] Reactions (RCS) * [ ] Typing notifications (RCS) - * [ ] Read receipts (RCS) + * [x] Read receipts in 1:1 chats (RCS) + * [ ] Read receipts in groups (RCS) * [x] Message deletions (own device only) * Misc * [x] Automatic portal creation diff --git a/backfill.go b/backfill.go index 86dff9a..0bdc734 100644 --- a/backfill.go +++ b/backfill.go @@ -135,8 +135,9 @@ func (portal *Portal) forwardBackfill(ctx context.Context, user *User, after tim for i := len(resp.Messages) - 1; i >= 0; i-- { evt := resp.Messages[i] // TODO this should check the database too - if evtID := portal.isOutgoingMessage(evt); evtID != "" { - log.Debug().Str("event_id", evtID.String()).Msg("Got echo for outgoing message in backfill batch") + if dbMsg := portal.isOutgoingMessage(evt); dbMsg != nil { + log.Debug().Str("event_id", dbMsg.MXID.String()).Msg("Got echo for outgoing message in backfill batch") + portal.handleExistingMessageUpdate(ctx, user, dbMsg, evt) continue } else if !time.UnixMicro(evt.Timestamp).After(after) { continue diff --git a/database/message.go b/database/message.go index 311c11f..faa14d7 100644 --- a/database/message.go +++ b/database/message.go @@ -74,6 +74,11 @@ func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message, type MessageStatus struct { Type gmproto.MessageStatusType + + MSSSent bool + MSSFailSent bool + MSSDeliverySent bool + ReadReceiptSent bool } type Message struct { diff --git a/messagetracking.go b/messagetracking.go index dd6a512..06f960b 100644 --- a/messagetracking.go +++ b/messagetracking.go @@ -115,7 +115,7 @@ func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType stri return resp.EventID } -func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error) { +func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) { if !portal.bridge.Config.Bridge.MessageStatusEvents { return } @@ -134,6 +134,8 @@ func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error) { EventID: evtID, }, LastRetry: lastRetry, + + DeliveredToUsers: deliveredTo, } if err == nil { content.Status = event.MessageStatusSuccess @@ -189,7 +191,7 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin if sendNotice { ms.setNoticeID(portal.sendErrorMessage(evt, err, msgType, isCertain, ms.getNoticeID())) } - portal.sendStatusEvent(origEvtID, evt.ID, err) + portal.sendStatusEvent(origEvtID, evt.ID, err, nil) } else { portal.zlog.Debug(). Str("event_id", evt.ID.String()). @@ -198,7 +200,7 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin portal.sendDeliveryReceipt(evt.ID) portal.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepRemote, ms.getRetryNum()) if msgType != "message" { - portal.sendStatusEvent(origEvtID, evt.ID, nil) + portal.sendStatusEvent(origEvtID, evt.ID, nil, nil) } if prevNotice := ms.popNoticeID(); prevNotice != "" { _, _ = portal.MainIntent().RedactEvent(portal.MXID, prevNotice, mautrix.ReqRedact{ diff --git a/portal.go b/portal.go index 9089daa..acfd647 100644 --- a/portal.go +++ b/portal.go @@ -313,39 +313,19 @@ func (portal *Portal) handleMessageLoop() { } } -func (portal *Portal) isOutgoingMessage(msg *gmproto.Message) id.EventID { +func (portal *Portal) isOutgoingMessage(msg *gmproto.Message) *database.Message { portal.outgoingMessagesLock.Lock() defer portal.outgoingMessagesLock.Unlock() out, ok := portal.outgoingMessages[msg.TmpID] if ok { - if !out.Saved { - portal.markHandled(&ConvertedMessage{ - ID: msg.MessageID, - Timestamp: time.UnixMicro(msg.GetTimestamp()), - SenderID: msg.ParticipantID, - }, out.ID, true) - out.Saved = true - } - switch msg.GetMessageStatus().GetStatus() { - case gmproto.MessageStatusType_OUTGOING_DELIVERED, gmproto.MessageStatusType_OUTGOING_COMPLETE, gmproto.MessageStatusType_OUTGOING_DISPLAYED: - delete(portal.outgoingMessages, msg.TmpID) - go portal.sendStatusEvent(out.ID, "", nil) - case gmproto.MessageStatusType_OUTGOING_FAILED_GENERIC, - gmproto.MessageStatusType_OUTGOING_FAILED_EMERGENCY_NUMBER, - gmproto.MessageStatusType_OUTGOING_CANCELED, - gmproto.MessageStatusType_OUTGOING_FAILED_TOO_LARGE, - gmproto.MessageStatusType_OUTGOING_FAILED_RECIPIENT_LOST_RCS, - gmproto.MessageStatusType_OUTGOING_FAILED_NO_RETRY_NO_FALLBACK, - gmproto.MessageStatusType_OUTGOING_FAILED_RECIPIENT_DID_NOT_DECRYPT, - gmproto.MessageStatusType_OUTGOING_FAILED_RECIPIENT_LOST_ENCRYPTION, - gmproto.MessageStatusType_OUTGOING_FAILED_RECIPIENT_DID_NOT_DECRYPT_NO_MORE_RETRY: - err := OutgoingStatusError(msg.GetMessageStatus().GetStatus()) - go portal.sendStatusEvent(out.ID, "", err) - // TODO error notice - } - return out.ID + delete(portal.outgoingMessages, msg.TmpID) + return portal.markHandled(&ConvertedMessage{ + ID: msg.MessageID, + Timestamp: time.UnixMicro(msg.GetTimestamp()), + SenderID: msg.ParticipantID, + }, out.ID, true) } - return "" + return nil } func hasInProgressMedia(msg *gmproto.Message) bool { for _, part := range msg.MessageInfo { @@ -357,6 +337,76 @@ func hasInProgressMedia(msg *gmproto.Message) bool { return false } +func isSuccessfullySentStatus(status gmproto.MessageStatusType) bool { + switch status { + case gmproto.MessageStatusType_OUTGOING_DELIVERED, gmproto.MessageStatusType_OUTGOING_COMPLETE, gmproto.MessageStatusType_OUTGOING_DISPLAYED: + return true + default: + return false + } +} + +func isFailSendStatus(status gmproto.MessageStatusType) bool { + switch status { + case gmproto.MessageStatusType_OUTGOING_FAILED_GENERIC, + gmproto.MessageStatusType_OUTGOING_FAILED_EMERGENCY_NUMBER, + gmproto.MessageStatusType_OUTGOING_CANCELED, + gmproto.MessageStatusType_OUTGOING_FAILED_TOO_LARGE, + gmproto.MessageStatusType_OUTGOING_FAILED_RECIPIENT_LOST_RCS, + gmproto.MessageStatusType_OUTGOING_FAILED_NO_RETRY_NO_FALLBACK, + gmproto.MessageStatusType_OUTGOING_FAILED_RECIPIENT_DID_NOT_DECRYPT, + gmproto.MessageStatusType_OUTGOING_FAILED_RECIPIENT_LOST_ENCRYPTION, + gmproto.MessageStatusType_OUTGOING_FAILED_RECIPIENT_DID_NOT_DECRYPT_NO_MORE_RETRY: + return true + default: + return false + } +} + +func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *User, dbMsg *database.Message, evt *gmproto.Message) { + log := zerolog.Ctx(ctx) + portal.syncReactions(ctx, source, dbMsg, evt.Reactions) + newStatus := evt.GetMessageStatus().GetStatus() + if dbMsg.Status.Type != newStatus { + log.Debug().Str("old_status", dbMsg.Status.Type.String()).Msg("Message status changed") + switch { + case !dbMsg.Status.ReadReceiptSent && portal.IsPrivateChat() && newStatus == gmproto.MessageStatusType_OUTGOING_DISPLAYED: + dbMsg.Status.ReadReceiptSent = true + if !dbMsg.Status.MSSDeliverySent { + dbMsg.Status.MSSDeliverySent = true + dbMsg.Status.MSSSent = true + go portal.sendStatusEvent(dbMsg.MXID, "", nil, &[]id.UserID{portal.MainIntent().UserID}) + } + err := portal.MainIntent().MarkRead(portal.MXID, dbMsg.MXID) + if err != nil { + log.Warn().Err(err).Msg("Failed to mark message as read") + } + case !dbMsg.Status.MSSDeliverySent && portal.IsPrivateChat() && newStatus == gmproto.MessageStatusType_OUTGOING_DELIVERED: + dbMsg.Status.MSSDeliverySent = true + dbMsg.Status.MSSSent = true + go portal.sendStatusEvent(dbMsg.MXID, "", nil, &[]id.UserID{portal.MainIntent().UserID}) + case !dbMsg.Status.MSSSent && isSuccessfullySentStatus(newStatus): + dbMsg.Status.MSSSent = true + var deliveredTo *[]id.UserID + // TODO SMSes can enable delivery receipts too, but can it be detected? + if portal.IsPrivateChat() && portal.Type == gmproto.ConversationType_RCS { + deliveredTo = &[]id.UserID{} + } + go portal.sendStatusEvent(dbMsg.MXID, "", nil, deliveredTo) + case !dbMsg.Status.MSSFailSent && !dbMsg.Status.MSSSent && isFailSendStatus(newStatus): + go portal.sendStatusEvent(dbMsg.MXID, "", OutgoingStatusError(newStatus), nil) + // TODO error notice + default: + // TODO do something? + } + dbMsg.Status.Type = newStatus + err := dbMsg.UpdateStatus(ctx) + if err != nil { + log.Warn().Err(err).Msg("Failed to save updated message status to database") + } + } +} + func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) { if len(portal.MXID) == 0 { portal.zlog.Warn().Msg("handleMessage called even though portal.MXID is empty") @@ -385,8 +435,9 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) { log.Debug().Msg("Not handling incoming message that doesn't have full media yet") return } - if evtID := portal.isOutgoingMessage(evt); evtID != "" { - log.Debug().Str("event_id", evtID.String()).Msg("Got echo for outgoing message") + if existingMsg := portal.isOutgoingMessage(evt); existingMsg != nil { + log.Debug().Str("event_id", existingMsg.MXID.String()).Msg("Got echo for outgoing message") + portal.handleExistingMessageUpdate(ctx, source, existingMsg, evt) return } existingMsg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Key, evt.MessageID) @@ -394,7 +445,7 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) { log.Err(err).Msg("Failed to check if message is duplicate") } else if existingMsg != nil { log.Debug().Msg("Not handling duplicate message") - portal.syncReactions(ctx, source, existingMsg, evt.Reactions) + portal.handleExistingMessageUpdate(ctx, source, existingMsg, evt) return }