Store all own participant IDs for proper multi-sim support

This commit is contained in:
Tulir Asokan 2023-07-19 20:32:01 +03:00
parent bf277e197f
commit 907e6af77b
16 changed files with 99 additions and 55 deletions

View file

@ -51,7 +51,7 @@ func (user *User) GetRemoteID() string {
if user == nil {
return ""
}
return user.Phone
return user.PhoneID
}
func (user *User) GetRemoteName() string {

View file

@ -263,9 +263,9 @@ func fnPing(ce *WrappedCommandEvent) {
ce.Reply("You're not logged into Google Messages.")
}
} else if ce.User.Client == nil || !ce.User.Client.IsConnected() {
ce.Reply("You're logged in as %s, but you don't have a Google Messages connection.", ce.User.Phone)
ce.Reply("You're logged in as %s, but you don't have a Google Messages connection.", ce.User.PhoneID)
} else {
ce.Reply("Logged in as %s, connection to Google Messages may be OK", ce.User.Phone)
ce.Reply("Logged in as %s, connection to Google Messages may be OK", ce.User.PhoneID)
}
}

View file

@ -74,7 +74,7 @@ type Portal struct {
db *Database
Key
SelfUserID string
OutgoingID string
OtherUserID string
MXID id.RoomID
@ -96,7 +96,7 @@ func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
return nil, err
}
portal.MXID = id.RoomID(mxid.String)
portal.SelfUserID = selfUserID.String
portal.OutgoingID = selfUserID.String
portal.OtherUserID = otherUserID.String
return portal, nil
}
@ -106,8 +106,8 @@ func (portal *Portal) sqlVariables() []any {
if portal.MXID != "" {
mxid = (*string)(&portal.MXID)
}
if portal.SelfUserID != "" {
selfUserID = &portal.SelfUserID
if portal.OutgoingID != "" {
selfUserID = &portal.OutgoingID
}
if portal.OtherUserID != "" {
otherUserID = &portal.OtherUserID

View file

@ -1,4 +1,4 @@
-- v0 -> v1: Latest revision
-- v0 -> v2: Latest revision
CREATE TABLE "user" (
-- only: postgres
@ -6,9 +6,11 @@ CREATE TABLE "user" (
-- only: sqlite
rowid INTEGER PRIMARY KEY,
mxid TEXT NOT NULL UNIQUE,
phone TEXT UNIQUE,
session jsonb,
mxid TEXT NOT NULL UNIQUE,
phone_id TEXT UNIQUE,
session jsonb,
self_participant_ids jsonb NOT NULL DEFAULT '[]',
management_room TEXT,
space_room TEXT,

View file

@ -0,0 +1,3 @@
-- v2: Store all self-participant IDs
ALTER TABLE "user" RENAME COLUMN phone TO phone_id;
ALTER TABLE "user" ADD COLUMN self_participant_ids jsonb NOT NULL DEFAULT '[]';

View file

@ -22,7 +22,9 @@ import (
"encoding/json"
"errors"
"fmt"
"sync"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
@ -44,23 +46,23 @@ func (uq *UserQuery) getDB() *Database {
}
func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) {
return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE phone<>'' AND session IS NOT NULL`)
return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE session IS NOT NULL`)
}
func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) {
return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE access_token<>''`)
return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE access_token<>''`)
}
func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) {
return get[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE rowid=$1`, rowID)
return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE rowid=$1`, rowID)
}
func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
return get[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE mxid=$1`, userID)
return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE mxid=$1`, userID)
}
func (uq *UserQuery) GetByPhone(ctx context.Context, phone string) (*User, error) {
return get[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE phone=$1`, phone)
return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE phone_id=$1`, phone)
}
type User struct {
@ -68,18 +70,22 @@ type User struct {
RowID int
MXID id.UserID
Phone string
PhoneID string
Session *libgm.AuthData
ManagementRoom id.RoomID
SpaceRoom id.RoomID
SelfParticipantIDs []string
selfParticipantIDsLock sync.RWMutex
AccessToken string
}
func (user *User) Scan(row dbutil.Scannable) (*User, error) {
var phone, session, managementRoom, spaceRoom, accessToken sql.NullString
err := row.Scan(&user.RowID, &user.MXID, &phone, &session, &managementRoom, &spaceRoom, &accessToken)
var phoneID, session, managementRoom, spaceRoom, accessToken sql.NullString
var selfParticipantIDs string
err := row.Scan(&user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &managementRoom, &spaceRoom, &accessToken)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
@ -93,7 +99,13 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) {
}
user.Session = &sess
}
user.Phone = phone.String
user.selfParticipantIDsLock.Lock()
err = json.Unmarshal([]byte(selfParticipantIDs), &user.SelfParticipantIDs)
user.selfParticipantIDsLock.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to parse self participant IDs: %w", err)
}
user.PhoneID = phoneID.String
user.AccessToken = accessToken.String
user.ManagementRoom = id.RoomID(managementRoom.String)
user.SpaceRoom = id.RoomID(spaceRoom.String)
@ -101,9 +113,9 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) {
}
func (user *User) sqlVariables() []any {
var phone, session, managementRoom, spaceRoom, accessToken *string
if user.Phone != "" {
phone = &user.Phone
var phoneID, session, managementRoom, spaceRoom, accessToken *string
if user.PhoneID != "" {
phoneID = &user.PhoneID
}
if user.Session != nil {
data, _ := json.Marshal(user.Session)
@ -119,17 +131,38 @@ func (user *User) sqlVariables() []any {
if user.AccessToken != "" {
accessToken = &user.AccessToken
}
return []any{user.MXID, phone, session, managementRoom, spaceRoom, accessToken}
user.selfParticipantIDsLock.RLock()
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
user.selfParticipantIDsLock.RUnlock()
return []any{user.MXID, phoneID, session, string(selfParticipantIDs), managementRoom, spaceRoom, accessToken}
}
func (user *User) IsSelfParticipantID(id string) bool {
user.selfParticipantIDsLock.RLock()
defer user.selfParticipantIDsLock.RUnlock()
return slices.Contains(user.SelfParticipantIDs, id)
}
func (user *User) AddSelfParticipantID(ctx context.Context, id string) error {
user.selfParticipantIDsLock.Lock()
defer user.selfParticipantIDsLock.Unlock()
if !slices.Contains(user.SelfParticipantIDs, id) {
user.SelfParticipantIDs = append(user.SelfParticipantIDs, id)
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
_, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`, user.MXID, selfParticipantIDs)
return err
}
return nil
}
func (user *User) Insert(ctx context.Context) error {
err := user.db.Conn(ctx).
QueryRowContext(ctx, `INSERT INTO "user" (mxid, phone, session, management_room, space_room, access_token) VALUES ($1, $2, $3, $4, $5, $6) RETURNING rowid`, user.sqlVariables()...).
QueryRowContext(ctx, `INSERT INTO "user" (mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING rowid`, user.sqlVariables()...).
Scan(&user.RowID)
return err
}
func (user *User) Update(ctx context.Context) error {
_, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET phone=$2, session=$3, management_room=$4, space_room=$5, access_token=$6 WHERE mxid=$1`, user.sqlVariables()...)
_, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET phone_id=$2, session=$3, self_participant_ids=$4, management_room=$5, space_room=$6, access_token=$7 WHERE mxid=$1`, user.sqlVariables()...)
return err
}

4
go.mod
View file

@ -8,8 +8,9 @@ require (
github.com/rs/zerolog v1.29.1
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
go.mau.fi/mautrix-gmessages/libgm v0.1.0
google.golang.org/protobuf v1.31.0
maunium.net/go/maulogger/v2 v2.4.1
maunium.net/go/mautrix v0.15.4
maunium.net/go/mautrix v0.15.5-0.20230719135321-8c3bd7722909
)
require (
@ -33,7 +34,6 @@ require (
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect
golang.org/x/net v0.12.0 // indirect
golang.org/x/sys v0.10.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

4
go.sum
View file

@ -81,5 +81,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
maunium.net/go/mautrix v0.15.4 h1:Ug3n2Mo+9Yb94AjZTWJQSNHmShaksEzZi85EPl3S3P0=
maunium.net/go/mautrix v0.15.4/go.mod h1:dBaDmsnOOBM4a+gKcgefXH73pHGXm+MCJzCs1dXFgrw=
maunium.net/go/mautrix v0.15.5-0.20230719135321-8c3bd7722909 h1:Bw7WMyeODxjGjBK3v9+kE7Dm5b8HAEawmYMmYYSXvow=
maunium.net/go/mautrix v0.15.5-0.20230719135321-8c3bd7722909/go.mod h1:dBaDmsnOOBM4a+gKcgefXH73pHGXm+MCJzCs1dXFgrw=

View file

@ -2598,7 +2598,7 @@ type MessagePayload struct {
TmpID string `protobuf:"bytes,1,opt,name=tmpID,proto3" json:"tmpID,omitempty"`
MessagePayloadContent *MessagePayloadContent `protobuf:"bytes,6,opt,name=messagePayloadContent,proto3" json:"messagePayloadContent,omitempty"`
ConversationID string `protobuf:"bytes,7,opt,name=conversationID,proto3" json:"conversationID,omitempty"`
SelfParticipantID string `protobuf:"bytes,9,opt,name=selfParticipantID,proto3" json:"selfParticipantID,omitempty"` // might be participantID
ParticipantID string `protobuf:"bytes,9,opt,name=participantID,proto3" json:"participantID,omitempty"`
MessageInfo []*MessageInfo `protobuf:"bytes,10,rep,name=messageInfo,proto3" json:"messageInfo,omitempty"`
TmpID2 string `protobuf:"bytes,12,opt,name=tmpID2,proto3" json:"tmpID2,omitempty"`
}
@ -2656,9 +2656,9 @@ func (x *MessagePayload) GetConversationID() string {
return ""
}
func (x *MessagePayload) GetSelfParticipantID() string {
func (x *MessagePayload) GetParticipantID() string {
if x != nil {
return x.SelfParticipantID
return x.ParticipantID
}
return ""
}

Binary file not shown.

View file

@ -269,7 +269,7 @@ message MessagePayload {
string tmpID = 1;
MessagePayloadContent messagePayloadContent = 6;
string conversationID = 7;
string selfParticipantID = 9; // might be participantID
string participantID = 9;
repeated conversations.MessageInfo messageInfo = 10;
string tmpID2 = 12;
}

View file

@ -1760,7 +1760,7 @@ type Conversation struct {
LastMessageTimestamp int64 `protobuf:"varint,5,opt,name=lastMessageTimestamp,proto3" json:"lastMessageTimestamp,omitempty"`
Unread bool `protobuf:"varint,6,opt,name=unread,proto3" json:"unread,omitempty"`
IsGroupChat bool `protobuf:"varint,10,opt,name=isGroupChat,proto3" json:"isGroupChat,omitempty"` // not certain
SelfParticipantID string `protobuf:"bytes,11,opt,name=selfParticipantID,proto3" json:"selfParticipantID,omitempty"`
DefaultOutgoingID string `protobuf:"bytes,11,opt,name=defaultOutgoingID,proto3" json:"defaultOutgoingID,omitempty"`
// bool bool1 = 13;
Status ConvUpdateTypes `protobuf:"varint,12,opt,name=status,proto3,enum=conversations.ConvUpdateTypes" json:"status,omitempty"`
AvatarHexColor string `protobuf:"bytes,15,opt,name=avatarHexColor,proto3" json:"avatarHexColor,omitempty"`
@ -1847,9 +1847,9 @@ func (x *Conversation) GetIsGroupChat() bool {
return false
}
func (x *Conversation) GetSelfParticipantID() string {
func (x *Conversation) GetDefaultOutgoingID() string {
if x != nil {
return x.SelfParticipantID
return x.DefaultOutgoingID
}
return ""
}

Binary file not shown.

View file

@ -130,7 +130,7 @@ message Conversation {
bool unread = 6;
bool isGroupChat = 10; // not certain
string selfParticipantID = 11;
string defaultOutgoingID = 11;
//bool bool1 = 13;
ConvUpdateTypes status = 12;

View file

@ -518,7 +518,7 @@ type ConvertedMessage struct {
}
func (portal *Portal) getIntent(ctx context.Context, source *User, participant string) *appservice.IntentAPI {
if participant == portal.SelfUserID {
if source.IsSelfParticipantID(participant) {
intent := source.DoublePuppetIntent
if intent == nil {
zerolog.Ctx(ctx).Debug().Msg("Dropping message from self as double puppeting is not enabled")
@ -712,6 +712,12 @@ func (portal *Portal) SyncParticipants(source *User, metadata *gmproto.Conversat
var manyParticipants bool
for _, participant := range metadata.Participants {
if participant.IsMe {
err := source.AddSelfParticipantID(context.TODO(), participant.ID.ParticipantID)
if err != nil {
portal.zlog.Warn().Err(err).
Str("participant_id", participant.ID.ParticipantID).
Msg("Failed to save self participant ID")
}
continue
} else if participant.ID.Number == "" {
portal.zlog.Warn().Interface("participant", participant).Msg("No number found in non-self participant entry")
@ -784,8 +790,8 @@ func (portal *Portal) UpdateName(name string, updateInfo bool) bool {
func (portal *Portal) UpdateMetadata(user *User, info *gmproto.Conversation) []id.UserID {
participants, update := portal.SyncParticipants(user, info)
if portal.SelfUserID != info.SelfParticipantID {
portal.SelfUserID = info.SelfParticipantID
if portal.OutgoingID != info.DefaultOutgoingID {
portal.OutgoingID = info.DefaultOutgoingID
update = true
}
if portal.MXID != "" {
@ -1178,10 +1184,10 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, co
ConversationID: portal.ID,
TmpID: txnID,
MessagePayload: &gmproto.MessagePayload{
ConversationID: portal.ID,
TmpID: txnID,
TmpID2: txnID,
SelfParticipantID: portal.SelfUserID,
ConversationID: portal.ID,
TmpID: txnID,
TmpID2: txnID,
ParticipantID: portal.OutgoingID,
},
}
@ -1342,7 +1348,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
return errTargetNotFound
}
existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Key, msg.ID, portal.SelfUserID)
existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Key, msg.ID, portal.OutgoingID)
if err != nil {
log.Err(err).Msg("Failed to get existing reaction")
return fmt.Errorf("failed to get existing reaction from database")
@ -1367,7 +1373,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
existingReaction = portal.bridge.DB.Reaction.New()
existingReaction.Chat = portal.Key
existingReaction.MessageID = msg.ID
existingReaction.Sender = portal.SelfUserID
existingReaction.Sender = portal.OutgoingID
} else if sender.DoublePuppetIntent != nil {
_, err = sender.DoublePuppetIntent.RedactEvent(portal.MXID, existingReaction.MXID)
if err != nil {

16
user.go
View file

@ -173,15 +173,15 @@ func (br *GMBridge) GetUserByPhone(phone string) *User {
func (user *User) addToPhoneMap() {
user.bridge.usersLock.Lock()
user.bridge.usersByPhone[user.Phone] = user
user.bridge.usersByPhone[user.PhoneID] = user
user.bridge.usersLock.Unlock()
}
func (user *User) removeFromPhoneMap(state status.BridgeState) {
user.bridge.usersLock.Lock()
phoneUser, ok := user.bridge.usersByPhone[user.Phone]
phoneUser, ok := user.bridge.usersByPhone[user.PhoneID]
if ok && user == phoneUser {
delete(user.bridge.usersByPhone, user.Phone)
delete(user.bridge.usersByPhone, user.PhoneID)
}
user.bridge.usersLock.Unlock()
user.BridgeState.Send(state)
@ -231,11 +231,11 @@ func (br *GMBridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
}
user := br.NewUser(dbUser)
br.usersByMXID[user.MXID] = user
if user.Session != nil && user.Phone != "" {
br.usersByPhone[user.Phone] = user
if user.Session != nil && user.PhoneID != "" {
br.usersByPhone[user.PhoneID] = user
} else {
user.Session = nil
user.Phone = ""
user.PhoneID = ""
}
if len(user.ManagementRoom) > 0 {
br.managementRooms[user.ManagementRoom] = user
@ -540,7 +540,7 @@ func (user *User) HasSession() bool {
func (user *User) DeleteSession() {
user.Session = nil
user.Phone = ""
user.PhoneID = ""
err := user.Update(context.TODO())
if err != nil {
user.zlog.Err(err).Msg("Failed to delete session from database")
@ -601,7 +601,7 @@ func (user *User) HandleEvent(event interface{}) {
user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected})
case *events.PairSuccessful:
user.Session = user.Client.AuthData
user.Phone = v.GetMobile().GetSourceID()
user.PhoneID = v.GetMobile().GetSourceID()
user.tryAutomaticDoublePuppeting()
user.addToPhoneMap()
err := user.Update(context.TODO())