diff --git a/backfill.go b/backfill.go index 3152437..b4b9127 100644 --- a/backfill.go +++ b/backfill.go @@ -24,6 +24,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/random" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -197,6 +198,11 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver dbm.Status.PartCount = msg.PartCount dbm.Status.MediaStatus = msg.MediaStatus dbm.Status.MediaParts = make(map[string]database.MediaPart, len(msg.Parts)) + if msg.DontBridge { + dbm.MXID = id.EventID(fmt.Sprintf("$fake::%s", random.String(37))) + dbMessages = append(dbMessages, dbm) + continue + } for i, part := range msg.Parts { content := event.Content{ @@ -252,42 +258,16 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver } func (portal *Portal) backfillSendLegacy(ctx context.Context, converted []*ConvertedMessage) id.EventID { - log := zerolog.Ctx(ctx) var lastEventID id.EventID eventIDs := make(map[string]id.EventID) for _, msg := range converted { if len(msg.Parts) == 0 { continue } - var msgFirstEventID id.EventID - mediaParts := make(map[string]database.MediaPart, len(msg.Parts)-1) - for i, part := range msg.Parts { - if msg.ReplyTo != "" && part.Content.RelatesTo == nil { - replyToEvent, ok := eventIDs[msg.ReplyTo] - if ok { - part.Content.RelatesTo = &event.RelatesTo{ - InReplyTo: &event.InReplyTo{EventID: replyToEvent}, - } - } - } - resp, err := portal.sendMessage(msg.Intent, event.EventMessage, part.Content, part.Extra, msg.Timestamp.UnixMilli()) - if err != nil { - log.Err(err).Str("message_id", msg.ID).Int("part", i).Msg("Failed to send message") - } else { - if msgFirstEventID == "" { - msgFirstEventID = resp.EventID - eventIDs[msg.ID] = resp.EventID - } else { - mediaParts[part.ID] = database.MediaPart{ - EventID: resp.EventID, - PendingMedia: part.PendingMedia, - } - } - lastEventID = resp.EventID - } - } - if msgFirstEventID != "" { - portal.markHandled(msg, msgFirstEventID, mediaParts, false) + msgEventIDs := portal.sendMessageParts(ctx, msg, eventIDs) + if len(msgEventIDs) > 0 { + eventIDs[msg.ID] = msgEventIDs[0] + lastEventID = msgEventIDs[len(msgEventIDs)-1] } } return lastEventID diff --git a/database/message.go b/database/message.go index da5e8f2..808b65b 100644 --- a/database/message.go +++ b/database/message.go @@ -54,6 +54,11 @@ const ( WHERE conv_id=$1 AND conv_receiver=$2 ORDER BY timestamp DESC LIMIT 1 ` + getLastMessageInChatWithMXIDQuery = ` + SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message + WHERE conv_id=$1 AND conv_receiver=$2 AND mxid NOT LIKE '$fake::%' + ORDER BY timestamp DESC LIMIT 1 + ` getMessageByMXIDQuery = ` SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message WHERE mxid=$1 @@ -72,6 +77,10 @@ func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message, return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver) } +func (mq *MessageQuery) GetLastInChatWithMXID(ctx context.Context, chat Key) (*Message, error) { + return get[*Message](mq, ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver) +} + type MediaPart struct { EventID id.EventID `json:"mxid,omitempty"` PendingMedia bool `json:"pending_media,omitempty"` @@ -171,3 +180,7 @@ func (msg *Message) Delete(ctx context.Context) error { _, err := msg.db.Conn(ctx).ExecContext(ctx, "DELETE FROM message WHERE conv_id=$1 AND conv_receiver=$2 AND id=$3", msg.Chat.ID, msg.Chat.Receiver, msg.ID) return err } + +func (msg *Message) IsFakeMXID() bool { + return strings.HasPrefix(msg.MXID.String(), "$fake:") +} diff --git a/portal.go b/portal.go index b490541..a54ebc5 100644 --- a/portal.go +++ b/portal.go @@ -31,6 +31,7 @@ import ( "github.com/gabriel-vasile/mimetype" "github.com/rs/zerolog" "go.mau.fi/util/exerrors" + "go.mau.fi/util/random" "go.mau.fi/util/variationselector" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -523,19 +524,38 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) { } converted := portal.convertGoogleMessage(ctx, source, evt, false) - if converted == nil { - return - } else if len(converted.Parts) == 0 { - log.Debug().Msg("Didn't get any converted parts from message") - return + eventIDs := portal.sendMessageParts(ctx, converted, nil) + if len(eventIDs) > 0 { + portal.sendDeliveryReceipt(eventIDs[len(eventIDs)-1]) + log.Debug().Interface("event_ids", eventIDs).Msg("Handled message") } +} +func (portal *Portal) sendMessageParts(ctx context.Context, converted *ConvertedMessage, replyToMap map[string]id.EventID) []id.EventID { + if converted == nil { + return nil + } else if len(converted.Parts) == 0 { + zerolog.Ctx(ctx).Debug().Msg("Didn't get any converted parts from message") + return nil + } else if converted.DontBridge { + zerolog.Ctx(ctx).Debug().Msg("Ignored incoming tombstone message") + portal.markHandled(converted, id.EventID(fmt.Sprintf("$fake::%s", random.String(37))), nil, true) + return nil + } eventIDs := make([]id.EventID, 0, len(converted.Parts)) mediaParts := make(map[string]database.MediaPart, len(converted.Parts)-1) - for _, part := range converted.Parts { + for i, part := range converted.Parts { + if replyToMap != nil && converted.ReplyTo != "" && part.Content.RelatesTo == nil { + replyToEvent, ok := replyToMap[converted.ReplyTo] + if ok { + part.Content.RelatesTo = &event.RelatesTo{ + InReplyTo: &event.InReplyTo{EventID: replyToEvent}, + } + } + } resp, err := portal.sendMessage(converted.Intent, event.EventMessage, part.Content, part.Extra, converted.Timestamp.UnixMilli()) if err != nil { - log.Err(err).Msg("Failed to send message") + zerolog.Ctx(ctx).Err(err).Int("part_index", i).Str("part_id", part.ID).Msg("Failed to send message") } else { eventIDs = append(eventIDs, resp.EventID) if len(eventIDs) > 1 { @@ -549,8 +569,7 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) { } } portal.markHandled(converted, eventIDs[0], mediaParts, true) - portal.sendDeliveryReceipt(eventIDs[len(eventIDs)-1]) - log.Debug().Interface("event_ids", eventIDs).Msg("Handled message") + return eventIDs } func (portal *Portal) syncReactions(ctx context.Context, source *User, message *database.Message, reactions []*gmproto.ReactionEntry) { @@ -637,6 +656,8 @@ type ConvertedMessage struct { Parts []ConvertedMessagePart PartCount int + DontBridge bool + Status gmproto.MessageStatusType MediaStatus string } @@ -672,6 +693,17 @@ func addDownloadStatus(content *event.MessageEventContent, status string) { } } +func shouldIgnoreStatus(status gmproto.MessageStatusType) bool { + switch status { + case gmproto.MessageStatusType_TOMBSTONE_PROTOCOL_SWITCH_TO_TEXT, + gmproto.MessageStatusType_TOMBSTONE_PROTOCOL_SWITCH_TO_RCS, + gmproto.MessageStatusType_TOMBSTONE_PROTOCOL_SWITCH_TO_ENCRYPTED_RCS: + return true + default: + return false + } +} + func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, evt *gmproto.Message, backfill bool) *ConvertedMessage { log := zerolog.Ctx(ctx) @@ -681,6 +713,7 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev cm.ID = evt.MessageID cm.PartCount = len(evt.GetMessageInfo()) cm.Timestamp = time.UnixMicro(evt.Timestamp) + cm.DontBridge = shouldIgnoreStatus(cm.Status) if cm.Status >= 200 && cm.Status < 300 { cm.Intent = portal.bridge.Bot if !portal.Encrypted && portal.IsPrivateChat() { @@ -705,6 +738,8 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev } else { log.Warn().Str("reply_to_id", cm.ReplyTo).Msg("Reply target message not found") } + } else if msg.IsFakeMXID() { + log.Debug().Str("reply_to_id", msg.ID).Msg("Ignoring reply to non-bridged message") } else { replyTo = msg.MXID } @@ -1570,7 +1605,7 @@ func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.Eve log.Err(err).Msg("Failed to get target message to handle read receipt") return } else if targetMessage == nil { - lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key) + lastMessage, err := portal.bridge.DB.Message.GetLastInChatWithMXID(ctx, portal.Key) if err != nil { log.Err(err).Msg("Failed to get last message to handle read receipt") return diff --git a/user.go b/user.go index 5ce8cff..7708407 100644 --- a/user.go +++ b/user.go @@ -919,8 +919,8 @@ func (user *User) markSelfReadFull(portal *Portal, lastMessageID string) { } ctx := context.TODO() lastMessage, err := user.bridge.DB.Message.GetByID(ctx, portal.Key, lastMessageID) - if err == nil && lastMessage == nil { - lastMessage, err = user.bridge.DB.Message.GetLastInChat(ctx, portal.Key) + if err == nil && lastMessage == nil || lastMessage.IsFakeMXID() { + lastMessage, err = user.bridge.DB.Message.GetLastInChatWithMXID(ctx, portal.Key) } if err != nil { user.zlog.Warn().Err(err).Msg("Failed to get last message in chat to mark it as read")