Track number of parts in message

This commit is contained in:
Tulir Asokan 2023-08-09 19:49:36 +03:00
parent d1ba596504
commit f72cb7d7da
3 changed files with 9 additions and 5 deletions

View file

@ -194,6 +194,7 @@ func (portal *Portal) backfillSendBatch(ctx context.Context, converted []*Conver
dbm.Sender = msg.SenderID dbm.Sender = msg.SenderID
dbm.Timestamp = msg.Timestamp dbm.Timestamp = msg.Timestamp
dbm.Status.Type = msg.Status dbm.Status.Type = msg.Status
dbm.Status.PartCount = msg.PartCount
dbm.Status.MediaStatus = msg.MediaStatus dbm.Status.MediaStatus = msg.MediaStatus
dbm.Status.MediaParts = make(map[string]database.MediaPart, len(msg.Parts)) dbm.Status.MediaParts = make(map[string]database.MediaPart, len(msg.Parts))

View file

@ -82,6 +82,7 @@ type MessageStatus struct {
MediaStatus string `json:"media_status,omitempty"` MediaStatus string `json:"media_status,omitempty"`
MediaParts map[string]MediaPart `json:"media_parts,omitempty"` MediaParts map[string]MediaPart `json:"media_parts,omitempty"`
PartCount int `json:"part_count,omitempty"`
MSSSent bool `json:"mss_sent,omitempty"` MSSSent bool `json:"mss_sent,omitempty"`
MSSFailSent bool `json:"mss_fail_sent,omitempty"` MSSFailSent bool `json:"mss_fail_sent,omitempty"`

View file

@ -325,6 +325,7 @@ func (portal *Portal) isOutgoingMessage(msg *gmproto.Message) *database.Message
ID: msg.MessageID, ID: msg.MessageID,
Timestamp: time.UnixMicro(msg.GetTimestamp()), Timestamp: time.UnixMicro(msg.GetTimestamp()),
SenderID: msg.ParticipantID, SenderID: msg.ParticipantID,
PartCount: len(msg.GetMessageInfo()),
}, out.ID, nil, true) }, out.ID, nil, true)
} }
return nil return nil
@ -414,7 +415,8 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
} }
return return
case dbMsg.Status.MediaStatus != downloadPendingStatusMessage(newStatus), case dbMsg.Status.MediaStatus != downloadPendingStatusMessage(newStatus),
dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt): dbMsg.Status.HasPendingMediaParts() && !hasInProgressMedia(evt),
dbMsg.Status.PartCount != len(evt.MessageInfo):
converted := portal.convertGoogleMessage(ctx, source, evt, false) converted := portal.convertGoogleMessage(ctx, source, evt, false)
dbMsg.Status.MediaStatus = converted.MediaStatus dbMsg.Status.MediaStatus = converted.MediaStatus
if dbMsg.Status.MediaParts == nil { if dbMsg.Status.MediaParts == nil {
@ -479,6 +481,7 @@ func (portal *Portal) handleExistingMessageUpdate(ctx context.Context, source *U
// TODO do something? // TODO do something?
} }
dbMsg.Status.Type = newStatus dbMsg.Status.Type = newStatus
dbMsg.Status.PartCount = len(evt.MessageInfo)
dbMsg.Timestamp = time.UnixMicro(evt.GetTimestamp()) dbMsg.Timestamp = time.UnixMicro(evt.GetTimestamp())
err := dbMsg.UpdateStatus(ctx) err := dbMsg.UpdateStatus(ctx)
if err != nil { if err != nil {
@ -502,10 +505,6 @@ func (portal *Portal) handleMessage(source *User, evt *gmproto.Message) {
Str("action", "handle google message"). Str("action", "handle google message").
Logger() Logger()
ctx := log.WithContext(context.TODO()) ctx := log.WithContext(context.TODO())
//if hasInProgressMedia(evt) {
// log.Debug().Msg("Not handling incoming message that doesn't have full media yet")
// return
//}
if existingMsg := portal.isOutgoingMessage(evt); existingMsg != nil { if existingMsg := portal.isOutgoingMessage(evt); existingMsg != nil {
log.Debug().Str("event_id", existingMsg.MXID.String()).Msg("Got echo for outgoing message") log.Debug().Str("event_id", existingMsg.MXID.String()).Msg("Got echo for outgoing message")
portal.handleExistingMessageUpdate(ctx, source, existingMsg, evt) portal.handleExistingMessageUpdate(ctx, source, existingMsg, evt)
@ -636,6 +635,7 @@ type ConvertedMessage struct {
Timestamp time.Time Timestamp time.Time
ReplyTo string ReplyTo string
Parts []ConvertedMessagePart Parts []ConvertedMessagePart
PartCount int
Status gmproto.MessageStatusType Status gmproto.MessageStatusType
MediaStatus string MediaStatus string
@ -679,6 +679,7 @@ func (portal *Portal) convertGoogleMessage(ctx context.Context, source *User, ev
cm.Status = evt.GetMessageStatus().GetStatus() cm.Status = evt.GetMessageStatus().GetStatus()
cm.SenderID = evt.ParticipantID cm.SenderID = evt.ParticipantID
cm.ID = evt.MessageID cm.ID = evt.MessageID
cm.PartCount = len(evt.GetMessageInfo())
cm.Timestamp = time.UnixMicro(evt.Timestamp) cm.Timestamp = time.UnixMicro(evt.Timestamp)
if cm.Status >= 200 && cm.Status < 300 { if cm.Status >= 200 && cm.Status < 300 {
cm.Intent = portal.bridge.Bot cm.Intent = portal.bridge.Bot
@ -873,6 +874,7 @@ func (portal *Portal) markHandled(cm *ConvertedMessage, eventID id.EventID, medi
msg.Timestamp = cm.Timestamp msg.Timestamp = cm.Timestamp
msg.Sender = cm.SenderID msg.Sender = cm.SenderID
msg.Status.Type = cm.Status msg.Status.Type = cm.Status
msg.Status.PartCount = cm.PartCount
msg.Status.MediaStatus = cm.MediaStatus msg.Status.MediaStatus = cm.MediaStatus
msg.Status.MediaParts = mediaParts msg.Status.MediaParts = mediaParts
err := msg.Insert(context.TODO()) err := msg.Insert(context.TODO())