Add support for messages moving to different chats

This commit is contained in:
Tulir Asokan 2023-08-14 14:32:12 +03:00
parent 3f417ba719
commit b05212e47d
8 changed files with 192 additions and 49 deletions

View file

@ -188,6 +188,7 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver
for _, msg := range converted {
dbm := portal.bridge.DB.Message.New()
dbm.Chat = portal.Key
dbm.RoomID = portal.MXID
dbm.ID = msg.ID
dbm.Sender = msg.SenderID
dbm.Timestamp = msg.Timestamp

View file

@ -46,27 +46,27 @@ func (mq *MessageQuery) getDB() *Database {
const (
getMessageByIDQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message
WHERE conv_id=$1 AND conv_receiver=$2 AND id=$3
SELECT conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status FROM message
WHERE conv_receiver=$1 AND id=$2
`
getLastMessageInChatQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message
SELECT conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status FROM message
WHERE conv_id=$1 AND conv_receiver=$2
ORDER BY timestamp DESC LIMIT 1
`
getLastMessageInChatWithMXIDQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message
SELECT conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status FROM message
WHERE conv_id=$1 AND conv_receiver=$2 AND mxid NOT LIKE '$fake::%'
ORDER BY timestamp DESC LIMIT 1
`
getMessageByMXIDQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message
SELECT conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status FROM message
WHERE mxid=$1
`
)
func (mq *MessageQuery) GetByID(ctx context.Context, chat Key, messageID string) (*Message, error) {
return get[*Message](mq, ctx, getMessageByIDQuery, chat.ID, chat.Receiver, messageID)
func (mq *MessageQuery) GetByID(ctx context.Context, receiver int, messageID string) (*Message, error) {
return get[*Message](mq, ctx, getMessageByIDQuery, receiver, messageID)
}
func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
@ -114,6 +114,7 @@ type Message struct {
Chat Key
ID string
MXID id.EventID
RoomID id.RoomID
Sender string
Timestamp time.Time
Status MessageStatus
@ -121,7 +122,7 @@ type Message struct {
func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
var ts int64
err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status})
err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.RoomID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status})
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
@ -134,28 +135,29 @@ func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
}
func (msg *Message) sqlVariables() []any {
return []any{msg.Chat.ID, msg.Chat.Receiver, msg.ID, msg.MXID, msg.Sender, msg.Timestamp.UnixMicro(), dbutil.JSON{Data: &msg.Status}}
return []any{msg.Chat.ID, msg.Chat.Receiver, msg.ID, msg.MXID, msg.RoomID, msg.Sender, msg.Timestamp.UnixMicro(), dbutil.JSON{Data: &msg.Status}}
}
func (msg *Message) Insert(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, `
INSERT INTO message (conv_id, conv_receiver, id, mxid, sender, timestamp, status)
VALUES ($1, $2, $3, $4, $5, $6, $7)
INSERT INTO message (conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`, msg.sqlVariables()...)
return err
}
func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) error {
valueStringFormat := "($1, $2, $%d, $%d, $%d, $%d, $%d)"
valueStringFormat := "($1, $2, $%d, $%d, $3, $%d, $%d, $%d)"
if mq.db.Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
}
placeholders := make([]string, len(messages))
params := make([]any, 2+len(messages)*5)
params := make([]any, 3+len(messages)*5)
params[0] = messages[0].Chat.ID
params[1] = messages[0].Chat.Receiver
params[2] = messages[0].RoomID
for i, msg := range messages {
baseIndex := 2 + i*5
baseIndex := 3 + i*5
params[baseIndex] = msg.ID
params[baseIndex+1] = msg.MXID
params[baseIndex+2] = msg.Sender
@ -164,15 +166,24 @@ func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) err
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5)
}
query := `
INSERT INTO message (conv_id, conv_receiver, id, mxid, sender, timestamp, status)
INSERT INTO message (conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status)
VALUES
` + strings.Join(placeholders, ",")
_, err := mq.db.Conn(ctx).ExecContext(ctx, query, params...)
return err
}
func (msg *Message) Update(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, `
UPDATE message
SET conv_id=$1, mxid=$4, mx_room=$5, sender=$6, timestamp=$7, status=$8
WHERE conv_receiver=$2 AND id=$3
`, msg.sqlVariables()...)
return err
}
func (msg *Message) UpdateStatus(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, "UPDATE message SET status=$1, timestamp=$2 WHERE conv_id=$3 AND conv_receiver=$4 AND id=$5", dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.ID, msg.Chat.Receiver, msg.ID)
_, err := msg.db.Conn(ctx).ExecContext(ctx, "UPDATE message SET status=$1, timestamp=$2 WHERE conv_receiver=$3 AND id=$4", dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID)
return err
}

