Stop bridging protocol switch messages
This commit is contained in:
parent
f72cb7d7da
commit
211f000b28
4 changed files with 70 additions and 42 deletions
40
backfill.go
40
backfill.go
|
@ -24,6 +24,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"go.mau.fi/util/random"
|
||||||
"maunium.net/go/mautrix"
|
"maunium.net/go/mautrix"
|
||||||
"maunium.net/go/mautrix/event"
|
"maunium.net/go/mautrix/event"
|
||||||
"maunium.net/go/mautrix/id"
|
"maunium.net/go/mautrix/id"
|
||||||
|
@ -197,6 +198,11 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver
|
||||||
dbm.Status.PartCount = msg.PartCount
|
dbm.Status.PartCount = msg.PartCount
|
||||||
dbm.Status.MediaStatus = msg.MediaStatus
|
dbm.Status.MediaStatus = msg.MediaStatus
|
||||||
dbm.Status.MediaParts = make(map[string]database.MediaPart, len(msg.Parts))
|
dbm.Status.MediaParts = make(map[string]database.MediaPart, len(msg.Parts))
|
||||||
|
if msg.DontBridge {
|
||||||
|
dbm.MXID = id.EventID(fmt.Sprintf("$fake::%s", random.String(37)))
|
||||||
|
dbMessages = append(dbMessages, dbm)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
for i, part := range msg.Parts {
|
for i, part := range msg.Parts {
|
||||||
content := event.Content{
|
content := event.Content{
|
||||||
|
@ -252,42 +258,16 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver
|
||||||
}
|
}
|
||||||
|
|
||||||
func (portal *Portal) backfillSendLegacy(ctx context.Context, converted []*ConvertedMessage) id.EventID {
|
func (portal *Portal) backfillSendLegacy(ctx context.Context, converted []*ConvertedMessage) id.EventID {
|
||||||
log := zerolog.Ctx(ctx)
|
|
||||||
var lastEventID id.EventID
|
var lastEventID id.EventID
|
||||||
eventIDs := make(map[string]id.EventID)
|
eventIDs := make(map[string]id.EventID)
|
||||||
for _, msg := range converted {
|
for _, msg := range converted {
|
||||||
if len(msg.Parts) == 0 {
|
if len(msg.Parts) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var msgFirstEventID id.EventID
|
msgEventIDs := portal.sendMessageParts(ctx, msg, eventIDs)
|
||||||
mediaParts := make(map[string]database.MediaPart, len(msg.Parts)-1)
|
if len(msgEventIDs) > 0 {
|
||||||
for i, part := range msg.Parts {
|
eventIDs[msg.ID] = msgEventIDs[0]
|
||||||
if msg.ReplyTo != "" && part.Content.RelatesTo == nil {
|
lastEventID = msgEventIDs[len(msgEventIDs)-1]
|
||||||
replyToEvent, ok := eventIDs[msg.ReplyTo]
|
|
||||||
if ok {
|
|
||||||
part.Content.RelatesTo = &event.RelatesTo{
|
|
||||||
InReplyTo: &event.InReplyTo{EventID: replyToEvent},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp, err := portal.sendMessage(msg.Intent, event.EventMessage, part.Content, part.Extra, msg.Timestamp.UnixMilli())
|
|
||||||
if err != nil {
|
|
||||||
log.Err(err).Str("message_id", msg.ID).Int("part", i).Msg("Failed to send message")
|
|
||||||
} else {
|
|
||||||
if msgFirstEventID == "" {
|
|
||||||
msgFirstEventID = resp.EventID
|
|
||||||
eventIDs[msg.ID] = resp.EventID
|
|
||||||
} else {
|
|
||||||
mediaParts[part.ID] = database.MediaPart{
|
|
||||||
EventID: resp.EventID,
|
|
||||||
PendingMedia: part.PendingMedia,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lastEventID = resp.EventID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if msgFirstEventID != "" {
|
|
||||||
portal.markHandled(msg, msgFirstEventID, mediaParts, false)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return lastEventID
|
return lastEventID
|
||||||
|
|
|
@ -54,6 +54,11 @@ const (
|
||||||
WHERE conv_id=$1 AND conv_receiver=$2
|
WHERE conv_id=$1 AND conv_receiver=$2
|
||||||
ORDER BY timestamp DESC LIMIT 1
|
ORDER BY timestamp DESC LIMIT 1
|
||||||
`
|
`
|
||||||
|
getLastMessageInChatWithMXIDQuery = `
|
||||||
|
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message
|
||||||
|
WHERE conv_id=$1 AND conv_receiver=$2 AND mxid NOT LIKE '$fake::%'
|
||||||
|
ORDER BY timestamp DESC LIMIT 1
|
||||||
|
`
|
||||||
getMessageByMXIDQuery = `
|
getMessageByMXIDQuery = `
|
||||||
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message
|
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message
|
||||||
WHERE mxid=$1
|
WHERE mxid=$1
|
||||||
|
@ -72,6 +77,10 @@ func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message,
|
||||||
return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver)
|
return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mq *MessageQuery) GetLastInChatWithMXID(ctx context.Context, chat Key) (*Message, error) {
|
||||||
|
return get[*Message](mq, ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver)
|
||||||
|
}
|
||||||
|
|
||||||
type MediaPart struct {
|
type MediaPart struct {
|
||||||
EventID id.EventID `json:"mxid,omitempty"`
|
EventID id.EventID `json:"mxid,omitempty"`
|
||||||
PendingMedia bool `json:"pending_media,omitempty"`
|
PendingMedia bool `json:"pending_media,omitempty"`
|
||||||
|
@ -171,3 +180,7 @@ func (msg *Message) Delete(ctx context.Context) error {
|
||||||
_, err := msg.db.Conn(ctx).ExecContext(ctx, "DELETE FROM message WHERE conv_id=$1 AND conv_receiver=$2 AND id=$3", msg.Chat.ID, msg.Chat.Receiver, msg.ID)
|
_, err := msg.db.Conn(ctx).ExecContext(ctx, "DELETE FROM message WHERE conv_id=$1 AND conv_receiver=$2 AND id=$3", msg.Chat.ID, msg.Chat.Receiver, msg.ID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (msg *Message) IsFakeMXID() bool {
|
||||||
|
return strings.HasPrefix(msg.MXID.String(), "$fake:")
|
||||||
|
}
|
||||||
|
|
55
portal.go
55
portal.go
|
@ -31,6 +31,7 @@ import (
|
||||||
"github.com/gabriel-vasile/mimetype"
|
"github.com/gabriel-vasile/mimetype"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"go.mau.fi/util/exerrors"
|
"go.mau.fi/util/exerrors"
|
||||||
|
"go.mau.fi/util/random"
|
||||||
"go.mau.fi/util/variationselector"
|
"go.mau.fi/util/variationselector"
|
||||||
"maunium.net/go/mautrix"
|
"maunium.net/go/mautrix"
|
||||||
"maunium.net/go/mautrix/appservice"
|
"maunium.net/go/mautrix/appservice"
|
||||||
|
@ -523,19 +524,38 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) {
|
||||||
}
|
}
|
||||||
|
|
||||||
converted := portal.convertGoogleMessage(ctx, source, evt, false)
|
converted := portal.convertGoogleMessage(ctx, source, evt, false)
|
||||||
if converted == nil {
|
eventIDs := portal.sendMessageParts(ctx, converted, nil)
|
||||||
return
|
if len(eventIDs) > 0 {
|
||||||
} else if len(converted.Parts) == 0 {
|
portal.sendDeliveryReceipt(eventIDs[len(eventIDs)-1])
|
||||||
log.Debug().Msg("Didn't get any converted parts from message")
|
log.Debug().Interface("event_ids", eventIDs).Msg("Handled message")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (portal *Portal) sendMessageParts(ctx context.Context, converted *ConvertedMessage, replyToMap map[string]id.EventID) []id.EventID {
|
||||||
|
if converted == nil {
|
||||||
|
return nil
|
||||||
|
} else if len(converted.Parts) == 0 {
|
||||||
|
zerolog.Ctx(ctx).Debug().Msg("Didn't get any converted parts from message")
|
||||||
|
return nil
|
||||||
|
} else if converted.DontBridge {
|
||||||
|
zerolog.Ctx(ctx).Debug().Msg("Ignored incoming tombstone message")
|
||||||
|
portal.markHandled(converted, id.EventID(fmt.Sprintf("$fake::%s", random.String(37))), nil, true)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
eventIDs := make([]id.EventID, 0, len(converted.Parts))
|
eventIDs := make([]id.EventID, 0, len(converted.Parts))
|
||||||
mediaParts := make(map[string]database.MediaPart, len(converted.Parts)-1)
|
mediaParts := make(map[string]database.MediaPart, len(converted.Parts)-1)
|
||||||
for _, part := range converted.Parts {
|
for i, part := range converted.Parts {
|
||||||
|
if replyToMap != nil && converted.ReplyTo != "" && part.Content.RelatesTo == nil {
|
||||||
|
replyToEvent, ok := replyToMap[converted.ReplyTo]
|
||||||
|
if ok {
|
||||||
|
part.Content.RelatesTo = &event.RelatesTo{
|
||||||
|
InReplyTo: &event.InReplyTo{EventID: replyToEvent},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
resp, err := portal.sendMessage(converted.Intent, event.EventMessage, part.Content, part.Extra, converted.Timestamp.UnixMilli())
|
resp, err := portal.sendMessage(converted.Intent, event.EventMessage, part.Content, part.Extra, converted.Timestamp.UnixMilli())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).Msg("Failed to send message")
|
zerolog.Ctx(ctx).Err(err).Int("part_index", i).Str("part_id", part.ID).Msg("Failed to send message")
|
||||||
} else {
|
} else {
|
||||||
eventIDs = append(eventIDs, resp.EventID)
|
eventIDs = append(eventIDs, resp.EventID)
|
||||||
if len(eventIDs) > 1 {
|
if len(eventIDs) > 1 {
|
||||||
|
@ -549,8 +569,7 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
portal.markHandled(converted, eventIDs[0], mediaParts, true)
|
portal.markHandled(converted, eventIDs[0], mediaParts, true)
|
||||||
portal.sendDeliveryReceipt(eventIDs[len(eventIDs)-1])
|
return eventIDs
|
||||||
log.Debug().Interface("event_ids", eventIDs).Msg("Handled message")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (portal *Portal) syncReactions(ctx context.Context, source *User, message *database.Message, reactions []*gmproto.ReactionEntry) {
|
func (portal *Portal) syncReactions(ctx context.Context, source *User, message *database.Message, reactions []*gmproto.ReactionEntry) {
|
||||||
|
@ -637,6 +656,8 @@ type ConvertedMessage struct {
|
||||||
Parts []ConvertedMessagePart
|
Parts []ConvertedMessagePart
|
||||||
PartCount int
|
PartCount int
|
||||||
|
|
||||||
|
DontBridge bool
|
||||||
|
|
||||||
Status gmproto.MessageStatusType
|
Status gmproto.MessageStatusType
|
||||||
MediaStatus string
|
MediaStatus string
|
||||||
}
|
}
|
||||||
|
@ -672,6 +693,17 @@ func addDownloadStatus(content *event.MessageEventContent, status string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldIgnoreStatus(status gmproto.MessageStatusType) bool {
|
||||||
|
switch status {
|
||||||
|
case gmproto.MessageStatusType_TOMBSTONE_PROTOCOL_SWITCH_TO_TEXT,
|
||||||
|
gmproto.MessageStatusType_TOMBSTONE_PROTOCOL_SWITCH_TO_RCS,
|
||||||
|
gmproto.MessageStatusType_TOMBSTONE_PROTOCOL_SWITCH_TO_ENCRYPTED_RCS:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, evt *gmproto.Message, backfill bool) *ConvertedMessage {
|
func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, evt *gmproto.Message, backfill bool) *ConvertedMessage {
|
||||||
log := zerolog.Ctx(ctx)
|
log := zerolog.Ctx(ctx)
|
||||||
|
|
||||||
|
@ -681,6 +713,7 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
|
||||||
cm.ID = evt.MessageID
|
cm.ID = evt.MessageID
|
||||||
cm.PartCount = len(evt.GetMessageInfo())
|
cm.PartCount = len(evt.GetMessageInfo())
|
||||||
cm.Timestamp = time.UnixMicro(evt.Timestamp)
|
cm.Timestamp = time.UnixMicro(evt.Timestamp)
|
||||||
|
cm.DontBridge = shouldIgnoreStatus(cm.Status)
|
||||||
if cm.Status >= 200 && cm.Status < 300 {
|
if cm.Status >= 200 && cm.Status < 300 {
|
||||||
cm.Intent = portal.bridge.Bot
|
cm.Intent = portal.bridge.Bot
|
||||||
if !portal.Encrypted && portal.IsPrivateChat() {
|
if !portal.Encrypted && portal.IsPrivateChat() {
|
||||||
|
@ -705,6 +738,8 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
|
||||||
} else {
|
} else {
|
||||||
log.Warn().Str("reply_to_id", cm.ReplyTo).Msg("Reply target message not found")
|
log.Warn().Str("reply_to_id", cm.ReplyTo).Msg("Reply target message not found")
|
||||||
}
|
}
|
||||||
|
} else if msg.IsFakeMXID() {
|
||||||
|
log.Debug().Str("reply_to_id", msg.ID).Msg("Ignoring reply to non-bridged message")
|
||||||
} else {
|
} else {
|
||||||
replyTo = msg.MXID
|
replyTo = msg.MXID
|
||||||
}
|
}
|
||||||
|
@ -1570,7 +1605,7 @@ func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.Eve
|
||||||
log.Err(err).Msg("Failed to get target message to handle read receipt")
|
log.Err(err).Msg("Failed to get target message to handle read receipt")
|
||||||
return
|
return
|
||||||
} else if targetMessage == nil {
|
} else if targetMessage == nil {
|
||||||
lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
|
lastMessage, err := portal.bridge.DB.Message.GetLastInChatWithMXID(ctx, portal.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).Msg("Failed to get last message to handle read receipt")
|
log.Err(err).Msg("Failed to get last message to handle read receipt")
|
||||||
return
|
return
|
||||||
|
|
4
user.go
4
user.go
|
@ -919,8 +919,8 @@ func (user *User) markSelfReadFull(portal *Portal, lastMessageID string) {
|
||||||
}
|
}
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
lastMessage, err := user.bridge.DB.Message.GetByID(ctx, portal.Key, lastMessageID)
|
lastMessage, err := user.bridge.DB.Message.GetByID(ctx, portal.Key, lastMessageID)
|
||||||
if err == nil && lastMessage == nil {
|
if err == nil && lastMessage == nil || lastMessage.IsFakeMXID() {
|
||||||
lastMessage, err = user.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
|
lastMessage, err = user.bridge.DB.Message.GetLastInChatWithMXID(ctx, portal.Key)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
user.zlog.Warn().Err(err).Msg("Failed to get last message in chat to mark it as read")
|
user.zlog.Warn().Err(err).Msg("Failed to get last message in chat to mark it as read")
|
||||||
|
|
Loading…
Reference in a new issue