Stop bridging protocol switch messages

This commit is contained in:
Tulir Asokan 2023-08-10 11:30:17 +03:00
parent f72cb7d7da
commit 211f000b28
4 changed files with 70 additions and 42 deletions

View file

@ -24,6 +24,7 @@ import (
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.mau.fi/util/random"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@ -197,6 +198,11 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver
dbm.Status.PartCount = msg.PartCount dbm.Status.PartCount = msg.PartCount
dbm.Status.MediaStatus = msg.MediaStatus dbm.Status.MediaStatus = msg.MediaStatus
dbm.Status.MediaParts = make(map[string]database.MediaPart, len(msg.Parts)) dbm.Status.MediaParts = make(map[string]database.MediaPart, len(msg.Parts))
if msg.DontBridge {
dbm.MXID = id.EventID(fmt.Sprintf("$fake::%s", random.String(37)))
dbMessages = append(dbMessages, dbm)
continue
}
for i, part := range msg.Parts { for i, part := range msg.Parts {
content := event.Content{ content := event.Content{
@ -252,42 +258,16 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver
} }
func (portal *Portal) backfillSendLegacy(ctx context.Context, converted []*ConvertedMessage) id.EventID { func (portal *Portal) backfillSendLegacy(ctx context.Context, converted []*ConvertedMessage) id.EventID {
log := zerolog.Ctx(ctx)
var lastEventID id.EventID var lastEventID id.EventID
eventIDs := make(map[string]id.EventID) eventIDs := make(map[string]id.EventID)
for _, msg := range converted { for _, msg := range converted {
if len(msg.Parts) == 0 { if len(msg.Parts) == 0 {
continue continue
} }
var msgFirstEventID id.EventID msgEventIDs := portal.sendMessageParts(ctx, msg, eventIDs)
mediaParts := make(map[string]database.MediaPart, len(msg.Parts)-1) if len(msgEventIDs) > 0 {
for i, part := range msg.Parts { eventIDs[msg.ID] = msgEventIDs[0]
if msg.ReplyTo != "" && part.Content.RelatesTo == nil { lastEventID = msgEventIDs[len(msgEventIDs)-1]
replyToEvent, ok := eventIDs[msg.ReplyTo]
if ok {
part.Content.RelatesTo = &event.RelatesTo{
InReplyTo: &event.InReplyTo{EventID: replyToEvent},
}
}
}
resp, err := portal.sendMessage(msg.Intent, event.EventMessage, part.Content, part.Extra, msg.Timestamp.UnixMilli())
if err != nil {
log.Err(err).Str("message_id", msg.ID).Int("part", i).Msg("Failed to send message")
} else {
if msgFirstEventID == "" {
msgFirstEventID = resp.EventID
eventIDs[msg.ID] = resp.EventID
} else {
mediaParts[part.ID] = database.MediaPart{
EventID: resp.EventID,
PendingMedia: part.PendingMedia,
}
}
lastEventID = resp.EventID
}
}
if msgFirstEventID != "" {
portal.markHandled(msg, msgFirstEventID, mediaParts, false)
} }
} }
return lastEventID return lastEventID

View file

@ -54,6 +54,11 @@ const (
WHERE conv_id=$1 AND conv_receiver=$2 WHERE conv_id=$1 AND conv_receiver=$2
ORDER BY timestamp DESC LIMIT 1 ORDER BY timestamp DESC LIMIT 1
` `
getLastMessageInChatWithMXIDQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message
WHERE conv_id=$1 AND conv_receiver=$2 AND mxid NOT LIKE '$fake::%'
ORDER BY timestamp DESC LIMIT 1
`
getMessageByMXIDQuery = ` getMessageByMXIDQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message
WHERE mxid=$1 WHERE mxid=$1
@ -72,6 +77,10 @@ func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message,
return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver) return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver)
} }
func (mq *MessageQuery) GetLastInChatWithMXID(ctx context.Context, chat Key) (*Message, error) {
return get[*Message](mq, ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver)
}
type MediaPart struct { type MediaPart struct {
EventID id.EventID `json:"mxid,omitempty"` EventID id.EventID `json:"mxid,omitempty"`
PendingMedia bool `json:"pending_media,omitempty"` PendingMedia bool `json:"pending_media,omitempty"`
@ -171,3 +180,7 @@ func (msg *Message) Delete(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, "DELETE FROM message WHERE conv_id=$1 AND conv_receiver=$2 AND id=$3", msg.Chat.ID, msg.Chat.Receiver, msg.ID) _, err := msg.db.Conn(ctx).ExecContext(ctx, "DELETE FROM message WHERE conv_id=$1 AND conv_receiver=$2 AND id=$3", msg.Chat.ID, msg.Chat.Receiver, msg.ID)
return err return err
} }
func (msg *Message) IsFakeMXID() bool {
return strings.HasPrefix(msg.MXID.String(), "$fake:")
}

View file

@ -31,6 +31,7 @@ import (
"github.com/gabriel-vasile/mimetype" "github.com/gabriel-vasile/mimetype"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.mau.fi/util/exerrors" "go.mau.fi/util/exerrors"
"go.mau.fi/util/random"
"go.mau.fi/util/variationselector" "go.mau.fi/util/variationselector"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
@ -523,19 +524,38 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) {
} }
converted := portal.convertGoogleMessage(ctx, source, evt, false) converted := portal.convertGoogleMessage(ctx, source, evt, false)
if converted == nil { eventIDs := portal.sendMessageParts(ctx, converted, nil)
return if len(eventIDs) > 0 {
} else if len(converted.Parts) == 0 { portal.sendDeliveryReceipt(eventIDs[len(eventIDs)-1])
log.Debug().Msg("Didn't get any converted parts from message") log.Debug().Interface("event_ids", eventIDs).Msg("Handled message")
return
} }
}
func (portal *Portal) sendMessageParts(ctx context.Context, converted *ConvertedMessage, replyToMap map[string]id.EventID) []id.EventID {
if converted == nil {
return nil
} else if len(converted.Parts) == 0 {
zerolog.Ctx(ctx).Debug().Msg("Didn't get any converted parts from message")
return nil
} else if converted.DontBridge {
zerolog.Ctx(ctx).Debug().Msg("Ignored incoming tombstone message")
portal.markHandled(converted, id.EventID(fmt.Sprintf("$fake::%s", random.String(37))), nil, true)
return nil
}
eventIDs := make([]id.EventID, 0, len(converted.Parts)) eventIDs := make([]id.EventID, 0, len(converted.Parts))
mediaParts := make(map[string]database.MediaPart, len(converted.Parts)-1) mediaParts := make(map[string]database.MediaPart, len(converted.Parts)-1)
for _, part := range converted.Parts { for i, part := range converted.Parts {
if replyToMap != nil && converted.ReplyTo != "" && part.Content.RelatesTo == nil {
replyToEvent, ok := replyToMap[converted.ReplyTo]
if ok {
part.Content.RelatesTo = &event.RelatesTo{
InReplyTo: &event.InReplyTo{EventID: replyToEvent},
}
}
}
resp, err := portal.sendMessage(converted.Intent, event.EventMessage, part.Content, part.Extra, converted.Timestamp.UnixMilli()) resp, err := portal.sendMessage(converted.Intent, event.EventMessage, part.Content, part.Extra, converted.Timestamp.UnixMilli())
if err != nil { if err != nil {
log.Err(err).Msg("Failed to send message") zerolog.Ctx(ctx).Err(err).Int("part_index", i).Str("part_id", part.ID).Msg("Failed to send message")
} else { } else {
eventIDs = append(eventIDs, resp.EventID) eventIDs = append(eventIDs, resp.EventID)
if len(eventIDs) > 1 { if len(eventIDs) > 1 {
@ -549,8 +569,7 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) {
} }
} }
portal.markHandled(converted, eventIDs[0], mediaParts, true) portal.markHandled(converted, eventIDs[0], mediaParts, true)
portal.sendDeliveryReceipt(eventIDs[len(eventIDs)-1]) return eventIDs
log.Debug().Interface("event_ids", eventIDs).Msg("Handled message")
} }
func (portal *Portal) syncReactions(ctx context.Context, source *User, message *database.Message, reactions []*gmproto.ReactionEntry) { func (portal *Portal) syncReactions(ctx context.Context, source *User, message *database.Message, reactions []*gmproto.ReactionEntry) {
@ -637,6 +656,8 @@ type ConvertedMessage struct {
Parts []ConvertedMessagePart Parts []ConvertedMessagePart
PartCount int PartCount int
DontBridge bool
Status gmproto.MessageStatusType Status gmproto.MessageStatusType
MediaStatus string MediaStatus string
} }
@ -672,6 +693,17 @@ func addDownloadStatus(content *event.MessageEventContent, status string) {
} }
} }
func shouldIgnoreStatus(status gmproto.MessageStatusType) bool {
switch status {
case gmproto.MessageStatusType_TOMBSTONE_PROTOCOL_SWITCH_TO_TEXT,
gmproto.MessageStatusType_TOMBSTONE_PROTOCOL_SWITCH_TO_RCS,
gmproto.MessageStatusType_TOMBSTONE_PROTOCOL_SWITCH_TO_ENCRYPTED_RCS:
return true
default:
return false
}
}
func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, evt *gmproto.Message, backfill bool) *ConvertedMessage { func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, evt *gmproto.Message, backfill bool) *ConvertedMessage {
log := zerolog.Ctx(ctx) log := zerolog.Ctx(ctx)
@ -681,6 +713,7 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
cm.ID = evt.MessageID cm.ID = evt.MessageID
cm.PartCount = len(evt.GetMessageInfo()) cm.PartCount = len(evt.GetMessageInfo())
cm.Timestamp = time.UnixMicro(evt.Timestamp) cm.Timestamp = time.UnixMicro(evt.Timestamp)
cm.DontBridge = shouldIgnoreStatus(cm.Status)
if cm.Status >= 200 && cm.Status < 300 { if cm.Status >= 200 && cm.Status < 300 {
cm.Intent = portal.bridge.Bot cm.Intent = portal.bridge.Bot
if !portal.Encrypted && portal.IsPrivateChat() { if !portal.Encrypted && portal.IsPrivateChat() {
@ -705,6 +738,8 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
} else { } else {
log.Warn().Str("reply_to_id", cm.ReplyTo).Msg("Reply target message not found") log.Warn().Str("reply_to_id", cm.ReplyTo).Msg("Reply target message not found")
} }
} else if msg.IsFakeMXID() {
log.Debug().Str("reply_to_id", msg.ID).Msg("Ignoring reply to non-bridged message")
} else { } else {
replyTo = msg.MXID replyTo = msg.MXID
} }
@ -1570,7 +1605,7 @@ func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.Eve
log.Err(err).Msg("Failed to get target message to handle read receipt") log.Err(err).Msg("Failed to get target message to handle read receipt")
return return
} else if targetMessage == nil { } else if targetMessage == nil {
lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key) lastMessage, err := portal.bridge.DB.Message.GetLastInChatWithMXID(ctx, portal.Key)
if err != nil { if err != nil {
log.Err(err).Msg("Failed to get last message to handle read receipt") log.Err(err).Msg("Failed to get last message to handle read receipt")
return return

View file

@ -919,8 +919,8 @@ func (user *User) markSelfReadFull(portal *Portal, lastMessageID string) {
} }
ctx := context.TODO() ctx := context.TODO()
lastMessage, err := user.bridge.DB.Message.GetByID(ctx, portal.Key, lastMessageID) lastMessage, err := user.bridge.DB.Message.GetByID(ctx, portal.Key, lastMessageID)
if err == nil && lastMessage == nil { if err == nil && lastMessage == nil || lastMessage.IsFakeMXID() {
lastMessage, err = user.bridge.DB.Message.GetLastInChat(ctx, portal.Key) lastMessage, err = user.bridge.DB.Message.GetLastInChatWithMXID(ctx, portal.Key)
} }
if err != nil { if err != nil {
user.zlog.Warn().Err(err).Msg("Failed to get last message in chat to mark it as read") user.zlog.Warn().Err(err).Msg("Failed to get last message in chat to mark it as read")