View file

@ -44,13 +44,17 @@ func (rq *ReactionQuery) getDB() *Database {
const (
getReactionByIDQuery = `
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction
WHERE conv_id=$1 AND conv_receiver=$2 AND msg_id=$3 AND sender=$4
WHERE conv_receiver=$1 AND msg_id=$2 AND sender=$3
`
getReactionByMXIDQuery = `
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction
WHERE mxid=$1
`
getReactionsByMessageIDQuery = `
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction
WHERE conv_receiver=$1 AND msg_id=$2
`
deleteReactionsByMessageIDQuery = `
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction
WHERE conv_id=$1 AND conv_receiver=$2 AND msg_id=$3
`
@ -62,16 +66,21 @@ const (
`
)
func (rq *ReactionQuery) GetByID(ctx context.Context, chat Key, messageID, sender string) (*Reaction, error) {
return get[*Reaction](rq, ctx, getReactionByIDQuery, chat.ID, chat.Receiver, messageID, sender)
func (rq *ReactionQuery) GetByID(ctx context.Context, receiver int, messageID, sender string) (*Reaction, error) {
return get[*Reaction](rq, ctx, getReactionByIDQuery, receiver, messageID, sender)
}
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
return get[*Reaction](rq, ctx, getReactionByMXIDQuery, mxid)
}
func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, chat Key, messageID string) ([]*Reaction, error) {
return getAll[*Reaction](rq, ctx, getReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID)
func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, receiver int, messageID string) ([]*Reaction, error) {
return getAll[*Reaction](rq, ctx, getReactionsByMessageIDQuery, receiver, messageID)
}
func (rq *ReactionQuery) DeleteAllByMessage(ctx context.Context, chat Key, messageID string) error {
_, err := rq.db.Conn(ctx).ExecContext(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID)
return err
}
type Reaction struct {

View file

@ -59,14 +59,16 @@ CREATE TABLE message (
conv_receiver BIGINT NOT NULL,
id TEXT NOT NULL,
mxid TEXT NOT NULL,
mx_room TEXT NOT NULL,
sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
status jsonb NOT NULL,
PRIMARY KEY (conv_id, conv_receiver, id),
PRIMARY KEY (conv_receiver, id),
CONSTRAINT message_portal_fkey FOREIGN KEY (conv_id, conv_receiver) REFERENCES portal(id, receiver) ON DELETE CASCADE,
CONSTRAINT message_mxid_unique UNIQUE (mxid)
);
CREATE INDEX message_conv_timestamp_idx ON message(conv_id, conv_receiver, timestamp);
CREATE TABLE reaction (
conv_id TEXT NOT NULL,
@ -76,7 +78,7 @@ CREATE TABLE reaction (
reaction TEXT NOT NULL,
mxid TEXT NOT NULL,
PRIMARY KEY (conv_id, conv_receiver, msg_id, sender),
CONSTRAINT reaction_message_fkey FOREIGN KEY (conv_id, conv_receiver, msg_id) REFERENCES message(conv_id, conv_receiver, id) ON DELETE CASCADE,
PRIMARY KEY (conv_receiver, msg_id, sender),
CONSTRAINT reaction_message_fkey FOREIGN KEY (conv_receiver, msg_id) REFERENCES message(conv_receiver, id) ON DELETE CASCADE,
CONSTRAINT reaction_mxid_unique UNIQUE (mxid)
)

View file

@ -0,0 +1,20 @@
-- v4: Drop conversation ID from message primary key
-- transaction: off
BEGIN TRANSACTION;
ALTER TABLE reaction DROP CONSTRAINT reaction_pkey;
ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey;
ALTER TABLE reaction ADD PRIMARY KEY (conv_receiver, msg_id, sender);
ALTER TABLE message DROP CONSTRAINT message_pkey;
DELETE FROM message WHERE (status->>'type')::integer IN (101, 102, 103, 104, 105, 106, 107, 110, 111, 112, 113, 114);
ALTER TABLE message ADD PRIMARY KEY (conv_receiver, id);
CREATE INDEX message_conv_timestamp_idx ON message(conv_id, conv_receiver, timestamp);
ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (conv_receiver, msg_id) REFERENCES message (conv_receiver, id) ON DELETE CASCADE;
ALTER TABLE message ADD COLUMN mx_room TEXT NOT NULL DEFAULT '';
UPDATE message SET mx_room = (SELECT mxid FROM portal WHERE id=message.conv_id AND receiver=message.conv_receiver);
ALTER TABLE message ALTER COLUMN mx_room DROP DEFAULT;
COMMIT;

View file

@ -0,0 +1,50 @@
-- v4: Drop conversation ID from message primary key
-- transaction: off
PRAGMA foreign_keys = OFF;
BEGIN TRANSACTION;
CREATE TABLE message_new (
conv_id TEXT NOT NULL,
conv_receiver BIGINT NOT NULL,
id TEXT NOT NULL,
mxid TEXT NOT NULL,
mx_room TEXT NOT NULL,
sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
status jsonb NOT NULL,
PRIMARY KEY (conv_receiver, id),
CONSTRAINT message_portal_fkey FOREIGN KEY (conv_id, conv_receiver) REFERENCES portal(id, receiver) ON DELETE CASCADE,
CONSTRAINT message_mxid_unique UNIQUE (mxid)
);
INSERT INTO message_new (conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status)
SELECT conv_id, conv_receiver, id, mxid, (SELECT mxid FROM portal WHERE id=conv_id AND receiver=conv_receiver), sender, timestamp, status
FROM message
WHERE status->>'type' NOT IN (101, 102, 103, 104, 105, 106, 107, 110, 111, 112, 113, 114);
DROP TABLE message;
ALTER TABLE message_new RENAME TO message;
CREATE INDEX message_conv_timestamp_idx ON message(conv_id, conv_receiver, timestamp);
CREATE TABLE reaction_new (
conv_id TEXT NOT NULL,
conv_receiver BIGINT NOT NULL,
msg_id TEXT NOT NULL,
sender TEXT NOT NULL,
reaction TEXT NOT NULL,
mxid TEXT NOT NULL,
PRIMARY KEY (conv_receiver, msg_id, sender),
CONSTRAINT reaction_message_fkey FOREIGN KEY (conv_receiver, msg_id) REFERENCES message(conv_receiver, id) ON DELETE CASCADE,
CONSTRAINT reaction_mxid_unique UNIQUE (mxid)
);
INSERT INTO reaction_new (conv_id, conv_receiver, msg_id, sender, reaction, mxid)
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid
FROM reaction;
DROP TABLE reaction;
ALTER TABLE reaction_new RENAME TO reaction;
PRAGMA foreign_key_check;
COMMIT;
PRAGMA FOREIGN_KEYS = ON;

View file

@ -355,7 +355,7 @@ func isSuccessfullySentStatus(status gmproto.MessageStatusType) bool {
func downloadPendingStatusMessage(status gmproto.MessageStatusType) string {
switch status {
case gmproto.MessageStatusType_INCOMING_YET_TO_MANUAL_DOWNLOAD:
return "Attachment message (auto-download is disabled)"
return "Attachment message (auto-download is disabled, use Messages on Android to download)"
case gmproto.MessageStatusType_INCOMING_MANUAL_DOWNLOADING,
gmproto.MessageStatusType_INCOMING_AUTO_DOWNLOADING,
gmproto.MessageStatusType_INCOMING_RETRYING_MANUAL_DOWNLOAD,
@ -391,32 +391,68 @@ func isFailSendStatus(status gmproto.MessageStatusType) bool {
}
}
func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *User, dbMsg *database.Message, evt *gmproto.Message) {
log := zerolog.Ctx(ctx)
portal.syncReactions(ctx, source, dbMsg, evt.Reactions)
newStatus := evt.GetMessageStatus().GetStatus()
if dbMsg.Status.Type == newStatus && !(dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt)) {
func (portal *Portal) redactMessage(ctx context.Context, msg *database.Message) {
if msg.IsFakeMXID() {
return
}
log.Debug().Str("old_status", dbMsg.Status.Type.String()).Msg("Message status changed")
log := zerolog.Ctx(ctx)
intent := portal.MainIntent()
if msg.Chat.ID != portal.ID {
otherPortal := portal.bridge.GetExistingPortalByKey(msg.Chat)
if otherPortal != nil {
intent = otherPortal.MainIntent()
}
}
for partID, part := range msg.Status.MediaParts {
if part.EventID != "" {
if _, err := intent.RedactEvent(msg.RoomID, part.EventID); err != nil {
log.Err(err).Str("part_id", partID).Msg("Failed to redact part of message")
}
part.EventID = ""
msg.Status.MediaParts[partID] = part
}
}
if _, err := intent.RedactEvent(msg.RoomID, msg.MXID); err != nil {
log.Err(err).Msg("Failed to redact message")
}
msg.MXID = ""
}
func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *User, dbMsg *database.Message, evt *gmproto.Message) {
log := *zerolog.Ctx(ctx)
newStatus := evt.GetMessageStatus().GetStatus()
chatIDChanged := dbMsg.Chat.ID != portal.ID
if dbMsg.Status.Type == newStatus && !chatIDChanged && !(dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt)) {
portal.syncReactions(ctx, source, dbMsg, evt.Reactions)
return
}
if chatIDChanged {
log = log.With().Str("old_chat_id", dbMsg.Chat.ID).Logger()
log.Debug().
Str("old_room_id", dbMsg.RoomID.String()).
Str("sender_id", dbMsg.Sender).
Msg("Redacting events from old room")
ctx = log.WithContext(ctx)
err := portal.bridge.DB.Reaction.DeleteAllByMessage(ctx, dbMsg.Chat, dbMsg.ID)
if err != nil {
log.Warn().Err(err).Msg("Failed to delete db reactions for message that moved to another room")
}
portal.redactMessage(ctx, dbMsg)
}
log.Debug().
Str("old_status", dbMsg.Status.Type.String()).
Msg("Message status changed")
switch {
case newStatus == gmproto.MessageStatusType_MESSAGE_DELETED:
for partID, part := range dbMsg.Status.MediaParts {
if part.EventID != "" {
if _, err := portal.MainIntent().RedactEvent(portal.MXID, part.EventID); err != nil {
log.Err(err).Str("part_iD", partID).Msg("Failed to redact part of deleted message")
}
}
}
if _, err := portal.MainIntent().RedactEvent(portal.MXID, dbMsg.MXID); err != nil {
log.Err(err).Msg("Failed to redact deleted message")
} else if err = dbMsg.Delete(ctx); err != nil {
portal.redactMessage(ctx, dbMsg)
if err := dbMsg.Delete(ctx); err != nil {
log.Err(err).Msg("Failed to delete message from database")
} else {
log.Debug().Msg("Handled message deletion")
}
return
case dbMsg.Status.MediaStatus != downloadPendingStatusMessage(newStatus),
case chatIDChanged,
dbMsg.Status.MediaStatus != downloadPendingStatusMessage(newStatus),
dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt),
dbMsg.Status.PartCount != len(evt.MessageInfo):
converted := portal.convertGoogleMessage(ctx, source, evt, false)
@ -428,7 +464,9 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
for i, part := range converted.Parts {
isEdit := true
ts := time.Now().UnixMilli()
if i == 0 {
if chatIDChanged {
isEdit = false
} else if i == 0 {
part.Content.SetEdit(dbMsg.MXID)
} else if existingPart, ok := dbMsg.Status.MediaParts[part.ID]; ok {
part.Content.SetEdit(existingPart.EventID)
@ -443,6 +481,9 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
eventIDs = append(eventIDs, resp.EventID)
}
if i == 0 {
if chatIDChanged {
dbMsg.MXID = resp.EventID
}
dbMsg.Status.MediaParts[""] = database.MediaPart{PendingMedia: part.PendingMedia}
} else if !isEdit {
dbMsg.Status.MediaParts[part.ID] = database.MediaPart{EventID: resp.EventID, PendingMedia: part.PendingMedia}
@ -485,10 +526,18 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
dbMsg.Status.Type = newStatus
dbMsg.Status.PartCount = len(evt.MessageInfo)
dbMsg.Timestamp = time.UnixMicro(evt.GetTimestamp())
err := dbMsg.UpdateStatus(ctx)
var err error
if chatIDChanged {
dbMsg.Chat = portal.Key
dbMsg.RoomID = portal.MXID
err = dbMsg.Update(ctx)
} else {
err = dbMsg.UpdateStatus(ctx)
}
if err != nil {
log.Warn().Err(err).Msg("Failed to save updated message status to database")
}
portal.syncReactions(ctx, source, dbMsg, evt.Reactions)
}
func (portal *Portal) handleExistingMessage(ctx context.Context, source *User, evt *gmproto.Message, outgoingOnly bool) bool {
@ -500,7 +549,7 @@ func (portal *Portal) handleExistingMessage(ctx context.Context, source *User, e
} else if outgoingOnly {
return false
}
existingMsg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Key, evt.MessageID)
existingMsg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Receiver, evt.MessageID)
if err != nil {
log.Err(err).Msg("Failed to check if message is duplicate")
} else if existingMsg != nil {
@ -603,7 +652,7 @@ func (portal *Portal) sendMessageParts(ctx context.Context, converted *Converted
func (portal *Portal) syncReactions(ctx context.Context, source *User, message *database.Message, reactions []*gmproto.ReactionEntry) {
log := zerolog.Ctx(ctx)
existing, err := portal.bridge.DB.Reaction.GetAllByMessage(ctx, portal.Key, message.ID)
existing, err := portal.bridge.DB.Reaction.GetAllByMessage(ctx, portal.Receiver, message.ID)
if err != nil {
log.Err(err).Msg("Failed to get existing reactions from db to sync reactions")
return
@ -758,7 +807,7 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
var replyTo id.EventID
if evt.GetReplyMessage() != nil {
cm.ReplyTo = evt.GetReplyMessage().GetMessageID()
msg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Key, cm.ReplyTo)
msg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Receiver, cm.ReplyTo)
if err != nil {
log.Err(err).Str("reply_to_id", cm.ReplyTo).Msg("Failed to get reply target message")
} else if msg == nil {
@ -933,6 +982,7 @@ func (portal *Portal) isRecentlyHandled(id string) bool {
func (portal *Portal) markHandled(cm *ConvertedMessage, eventID id.EventID, mediaParts map[string]database.MediaPart, recent bool) *database.Message {
msg := portal.bridge.DB.Message.New()
msg.Chat = portal.Key
msg.RoomID = portal.MXID
msg.ID = cm.ID
msg.MXID = eventID
msg.Timestamp = cm.Timestamp
@ -1691,7 +1741,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
return errTargetNotFound
}
existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Key, msg.ID, portal.OutgoingID)
existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Receiver, msg.ID, portal.OutgoingID)
if err != nil {
log.Err(err).Msg("Failed to get existing reaction")
return fmt.Errorf("failed to get existing reaction from database")

View file

@ -900,7 +900,7 @@ func (user *User) markSelfReadFull(portal *Portal, lastMessageID string) {
return
}
ctx := context.TODO()
lastMessage, err := user.bridge.DB.Message.GetByID(ctx, portal.Key, lastMessageID)
lastMessage, err := user.bridge.DB.Message.GetByID(ctx, portal.Receiver, lastMessageID)
if err == nil && lastMessage == nil || lastMessage.IsFakeMXID() {
lastMessage, err = user.bridge.DB.Message.GetLastInChatWithMXID(ctx, portal.Key)
}