Handle message updates properly instead of dropping in-progress messages

This commit is contained in:
Tulir Asokan 2023-08-09 17:53:11 +03:00
parent 3df8296d9f
commit d1ba596504
4 changed files with 258 additions and 95 deletions

View file

@ -62,7 +62,7 @@ func (portal *Portal) missedForwardBackfill(user *User, lastMessageTS time.Time,
Str("latest_message_id", lastMessageID).
Logger()
ctx := log.WithContext(context.TODO())
if !lastMessageTS.IsZero() && time.Since(lastMessageTS) < 5*time.Minute && portal.lastMessageTS.Before(lastMessageTS) {
if portal.hasSyncedThisRun && !lastMessageTS.IsZero() && time.Since(lastMessageTS) < 5*time.Minute && portal.lastMessageTS.Before(lastMessageTS) {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
prev := portal.pendingRecentBackfill.Swap(&pendingBackfill{cancel: cancel, lastMessageID: lastMessageID, lastMessageTS: lastMessageTS})
@ -80,6 +80,7 @@ func (portal *Portal) missedForwardBackfill(user *User, lastMessageTS time.Time,
portal.forwardBackfillLock.Lock()
defer portal.forwardBackfillLock.Unlock()
portal.hasSyncedThisRun = true
if !lastMessageTS.IsZero() {
if portal.lastMessageTS.IsZero() {
lastMsg, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
@ -192,6 +193,9 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver
dbm.ID = msg.ID
dbm.Sender = msg.SenderID
dbm.Timestamp = msg.Timestamp
dbm.Status.Type = msg.Status
dbm.Status.MediaStatus = msg.MediaStatus
dbm.Status.MediaParts = make(map[string]database.MediaPart, len(msg.Parts))
for i, part := range msg.Parts {
content := event.Content{
@ -217,6 +221,14 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver
events = append(events, evt)
if dbm.MXID == "" {
dbm.MXID = evt.ID
if part.PendingMedia {
dbm.Status.MediaParts[""] = database.MediaPart{PendingMedia: true}
}
} else {
dbm.Status.MediaParts[part.ID] = database.MediaPart{
EventID: evt.ID,
PendingMedia: part.PendingMedia,
}
}
}
if dbm.MXID != "" {
@ -243,7 +255,11 @@ func (portal *Portal) backfillSendLegacy(ctx context.Context, converted []*Conve
var lastEventID id.EventID
eventIDs := make(map[string]id.EventID)
for _, msg := range converted {
var eventID id.EventID
if len(msg.Parts) == 0 {
continue
}
var msgFirstEventID id.EventID
mediaParts := make(map[string]database.MediaPart, len(msg.Parts)-1)
for i, part := range msg.Parts {
if msg.ReplyTo != "" && part.Content.RelatesTo == nil {
replyToEvent, ok := eventIDs[msg.ReplyTo]
@ -256,16 +272,21 @@ func (portal *Portal) backfillSendLegacy(ctx context.Context, converted []*Conve
resp, err := portal.sendMessage(msg.Intent, event.EventMessage, part.Content, part.Extra, msg.Timestamp.UnixMilli())
if err != nil {
log.Err(err).Str("message_id", msg.ID).Int("part", i).Msg("Failed to send message")
} else if eventID == "" {
eventID = resp.EventID
} else {
if msgFirstEventID == "" {
msgFirstEventID = resp.EventID
eventIDs[msg.ID] = resp.EventID
} else {
mediaParts[part.ID] = database.MediaPart{
EventID: resp.EventID,
PendingMedia: part.PendingMedia,
}
}
if resp != nil {
lastEventID = resp.EventID
}
}
if eventID != "" {
portal.markHandled(msg, eventID, false)
if msgFirstEventID != "" {
portal.markHandled(msg, msgFirstEventID, mediaParts, false)
}
}
return lastEventID

View file

@ -72,13 +72,30 @@ func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message,
return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver)
}
type MessageStatus struct {
Type gmproto.MessageStatusType
type MediaPart struct {
EventID id.EventID `json:"mxid,omitempty"`
PendingMedia bool `json:"pending_media,omitempty"`
}
MSSSent bool
MSSFailSent bool
MSSDeliverySent bool
ReadReceiptSent bool
type MessageStatus struct {
Type gmproto.MessageStatusType `json:"type,omitempty"`
MediaStatus string `json:"media_status,omitempty"`
MediaParts map[string]MediaPart `json:"media_parts,omitempty"`
MSSSent bool `json:"mss_sent,omitempty"`
MSSFailSent bool `json:"mss_fail_sent,omitempty"`
MSSDeliverySent bool `json:"mss_delivery_sent,omitempty"`
ReadReceiptSent bool `json:"read_receipt_sent,omitempty"`
}
func (ms *MessageStatus) HasPendingMediaParts() bool {
for _, part := range ms.MediaParts {
if part.PendingMedia {
return true
}
}
return false
}
type Message struct {
@ -145,7 +162,7 @@ func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) err
}
func (msg *Message) UpdateStatus(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, "UPDATE message SET status=$1 WHERE conv_id=$2 AND conv_receiver=$3 AND id=$4", dbutil.JSON{Data: &msg.Status}, msg.Chat.ID, msg.Chat.Receiver, msg.ID)
_, err := msg.db.Conn(ctx).ExecContext(ctx, "UPDATE message SET status=$1, timestamp=$2 WHERE conv_id=$3 AND conv_receiver=$4 AND id=$5", dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.ID, msg.Chat.Receiver, msg.ID)
return err
}

218
portal.go
View file

@ -239,6 +239,8 @@ type Portal struct {
forwardBackfillLock sync.Mutex
lastMessageTS time.Time
lastUserReadID string
hasSyncedThisRun bool
pendingRecentBackfill atomic.Pointer[pendingBackfill]
@ -323,10 +325,11 @@ func (portal *Portal) isOutgoingMessage(msg *gmproto.Message) *database.Message
ID: msg.MessageID,
Timestamp: time.UnixMicro(msg.GetTimestamp()),
SenderID: msg.ParticipantID,
}, out.ID, true)
}, out.ID, nil, true)
}
return nil
}
func hasInProgressMedia(msg *gmproto.Message) bool {
for _, part := range msg.MessageInfo {
media, ok := part.GetData().(*gmproto.MessageInfo_MediaContent)
@ -346,6 +349,28 @@ func isSuccessfullySentStatus(status gmproto.MessageStatusType) bool {
}
}
func downloadPendingStatusMessage(status gmproto.MessageStatusType) string {
switch status {
case gmproto.MessageStatusType_INCOMING_YET_TO_MANUAL_DOWNLOAD:
return "Attachment message (auto-download is disabled)"
case gmproto.MessageStatusType_INCOMING_MANUAL_DOWNLOADING,
gmproto.MessageStatusType_INCOMING_AUTO_DOWNLOADING,
gmproto.MessageStatusType_INCOMING_RETRYING_MANUAL_DOWNLOAD,
gmproto.MessageStatusType_INCOMING_RETRYING_AUTO_DOWNLOAD:
return "Downloading message..."
case gmproto.MessageStatusType_INCOMING_DOWNLOAD_FAILED:
return "Message download failed"
case gmproto.MessageStatusType_INCOMING_DOWNLOAD_FAILED_TOO_LARGE:
return "Message download failed (too large)"
case gmproto.MessageStatusType_INCOMING_DOWNLOAD_FAILED_SIM_HAS_NO_DATA:
return "Message download failed (no mobile data connection)"
case gmproto.MessageStatusType_INCOMING_DOWNLOAD_CANCELED:
return "Message download canceled"
default:
return ""
}
}
func isFailSendStatus(status gmproto.MessageStatusType) bool {
switch status {
case gmproto.MessageStatusType_OUTGOING_FAILED_GENERIC,
@ -367,9 +392,62 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
log := zerolog.Ctx(ctx)
portal.syncReactions(ctx, source, dbMsg, evt.Reactions)
newStatus := evt.GetMessageStatus().GetStatus()
if dbMsg.Status.Type != newStatus {
if dbMsg.Status.Type == newStatus && !(dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt)) {
return
}
log.Debug().Str("old_status", dbMsg.Status.Type.String()).Msg("Message status changed")
switch {
case newStatus == gmproto.MessageStatusType_MESSAGE_DELETED:
for partID, part := range dbMsg.Status.MediaParts {
if part.EventID != "" {
if _, err := portal.MainIntent().RedactEvent(portal.MXID, part.EventID); err != nil {
log.Err(err).Str("part_iD", partID).Msg("Failed to redact part of deleted message")
}
}
}
if _, err := portal.MainIntent().RedactEvent(portal.MXID, dbMsg.MXID); err != nil {
log.Err(err).Msg("Failed to redact deleted message")
} else if err = dbMsg.Delete(ctx); err != nil {
log.Err(err).Msg("Failed to delete message from database")
} else {
log.Debug().Msg("Handled message deletion")
}
return
case dbMsg.Status.MediaStatus != downloadPendingStatusMessage(newStatus),
dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt):
converted := portal.convertGoogleMessage(ctx, source, evt, false)
dbMsg.Status.MediaStatus = converted.MediaStatus
if dbMsg.Status.MediaParts == nil {
dbMsg.Status.MediaParts = make(map[string]database.MediaPart)
}
eventIDs := make([]id.EventID, 0, len(converted.Parts))
for i, part := range converted.Parts {
isEdit := true
ts := time.Now().UnixMilli()
if i == 0 {
part.Content.SetEdit(dbMsg.MXID)
} else if existingPart, ok := dbMsg.Status.MediaParts[part.ID]; ok {
part.Content.SetEdit(existingPart.EventID)
} else {
ts = converted.Timestamp.UnixMilli()
isEdit = false
}
resp, err := portal.sendMessage(converted.Intent, event.EventMessage, part.Content, part.Extra, ts)
if err != nil {
log.Err(err).Msg("Failed to send message")
} else {
eventIDs = append(eventIDs, resp.EventID)
}
if i == 0 {
dbMsg.Status.MediaParts[""] = database.MediaPart{PendingMedia: part.PendingMedia}
} else if !isEdit {
dbMsg.Status.MediaParts[part.ID] = database.MediaPart{EventID: resp.EventID, PendingMedia: part.PendingMedia}
}
}
if len(eventIDs) > 0 {
portal.sendDeliveryReceipt(eventIDs[len(eventIDs)-1])
log.Debug().Interface("event_ids", eventIDs).Msg("Handled update to message")
}
case !dbMsg.Status.ReadReceiptSent && portal.IsPrivateChat() && newStatus == gmproto.MessageStatusType_OUTGOING_DISPLAYED:
dbMsg.Status.ReadReceiptSent = true
if !dbMsg.Status.MSSDeliverySent {
@ -397,14 +475,15 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
go portal.sendStatusEvent(dbMsg.MXID, "", OutgoingStatusError(newStatus), nil)
// TODO error notice
default:
log.Debug().Msg("Ignored message update")
// TODO do something?
}
dbMsg.Status.Type = newStatus
dbMsg.Timestamp = time.UnixMicro(evt.GetTimestamp())
err := dbMsg.UpdateStatus(ctx)
if err != nil {
log.Warn().Err(err).Msg("Failed to save updated message status to database")
}
}
}
func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) {
@ -423,18 +502,10 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) {
Str("action", "handle google message").
Logger()
ctx := log.WithContext(context.TODO())
switch evt.GetMessageStatus().GetStatus() {
case gmproto.MessageStatusType_INCOMING_AUTO_DOWNLOADING, gmproto.MessageStatusType_INCOMING_RETRYING_AUTO_DOWNLOAD:
log.Debug().Msg("Not handling incoming message that is auto downloading")
return
case gmproto.MessageStatusType_MESSAGE_DELETED:
portal.handleGoogleDeletion(ctx, evt.MessageID)
return
}
if hasInProgressMedia(evt) {
log.Debug().Msg("Not handling incoming message that doesn't have full media yet")
return
}
//if hasInProgressMedia(evt) {
// log.Debug().Msg("Not handling incoming message that doesn't have full media yet")
// return
//}
if existingMsg := portal.isOutgoingMessage(evt); existingMsg != nil {
log.Debug().Str("event_id", existingMsg.MXID.String()).Msg("Got echo for outgoing message")
portal.handleExistingMessageUpdate(ctx, source, existingMsg, evt)
@ -444,48 +515,45 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) {
if err != nil {
log.Err(err).Msg("Failed to check if message is duplicate")
} else if existingMsg != nil {
log.Debug().Msg("Not handling duplicate message")
portal.handleExistingMessageUpdate(ctx, source, existingMsg, evt)
return
}
if evt.GetMessageStatus().GetStatus() == gmproto.MessageStatusType_MESSAGE_DELETED {
log.Debug().Msg("Not handling unknown deleted message")
return
}
converted := portal.convertGoogleMessage(ctx, source, evt, false)
if converted == nil {
return
} else if len(converted.Parts) == 0 {
log.Debug().Msg("Didn't get any converted parts from message")
return
}
eventIDs := make([]id.EventID, 0, len(converted.Parts))
mediaParts := make(map[string]database.MediaPart, len(converted.Parts)-1)
for _, part := range converted.Parts {
resp, err := portal.sendMessage(converted.Intent, event.EventMessage, part.Content, part.Extra, converted.Timestamp.UnixMilli())
if err != nil {
log.Err(err).Msg("Failed to send message")
} else {
eventIDs = append(eventIDs, resp.EventID)
if len(eventIDs) > 1 {
mediaParts[part.ID] = database.MediaPart{
EventID: resp.EventID,
PendingMedia: part.PendingMedia,
}
} else if part.PendingMedia {
mediaParts[""] = database.MediaPart{PendingMedia: true}
}
}
portal.markHandled(converted, eventIDs[0], true)
}
portal.markHandled(converted, eventIDs[0], mediaParts, true)
portal.sendDeliveryReceipt(eventIDs[len(eventIDs)-1])
log.Debug().Interface("event_ids", eventIDs).Msg("Handled message")
}
func (portal *Portal) handleGoogleDeletion(ctx context.Context, messageID string) {
log := zerolog.Ctx(ctx)
msg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Key, messageID)
if err != nil {
log.Err(err).Msg("Failed to get deleted message from database")
} else if msg == nil {
log.Debug().Msg("Didn't find deleted message in database")
} else {
if _, err = portal.MainIntent().RedactEvent(portal.MXID, msg.MXID); err != nil {
log.Err(err).Msg("Faield to redact deleted message")
}
if err = msg.Delete(ctx); err != nil {
log.Err(err).Msg("Failed to delete message from database")
}
log.Debug().Msg("Handled message deletion")
}
}
func (portal *Portal) syncReactions(ctx context.Context, source *User, message *database.Message, reactions []*gmproto.ReactionEntry) {
log := zerolog.Ctx(ctx)
existing, err := portal.bridge.DB.Reaction.GetAllByMessage(ctx, portal.Key, message.ID)
@ -554,6 +622,8 @@ func (portal *Portal) syncReactions(ctx context.Context, source *User, message *
}
type ConvertedMessagePart struct {
ID string
PendingMedia bool
Content *event.MessageEventContent
Extra map[string]any
}
@ -566,6 +636,9 @@ type ConvertedMessage struct {
Timestamp time.Time
ReplyTo string
Parts []ConvertedMessagePart
Status gmproto.MessageStatusType
MediaStatus string
}
func (portal *Portal) getIntent(ctx context.Context, source *User, participant string) *appservice.IntentAPI {
@ -586,15 +659,28 @@ func (portal *Portal) getIntent(ctx context.Context, source *User, participant s
}
}
func addSubject(content *event.MessageEventContent, subject string) {
content.Format = event.FormatHTML
content.FormattedBody = fmt.Sprintf("<strong>%s</strong><br>%s", event.TextToHTML(subject), event.TextToHTML(content.Body))
content.Body = fmt.Sprintf("**%s**\n%s", subject, content.Body)
}
func addDownloadStatus(content *event.MessageEventContent, status string) {
content.Body = fmt.Sprintf("%s\n\n%s", content.Body, status)
if content.Format == event.FormatHTML {
content.FormattedBody = fmt.Sprintf("<p>%s</p><p>%s</p>", content.FormattedBody, event.TextToHTML(status))
}
}
func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, evt *gmproto.Message, backfill bool) *ConvertedMessage {
log := zerolog.Ctx(ctx)
var cm ConvertedMessage
cm.Status = evt.GetMessageStatus().GetStatus()
cm.SenderID = evt.ParticipantID
cm.ID = evt.MessageID
cm.Timestamp = time.UnixMicro(evt.Timestamp)
msgStatus := evt.GetMessageStatus().GetStatus()
if msgStatus >= 200 && msgStatus < 300 {
if cm.Status >= 200 && cm.Status < 300 {
cm.Intent = portal.bridge.Bot
if !portal.Encrypted && portal.IsPrivateChat() {
cm.Intent = portal.MainIntent()
@ -624,8 +710,11 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
}
subject := evt.GetSubject()
downloadStatus := downloadPendingStatusMessage(evt.GetMessageStatus().GetStatus())
cm.MediaStatus = downloadStatus
for _, part := range evt.MessageInfo {
var content event.MessageEventContent
pendingMedia := false
switch data := part.GetData().(type) {
case *gmproto.MessageInfo_MessageContent:
content = event.MessageEventContent{
@ -633,14 +722,22 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
Body: data.MessageContent.GetContent(),
}
if subject != "" {
content.Format = event.FormatHTML
content.FormattedBody = fmt.Sprintf("<strong>%s</strong><br>%s", event.TextToHTML(subject), event.TextToHTML(content.Body))
content.Body = fmt.Sprintf("**%s**\n%s", subject, content.Body)
addSubject(&content, subject)
subject = ""
}
if downloadStatus != "" {
addDownloadStatus(&content, downloadStatus)
downloadStatus = ""
}
case *gmproto.MessageInfo_MediaContent:
contentPtr, err := portal.convertGoogleMedia(source, cm.Intent, data.MediaContent)
if err != nil {
if data.MediaContent.MediaID == "" {
pendingMedia = true
content = event.MessageEventContent{
MsgType: event.MsgNotice,
Body: fmt.Sprintf("Waiting for attachment %s", data.MediaContent.GetMediaName()),
}
} else if contentPtr, err := portal.convertGoogleMedia(source, cm.Intent, data.MediaContent); err != nil {
pendingMedia = true
log.Err(err).Msg("Failed to copy attachment")
content = event.MessageEventContent{
MsgType: event.MsgNotice,
@ -649,14 +746,29 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
} else {
content = *contentPtr
}
default:
continue
}
if replyTo != "" {
content.RelatesTo = &event.RelatesTo{InReplyTo: &event.InReplyTo{EventID: replyTo}}
}
cm.Parts = append(cm.Parts, ConvertedMessagePart{
ID: part.GetActionMessageID(),
PendingMedia: pendingMedia,
Content: &content,
})
}
if downloadStatus != "" {
content := event.MessageEventContent{
MsgType: event.MsgText,
Body: downloadStatus,
}
if subject != "" {
addSubject(&content, subject)
subject = ""
}
cm.Parts = append(cm.Parts, ConvertedMessagePart{Content: &content})
}
if subject != "" {
cm.Parts = append(cm.Parts, ConvertedMessagePart{
Content: &event.MessageEventContent{
@ -690,14 +802,23 @@ func (msg *ConvertedMessage) MergeCaption() {
}
switch filePart.Content.MsgType {
case event.MsgImage, event.MsgVideo, event.MsgAudio, event.MsgFile:
default:
return
}
filePart.Content.FileName = filePart.Content.Body
filePart.Content.Body = textPart.Content.Body
filePart.Content.Format = textPart.Content.Format
filePart.Content.FormattedBody = textPart.Content.FormattedBody
case event.MsgNotice: // If it's a notice, the media failed or is pending
if textPart.Content.Format == event.FormatHTML {
filePart.Content.Format = event.FormatHTML
if !strings.HasPrefix(textPart.Content.FormattedBody, "<p>") {
textPart.Content.FormattedBody = fmt.Sprintf("<p>%s</p>", textPart.Content.FormattedBody)
}
filePart.Content.FormattedBody = fmt.Sprintf("<p>%s</p>%s", event.TextToHTML(filePart.Content.Body), textPart.Content.FormattedBody)
}
filePart.Content.Body = fmt.Sprintf("%s\n\n%s", filePart.Content.Body, textPart.Content.Body)
filePart.Content.MsgType = event.MsgText
default:
return
}
msg.Parts = []ConvertedMessagePart{filePart}
}
@ -744,13 +865,16 @@ func (portal *Portal) isRecentlyHandled(id string) bool {
return false
}
func (portal *Portal) markHandled(cm *ConvertedMessage, eventID id.EventID, recent bool) *database.Message {
func (portal *Portal) markHandled(cm *ConvertedMessage, eventID id.EventID, mediaParts map[string]database.MediaPart, recent bool) *database.Message {
msg := portal.bridge.DB.Message.New()
msg.Chat = portal.Key
msg.ID = cm.ID
msg.MXID = eventID
msg.Timestamp = cm.Timestamp
msg.Sender = cm.SenderID
msg.Status.Type = cm.Status
msg.Status.MediaStatus = cm.MediaStatus
msg.Status.MediaParts = mediaParts
err := msg.Insert(context.TODO())
if err != nil {
portal.zlog.Err(err).Str("message_id", cm.ID).Msg("Failed to insert message to database")

View file

@ -914,7 +914,7 @@ type CustomReadMarkers struct {
}
func (user *User) markSelfReadFull(portal *Portal, lastMessageID string) {
if user.DoublePuppetIntent == nil {
if user.DoublePuppetIntent == nil || portal.lastUserReadID == lastMessageID {
return
}
ctx := context.TODO()
@ -925,7 +925,7 @@ func (user *User) markSelfReadFull(portal *Portal, lastMessageID string) {
if err != nil {
user.zlog.Warn().Err(err).Msg("Failed to get last message in chat to mark it as read")
return
} else if lastMessage == nil {
} else if lastMessage == nil || portal.lastUserReadID == lastMessage.ID {
return
}
log := user.zlog.With().
@ -946,6 +946,7 @@ func (user *User) markSelfReadFull(portal *Portal, lastMessageID string) {
log.Warn().Err(err).Msg("Failed to mark last message in chat as read")
} else {
log.Debug().Msg("Marked last message in chat as read")
portal.lastUserReadID = lastMessage.ID
}
}