Add support for messages moving to different chats

This commit is contained in:
Tulir Asokan 2023-08-14 14:32:12 +03:00
parent 3f417ba719
commit b05212e47d
8 changed files with 192 additions and 49 deletions

View file

@ -188,6 +188,7 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver
for _, msg := range converted { for _, msg := range converted {
dbm := portal.bridge.DB.Message.New() dbm := portal.bridge.DB.Message.New()
dbm.Chat = portal.Key dbm.Chat = portal.Key
dbm.RoomID = portal.MXID
dbm.ID = msg.ID dbm.ID = msg.ID
dbm.Sender = msg.SenderID dbm.Sender = msg.SenderID
dbm.Timestamp = msg.Timestamp dbm.Timestamp = msg.Timestamp

View file

@ -46,27 +46,27 @@ func (mq *MessageQuery) getDB() *Database {
const ( const (
getMessageByIDQuery = ` getMessageByIDQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message SELECT conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status FROM message
WHERE conv_id=$1 AND conv_receiver=$2 AND id=$3 WHERE conv_receiver=$1 AND id=$2
` `
getLastMessageInChatQuery = ` getLastMessageInChatQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message SELECT conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status FROM message
WHERE conv_id=$1 AND conv_receiver=$2 WHERE conv_id=$1 AND conv_receiver=$2
ORDER BY timestamp DESC LIMIT 1 ORDER BY timestamp DESC LIMIT 1
` `
getLastMessageInChatWithMXIDQuery = ` getLastMessageInChatWithMXIDQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message SELECT conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status FROM message
WHERE conv_id=$1 AND conv_receiver=$2 AND mxid NOT LIKE '$fake::%' WHERE conv_id=$1 AND conv_receiver=$2 AND mxid NOT LIKE '$fake::%'
ORDER BY timestamp DESC LIMIT 1 ORDER BY timestamp DESC LIMIT 1
` `
getMessageByMXIDQuery = ` getMessageByMXIDQuery = `
SELECT conv_id, conv_receiver, id, mxid, sender, timestamp, status FROM message SELECT conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status FROM message
WHERE mxid=$1 WHERE mxid=$1
` `
) )
func (mq *MessageQuery) GetByID(ctx context.Context, chat Key, messageID string) (*Message, error) { func (mq *MessageQuery) GetByID(ctx context.Context, receiver int, messageID string) (*Message, error) {
return get[*Message](mq, ctx, getMessageByIDQuery, chat.ID, chat.Receiver, messageID) return get[*Message](mq, ctx, getMessageByIDQuery, receiver, messageID)
} }
func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) { func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
@ -114,6 +114,7 @@ type Message struct {
Chat Key Chat Key
ID string ID string
MXID id.EventID MXID id.EventID
RoomID id.RoomID
Sender string Sender string
Timestamp time.Time Timestamp time.Time
Status MessageStatus Status MessageStatus
@ -121,7 +122,7 @@ type Message struct {
func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) { func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
var ts int64 var ts int64
err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status}) err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.RoomID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status})
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, nil return nil, nil
} else if err != nil { } else if err != nil {
@ -134,28 +135,29 @@ func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
} }
func (msg *Message) sqlVariables() []any { func (msg *Message) sqlVariables() []any {
return []any{msg.Chat.ID, msg.Chat.Receiver, msg.ID, msg.MXID, msg.Sender, msg.Timestamp.UnixMicro(), dbutil.JSON{Data: &msg.Status}} return []any{msg.Chat.ID, msg.Chat.Receiver, msg.ID, msg.MXID, msg.RoomID, msg.Sender, msg.Timestamp.UnixMicro(), dbutil.JSON{Data: &msg.Status}}
} }
func (msg *Message) Insert(ctx context.Context) error { func (msg *Message) Insert(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, ` _, err := msg.db.Conn(ctx).ExecContext(ctx, `
INSERT INTO message (conv_id, conv_receiver, id, mxid, sender, timestamp, status) INSERT INTO message (conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`, msg.sqlVariables()...) `, msg.sqlVariables()...)
return err return err
} }
func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) error { func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) error {
valueStringFormat := "($1, $2, $%d, $%d, $%d, $%d, $%d)" valueStringFormat := "($1, $2, $%d, $%d, $3, $%d, $%d, $%d)"
if mq.db.Dialect == dbutil.SQLite { if mq.db.Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
} }
placeholders := make([]string, len(messages)) placeholders := make([]string, len(messages))
params := make([]any, 2+len(messages)*5) params := make([]any, 3+len(messages)*5)
params[0] = messages[0].Chat.ID params[0] = messages[0].Chat.ID
params[1] = messages[0].Chat.Receiver params[1] = messages[0].Chat.Receiver
params[2] = messages[0].RoomID
for i, msg := range messages { for i, msg := range messages {
baseIndex := 2 + i*5 baseIndex := 3 + i*5
params[baseIndex] = msg.ID params[baseIndex] = msg.ID
params[baseIndex+1] = msg.MXID params[baseIndex+1] = msg.MXID
params[baseIndex+2] = msg.Sender params[baseIndex+2] = msg.Sender
@ -164,15 +166,24 @@ func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) err
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5) placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5)
} }
query := ` query := `
INSERT INTO message (conv_id, conv_receiver, id, mxid, sender, timestamp, status) INSERT INTO message (conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status)
VALUES VALUES
` + strings.Join(placeholders, ",") ` + strings.Join(placeholders, ",")
_, err := mq.db.Conn(ctx).ExecContext(ctx, query, params...) _, err := mq.db.Conn(ctx).ExecContext(ctx, query, params...)
return err return err
} }
func (msg *Message) Update(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, `
UPDATE message
SET conv_id=$1, mxid=$4, mx_room=$5, sender=$6, timestamp=$7, status=$8
WHERE conv_receiver=$2 AND id=$3
`, msg.sqlVariables()...)
return err
}
func (msg *Message) UpdateStatus(ctx context.Context) error { func (msg *Message) UpdateStatus(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, "UPDATE message SET status=$1, timestamp=$2 WHERE conv_id=$3 AND conv_receiver=$4 AND id=$5", dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.ID, msg.Chat.Receiver, msg.ID) _, err := msg.db.Conn(ctx).ExecContext(ctx, "UPDATE message SET status=$1, timestamp=$2 WHERE conv_receiver=$3 AND id=$4", dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID)
return err return err
} }

View file

@ -44,13 +44,17 @@ func (rq *ReactionQuery) getDB() *Database {
const ( const (
getReactionByIDQuery = ` getReactionByIDQuery = `
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction
WHERE conv_id=$1 AND conv_receiver=$2 AND msg_id=$3 AND sender=$4 WHERE conv_receiver=$1 AND msg_id=$2 AND sender=$3
` `
getReactionByMXIDQuery = ` getReactionByMXIDQuery = `
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction
WHERE mxid=$1 WHERE mxid=$1
` `
getReactionsByMessageIDQuery = ` getReactionsByMessageIDQuery = `
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction
WHERE conv_receiver=$1 AND msg_id=$2
`
deleteReactionsByMessageIDQuery = `
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid FROM reaction
WHERE conv_id=$1 AND conv_receiver=$2 AND msg_id=$3 WHERE conv_id=$1 AND conv_receiver=$2 AND msg_id=$3
` `
@ -62,16 +66,21 @@ const (
` `
) )
func (rq *ReactionQuery) GetByID(ctx context.Context, chat Key, messageID, sender string) (*Reaction, error) { func (rq *ReactionQuery) GetByID(ctx context.Context, receiver int, messageID, sender string) (*Reaction, error) {
return get[*Reaction](rq, ctx, getReactionByIDQuery, chat.ID, chat.Receiver, messageID, sender) return get[*Reaction](rq, ctx, getReactionByIDQuery, receiver, messageID, sender)
} }
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
return get[*Reaction](rq, ctx, getReactionByMXIDQuery, mxid) return get[*Reaction](rq, ctx, getReactionByMXIDQuery, mxid)
} }
func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, chat Key, messageID string) ([]*Reaction, error) { func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, receiver int, messageID string) ([]*Reaction, error) {
return getAll[*Reaction](rq, ctx, getReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID) return getAll[*Reaction](rq, ctx, getReactionsByMessageIDQuery, receiver, messageID)
}
func (rq *ReactionQuery) DeleteAllByMessage(ctx context.Context, chat Key, messageID string) error {
_, err := rq.db.Conn(ctx).ExecContext(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID)
return err
} }
type Reaction struct { type Reaction struct {

View file

@ -59,14 +59,16 @@ CREATE TABLE message (
conv_receiver BIGINT NOT NULL, conv_receiver BIGINT NOT NULL,
id TEXT NOT NULL, id TEXT NOT NULL,
mxid TEXT NOT NULL, mxid TEXT NOT NULL,
mx_room TEXT NOT NULL,
sender TEXT NOT NULL, sender TEXT NOT NULL,
timestamp BIGINT NOT NULL, timestamp BIGINT NOT NULL,
status jsonb NOT NULL, status jsonb NOT NULL,
PRIMARY KEY (conv_id, conv_receiver, id), PRIMARY KEY (conv_receiver, id),
CONSTRAINT message_portal_fkey FOREIGN KEY (conv_id, conv_receiver) REFERENCES portal(id, receiver) ON DELETE CASCADE, CONSTRAINT message_portal_fkey FOREIGN KEY (conv_id, conv_receiver) REFERENCES portal(id, receiver) ON DELETE CASCADE,
CONSTRAINT message_mxid_unique UNIQUE (mxid) CONSTRAINT message_mxid_unique UNIQUE (mxid)
); );
CREATE INDEX message_conv_timestamp_idx ON message(conv_id, conv_receiver, timestamp);
CREATE TABLE reaction ( CREATE TABLE reaction (
conv_id TEXT NOT NULL, conv_id TEXT NOT NULL,
@ -76,7 +78,7 @@ CREATE TABLE reaction (
reaction TEXT NOT NULL, reaction TEXT NOT NULL,
mxid TEXT NOT NULL, mxid TEXT NOT NULL,
PRIMARY KEY (conv_id, conv_receiver, msg_id, sender), PRIMARY KEY (conv_receiver, msg_id, sender),
CONSTRAINT reaction_message_fkey FOREIGN KEY (conv_id, conv_receiver, msg_id) REFERENCES message(conv_id, conv_receiver, id) ON DELETE CASCADE, CONSTRAINT reaction_message_fkey FOREIGN KEY (conv_receiver, msg_id) REFERENCES message(conv_receiver, id) ON DELETE CASCADE,
CONSTRAINT reaction_mxid_unique UNIQUE (mxid) CONSTRAINT reaction_mxid_unique UNIQUE (mxid)
) )

View file

@ -0,0 +1,20 @@
-- v4: Drop conversation ID from message primary key
-- transaction: off
BEGIN TRANSACTION;
ALTER TABLE reaction DROP CONSTRAINT reaction_pkey;
ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey;
ALTER TABLE reaction ADD PRIMARY KEY (conv_receiver, msg_id, sender);
ALTER TABLE message DROP CONSTRAINT message_pkey;
DELETE FROM message WHERE (status->>'type')::integer IN (101, 102, 103, 104, 105, 106, 107, 110, 111, 112, 113, 114);
ALTER TABLE message ADD PRIMARY KEY (conv_receiver, id);
CREATE INDEX message_conv_timestamp_idx ON message(conv_id, conv_receiver, timestamp);
ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (conv_receiver, msg_id) REFERENCES message (conv_receiver, id) ON DELETE CASCADE;
ALTER TABLE message ADD COLUMN mx_room TEXT NOT NULL DEFAULT '';
UPDATE message SET mx_room = (SELECT mxid FROM portal WHERE id=message.conv_id AND receiver=message.conv_receiver);
ALTER TABLE message ALTER COLUMN mx_room DROP DEFAULT;
COMMIT;

View file

@ -0,0 +1,50 @@
-- v4: Drop conversation ID from message primary key
-- transaction: off
PRAGMA foreign_keys = OFF;
BEGIN TRANSACTION;
CREATE TABLE message_new (
conv_id TEXT NOT NULL,
conv_receiver BIGINT NOT NULL,
id TEXT NOT NULL,
mxid TEXT NOT NULL,
mx_room TEXT NOT NULL,
sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
status jsonb NOT NULL,
PRIMARY KEY (conv_receiver, id),
CONSTRAINT message_portal_fkey FOREIGN KEY (conv_id, conv_receiver) REFERENCES portal(id, receiver) ON DELETE CASCADE,
CONSTRAINT message_mxid_unique UNIQUE (mxid)
);
INSERT INTO message_new (conv_id, conv_receiver, id, mxid, mx_room, sender, timestamp, status)
SELECT conv_id, conv_receiver, id, mxid, (SELECT mxid FROM portal WHERE id=conv_id AND receiver=conv_receiver), sender, timestamp, status
FROM message
WHERE status->>'type' NOT IN (101, 102, 103, 104, 105, 106, 107, 110, 111, 112, 113, 114);
DROP TABLE message;
ALTER TABLE message_new RENAME TO message;
CREATE INDEX message_conv_timestamp_idx ON message(conv_id, conv_receiver, timestamp);
CREATE TABLE reaction_new (
conv_id TEXT NOT NULL,
conv_receiver BIGINT NOT NULL,
msg_id TEXT NOT NULL,
sender TEXT NOT NULL,
reaction TEXT NOT NULL,
mxid TEXT NOT NULL,
PRIMARY KEY (conv_receiver, msg_id, sender),
CONSTRAINT reaction_message_fkey FOREIGN KEY (conv_receiver, msg_id) REFERENCES message(conv_receiver, id) ON DELETE CASCADE,
CONSTRAINT reaction_mxid_unique UNIQUE (mxid)
);
INSERT INTO reaction_new (conv_id, conv_receiver, msg_id, sender, reaction, mxid)
SELECT conv_id, conv_receiver, msg_id, sender, reaction, mxid
FROM reaction;
DROP TABLE reaction;
ALTER TABLE reaction_new RENAME TO reaction;
PRAGMA foreign_key_check;
COMMIT;
PRAGMA FOREIGN_KEYS = ON;

View file

@ -355,7 +355,7 @@ func isSuccessfullySentStatus(status gmproto.MessageStatusType) bool {
func downloadPendingStatusMessage(status gmproto.MessageStatusType) string { func downloadPendingStatusMessage(status gmproto.MessageStatusType) string {
switch status { switch status {
case gmproto.MessageStatusType_INCOMING_YET_TO_MANUAL_DOWNLOAD: case gmproto.MessageStatusType_INCOMING_YET_TO_MANUAL_DOWNLOAD:
return "Attachment message (auto-download is disabled)" return "Attachment message (auto-download is disabled, use Messages on Android to download)"
case gmproto.MessageStatusType_INCOMING_MANUAL_DOWNLOADING, case gmproto.MessageStatusType_INCOMING_MANUAL_DOWNLOADING,
gmproto.MessageStatusType_INCOMING_AUTO_DOWNLOADING, gmproto.MessageStatusType_INCOMING_AUTO_DOWNLOADING,
gmproto.MessageStatusType_INCOMING_RETRYING_MANUAL_DOWNLOAD, gmproto.MessageStatusType_INCOMING_RETRYING_MANUAL_DOWNLOAD,
@ -391,32 +391,68 @@ func isFailSendStatus(status gmproto.MessageStatusType) bool {
} }
} }
func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *User, dbMsg *database.Message, evt *gmproto.Message) { func (portal *Portal) redactMessage(ctx context.Context, msg *database.Message) {
log := zerolog.Ctx(ctx) if msg.IsFakeMXID() {
portal.syncReactions(ctx, source, dbMsg, evt.Reactions)
newStatus := evt.GetMessageStatus().GetStatus()
if dbMsg.Status.Type == newStatus && !(dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt)) {
return return
} }
log.Debug().Str("old_status", dbMsg.Status.Type.String()).Msg("Message status changed") log := zerolog.Ctx(ctx)
intent := portal.MainIntent()
if msg.Chat.ID != portal.ID {
otherPortal := portal.bridge.GetExistingPortalByKey(msg.Chat)
if otherPortal != nil {
intent = otherPortal.MainIntent()
}
}
for partID, part := range msg.Status.MediaParts {
if part.EventID != "" {
if _, err := intent.RedactEvent(msg.RoomID, part.EventID); err != nil {
log.Err(err).Str("part_id", partID).Msg("Failed to redact part of message")
}
part.EventID = ""
msg.Status.MediaParts[partID] = part
}
}
if _, err := intent.RedactEvent(msg.RoomID, msg.MXID); err != nil {
log.Err(err).Msg("Failed to redact message")
}
msg.MXID = ""
}
func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *User, dbMsg *database.Message, evt *gmproto.Message) {
log := *zerolog.Ctx(ctx)
newStatus := evt.GetMessageStatus().GetStatus()
chatIDChanged := dbMsg.Chat.ID != portal.ID
if dbMsg.Status.Type == newStatus && !chatIDChanged && !(dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt)) {
portal.syncReactions(ctx, source, dbMsg, evt.Reactions)
return
}
if chatIDChanged {
log = log.With().Str("old_chat_id", dbMsg.Chat.ID).Logger()
log.Debug().
Str("old_room_id", dbMsg.RoomID.String()).
Str("sender_id", dbMsg.Sender).
Msg("Redacting events from old room")
ctx = log.WithContext(ctx)
err := portal.bridge.DB.Reaction.DeleteAllByMessage(ctx, dbMsg.Chat, dbMsg.ID)
if err != nil {
log.Warn().Err(err).Msg("Failed to delete db reactions for message that moved to another room")
}
portal.redactMessage(ctx, dbMsg)
}
log.Debug().
Str("old_status", dbMsg.Status.Type.String()).
Msg("Message status changed")
switch { switch {
case newStatus == gmproto.MessageStatusType_MESSAGE_DELETED: case newStatus == gmproto.MessageStatusType_MESSAGE_DELETED:
for partID, part := range dbMsg.Status.MediaParts { portal.redactMessage(ctx, dbMsg)
if part.EventID != "" { if err := dbMsg.Delete(ctx); err != nil {
if _, err := portal.MainIntent().RedactEvent(portal.MXID, part.EventID); err != nil {
log.Err(err).Str("part_iD", partID).Msg("Failed to redact part of deleted message")
}
}
}
if _, err := portal.MainIntent().RedactEvent(portal.MXID, dbMsg.MXID); err != nil {
log.Err(err).Msg("Failed to redact deleted message")
} else if err = dbMsg.Delete(ctx); err != nil {
log.Err(err).Msg("Failed to delete message from database") log.Err(err).Msg("Failed to delete message from database")
} else { } else {
log.Debug().Msg("Handled message deletion") log.Debug().Msg("Handled message deletion")
} }
return return
case dbMsg.Status.MediaStatus != downloadPendingStatusMessage(newStatus), case chatIDChanged,
dbMsg.Status.MediaStatus != downloadPendingStatusMessage(newStatus),
dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt), dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt),
dbMsg.Status.PartCount != len(evt.MessageInfo): dbMsg.Status.PartCount != len(evt.MessageInfo):
converted := portal.convertGoogleMessage(ctx, source, evt, false) converted := portal.convertGoogleMessage(ctx, source, evt, false)
@ -428,7 +464,9 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
for i, part := range converted.Parts { for i, part := range converted.Parts {
isEdit := true isEdit := true
ts := time.Now().UnixMilli() ts := time.Now().UnixMilli()
if i == 0 { if chatIDChanged {
isEdit = false
} else if i == 0 {
part.Content.SetEdit(dbMsg.MXID) part.Content.SetEdit(dbMsg.MXID)
} else if existingPart, ok := dbMsg.Status.MediaParts[part.ID]; ok { } else if existingPart, ok := dbMsg.Status.MediaParts[part.ID]; ok {
part.Content.SetEdit(existingPart.EventID) part.Content.SetEdit(existingPart.EventID)
@ -443,6 +481,9 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
eventIDs = append(eventIDs, resp.EventID) eventIDs = append(eventIDs, resp.EventID)
} }
if i == 0 { if i == 0 {
if chatIDChanged {
dbMsg.MXID = resp.EventID
}
dbMsg.Status.MediaParts[""] = database.MediaPart{PendingMedia: part.PendingMedia} dbMsg.Status.MediaParts[""] = database.MediaPart{PendingMedia: part.PendingMedia}
} else if !isEdit { } else if !isEdit {
dbMsg.Status.MediaParts[part.ID] = database.MediaPart{EventID: resp.EventID, PendingMedia: part.PendingMedia} dbMsg.Status.MediaParts[part.ID] = database.MediaPart{EventID: resp.EventID, PendingMedia: part.PendingMedia}
@ -485,10 +526,18 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
dbMsg.Status.Type = newStatus dbMsg.Status.Type = newStatus
dbMsg.Status.PartCount = len(evt.MessageInfo) dbMsg.Status.PartCount = len(evt.MessageInfo)
dbMsg.Timestamp = time.UnixMicro(evt.GetTimestamp()) dbMsg.Timestamp = time.UnixMicro(evt.GetTimestamp())
err := dbMsg.UpdateStatus(ctx) var err error
if chatIDChanged {
dbMsg.Chat = portal.Key
dbMsg.RoomID = portal.MXID
err = dbMsg.Update(ctx)
} else {
err = dbMsg.UpdateStatus(ctx)
}
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Failed to save updated message status to database") log.Warn().Err(err).Msg("Failed to save updated message status to database")
} }
portal.syncReactions(ctx, source, dbMsg, evt.Reactions)
} }
func (portal *Portal) handleExistingMessage(ctx context.Context, source *User, evt *gmproto.Message, outgoingOnly bool) bool { func (portal *Portal) handleExistingMessage(ctx context.Context, source *User, evt *gmproto.Message, outgoingOnly bool) bool {
@ -500,7 +549,7 @@ func (portal *Portal) handleExistingMessage(ctx context.Context, source *User, e
} else if outgoingOnly { } else if outgoingOnly {
return false return false
} }
existingMsg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Key, evt.MessageID) existingMsg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Receiver, evt.MessageID)
if err != nil { if err != nil {
log.Err(err).Msg("Failed to check if message is duplicate") log.Err(err).Msg("Failed to check if message is duplicate")
} else if existingMsg != nil { } else if existingMsg != nil {
@ -603,7 +652,7 @@ func (portal *Portal) sendMessageParts(ctx context.Context, converted *Converted
func (portal *Portal) syncReactions(ctx context.Context, source *User, message *database.Message, reactions []*gmproto.ReactionEntry) { func (portal *Portal) syncReactions(ctx context.Context, source *User, message *database.Message, reactions []*gmproto.ReactionEntry) {
log := zerolog.Ctx(ctx) log := zerolog.Ctx(ctx)
existing, err := portal.bridge.DB.Reaction.GetAllByMessage(ctx, portal.Key, message.ID) existing, err := portal.bridge.DB.Reaction.GetAllByMessage(ctx, portal.Receiver, message.ID)
if err != nil { if err != nil {
log.Err(err).Msg("Failed to get existing reactions from db to sync reactions") log.Err(err).Msg("Failed to get existing reactions from db to sync reactions")
return return
@ -758,7 +807,7 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
var replyTo id.EventID var replyTo id.EventID
if evt.GetReplyMessage() != nil { if evt.GetReplyMessage() != nil {
cm.ReplyTo = evt.GetReplyMessage().GetMessageID() cm.ReplyTo = evt.GetReplyMessage().GetMessageID()
msg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Key, cm.ReplyTo) msg, err := portal.bridge.DB.Message.GetByID(ctx, portal.Receiver, cm.ReplyTo)
if err != nil { if err != nil {
log.Err(err).Str("reply_to_id", cm.ReplyTo).Msg("Failed to get reply target message") log.Err(err).Str("reply_to_id", cm.ReplyTo).Msg("Failed to get reply target message")
} else if msg == nil { } else if msg == nil {
@ -933,6 +982,7 @@ func (portal *Portal) isRecentlyHandled(id string) bool {
func (portal *Portal) markHandled(cm *ConvertedMessage, eventID id.EventID, mediaParts map[string]database.MediaPart, recent bool) *database.Message { func (portal *Portal) markHandled(cm *ConvertedMessage, eventID id.EventID, mediaParts map[string]database.MediaPart, recent bool) *database.Message {
msg := portal.bridge.DB.Message.New() msg := portal.bridge.DB.Message.New()
msg.Chat = portal.Key msg.Chat = portal.Key
msg.RoomID = portal.MXID
msg.ID = cm.ID msg.ID = cm.ID
msg.MXID = eventID msg.MXID = eventID
msg.Timestamp = cm.Timestamp msg.Timestamp = cm.Timestamp
@ -1691,7 +1741,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
return errTargetNotFound return errTargetNotFound
} }
existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Key, msg.ID, portal.OutgoingID) existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Receiver, msg.ID, portal.OutgoingID)
if err != nil { if err != nil {
log.Err(err).Msg("Failed to get existing reaction") log.Err(err).Msg("Failed to get existing reaction")
return fmt.Errorf("failed to get existing reaction from database") return fmt.Errorf("failed to get existing reaction from database")

View file

@ -900,7 +900,7 @@ func (user *User) markSelfReadFull(portal *Portal, lastMessageID string) {
return return
} }
ctx := context.TODO() ctx := context.TODO()
lastMessage, err := user.bridge.DB.Message.GetByID(ctx, portal.Key, lastMessageID) lastMessage, err := user.bridge.DB.Message.GetByID(ctx, portal.Receiver, lastMessageID)
if err == nil && lastMessage == nil || lastMessage.IsFakeMXID() { if err == nil && lastMessage == nil || lastMessage.IsFakeMXID() {
lastMessage, err = user.bridge.DB.Message.GetLastInChatWithMXID(ctx, portal.Key) lastMessage, err = user.bridge.DB.Message.GetLastInChatWithMXID(ctx, portal.Key)
} }