diff --git a/bridgestate.go b/bridgestate.go index 6fd463c..f67ba1a 100644 --- a/bridgestate.go +++ b/bridgestate.go @@ -51,7 +51,7 @@ func (user *User) GetRemoteID() string { if user == nil { return "" } - return user.Phone + return user.PhoneID } func (user *User) GetRemoteName() string { diff --git a/commands.go b/commands.go index c4f7283..9990824 100644 --- a/commands.go +++ b/commands.go @@ -263,9 +263,9 @@ func fnPing(ce *WrappedCommandEvent) { ce.Reply("You're not logged into Google Messages.") } } else if ce.User.Client == nil || !ce.User.Client.IsConnected() { - ce.Reply("You're logged in as %s, but you don't have a Google Messages connection.", ce.User.Phone) + ce.Reply("You're logged in as %s, but you don't have a Google Messages connection.", ce.User.PhoneID) } else { - ce.Reply("Logged in as %s, connection to Google Messages may be OK", ce.User.Phone) + ce.Reply("Logged in as %s, connection to Google Messages may be OK", ce.User.PhoneID) } } diff --git a/database/portal.go b/database/portal.go index ea37137..7b8b978 100644 --- a/database/portal.go +++ b/database/portal.go @@ -74,7 +74,7 @@ type Portal struct { db *Database Key - SelfUserID string + OutgoingID string OtherUserID string MXID id.RoomID @@ -96,7 +96,7 @@ func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) { return nil, err } portal.MXID = id.RoomID(mxid.String) - portal.SelfUserID = selfUserID.String + portal.OutgoingID = selfUserID.String portal.OtherUserID = otherUserID.String return portal, nil } @@ -106,8 +106,8 @@ func (portal *Portal) sqlVariables() []any { if portal.MXID != "" { mxid = (*string)(&portal.MXID) } - if portal.SelfUserID != "" { - selfUserID = &portal.SelfUserID + if portal.OutgoingID != "" { + selfUserID = &portal.OutgoingID } if portal.OtherUserID != "" { otherUserID = &portal.OtherUserID diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index db0e5b9..c88c537 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v1: Latest revision +-- v0 -> v2: Latest revision CREATE TABLE "user" ( -- only: postgres @@ -6,9 +6,11 @@ CREATE TABLE "user" ( -- only: sqlite rowid INTEGER PRIMARY KEY, - mxid TEXT NOT NULL UNIQUE, - phone TEXT UNIQUE, - session jsonb, + mxid TEXT NOT NULL UNIQUE, + phone_id TEXT UNIQUE, + session jsonb, + + self_participant_ids jsonb NOT NULL DEFAULT '[]', management_room TEXT, space_room TEXT, diff --git a/database/upgrades/02-all-own-ids.sql b/database/upgrades/02-all-own-ids.sql new file mode 100644 index 0000000..6d7e82f --- /dev/null +++ b/database/upgrades/02-all-own-ids.sql @@ -0,0 +1,3 @@ +-- v2: Store all self-participant IDs +ALTER TABLE "user" RENAME COLUMN phone TO phone_id; +ALTER TABLE "user" ADD COLUMN self_participant_ids jsonb NOT NULL DEFAULT '[]'; diff --git a/database/user.go b/database/user.go index 80c249f..0335b0d 100644 --- a/database/user.go +++ b/database/user.go @@ -22,7 +22,9 @@ import ( "encoding/json" "errors" "fmt" + "sync" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/util/dbutil" @@ -44,23 +46,23 @@ func (uq *UserQuery) getDB() *Database { } func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) { - return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE phone<>'' AND session IS NOT NULL`) + return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE session IS NOT NULL`) } func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) { - return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE access_token<>''`) + return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE access_token<>''`) } func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) { - return get[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE rowid=$1`, rowID) + return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE rowid=$1`, rowID) } func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) { - return get[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE mxid=$1`, userID) + return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE mxid=$1`, userID) } func (uq *UserQuery) GetByPhone(ctx context.Context, phone string) (*User, error) { - return get[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE phone=$1`, phone) + return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE phone_id=$1`, phone) } type User struct { @@ -68,18 +70,22 @@ type User struct { RowID int MXID id.UserID - Phone string + PhoneID string Session *libgm.AuthData ManagementRoom id.RoomID SpaceRoom id.RoomID + SelfParticipantIDs []string + selfParticipantIDsLock sync.RWMutex + AccessToken string } func (user *User) Scan(row dbutil.Scannable) (*User, error) { - var phone, session, managementRoom, spaceRoom, accessToken sql.NullString - err := row.Scan(&user.RowID, &user.MXID, &phone, &session, &managementRoom, &spaceRoom, &accessToken) + var phoneID, session, managementRoom, spaceRoom, accessToken sql.NullString + var selfParticipantIDs string + err := row.Scan(&user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &managementRoom, &spaceRoom, &accessToken) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -93,7 +99,13 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) { } user.Session = &sess } - user.Phone = phone.String + user.selfParticipantIDsLock.Lock() + err = json.Unmarshal([]byte(selfParticipantIDs), &user.SelfParticipantIDs) + user.selfParticipantIDsLock.Unlock() + if err != nil { + return nil, fmt.Errorf("failed to parse self participant IDs: %w", err) + } + user.PhoneID = phoneID.String user.AccessToken = accessToken.String user.ManagementRoom = id.RoomID(managementRoom.String) user.SpaceRoom = id.RoomID(spaceRoom.String) @@ -101,9 +113,9 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) { } func (user *User) sqlVariables() []any { - var phone, session, managementRoom, spaceRoom, accessToken *string - if user.Phone != "" { - phone = &user.Phone + var phoneID, session, managementRoom, spaceRoom, accessToken *string + if user.PhoneID != "" { + phoneID = &user.PhoneID } if user.Session != nil { data, _ := json.Marshal(user.Session) @@ -119,17 +131,38 @@ func (user *User) sqlVariables() []any { if user.AccessToken != "" { accessToken = &user.AccessToken } - return []any{user.MXID, phone, session, managementRoom, spaceRoom, accessToken} + user.selfParticipantIDsLock.RLock() + selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs) + user.selfParticipantIDsLock.RUnlock() + return []any{user.MXID, phoneID, session, string(selfParticipantIDs), managementRoom, spaceRoom, accessToken} +} + +func (user *User) IsSelfParticipantID(id string) bool { + user.selfParticipantIDsLock.RLock() + defer user.selfParticipantIDsLock.RUnlock() + return slices.Contains(user.SelfParticipantIDs, id) +} + +func (user *User) AddSelfParticipantID(ctx context.Context, id string) error { + user.selfParticipantIDsLock.Lock() + defer user.selfParticipantIDsLock.Unlock() + if !slices.Contains(user.SelfParticipantIDs, id) { + user.SelfParticipantIDs = append(user.SelfParticipantIDs, id) + selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs) + _, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`, user.MXID, selfParticipantIDs) + return err + } + return nil } func (user *User) Insert(ctx context.Context) error { err := user.db.Conn(ctx). - QueryRowContext(ctx, `INSERT INTO "user" (mxid, phone, session, management_room, space_room, access_token) VALUES ($1, $2, $3, $4, $5, $6) RETURNING rowid`, user.sqlVariables()...). + QueryRowContext(ctx, `INSERT INTO "user" (mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING rowid`, user.sqlVariables()...). Scan(&user.RowID) return err } func (user *User) Update(ctx context.Context) error { - _, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET phone=$2, session=$3, management_room=$4, space_room=$5, access_token=$6 WHERE mxid=$1`, user.sqlVariables()...) + _, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET phone_id=$2, session=$3, self_participant_ids=$4, management_room=$5, space_room=$6, access_token=$7 WHERE mxid=$1`, user.sqlVariables()...) return err } diff --git a/go.mod b/go.mod index 18ce7bd..af8b998 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,9 @@ require ( github.com/rs/zerolog v1.29.1 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e go.mau.fi/mautrix-gmessages/libgm v0.1.0 + google.golang.org/protobuf v1.31.0 maunium.net/go/maulogger/v2 v2.4.1 - maunium.net/go/mautrix v0.15.4 + maunium.net/go/mautrix v0.15.5-0.20230719135321-8c3bd7722909 ) require ( @@ -33,7 +34,6 @@ require ( golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect golang.org/x/net v0.12.0 // indirect golang.org/x/sys v0.10.0 // indirect - google.golang.org/protobuf v1.31.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index de396cf..f622df1 100644 --- a/go.sum +++ b/go.sum @@ -81,5 +81,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.15.4 h1:Ug3n2Mo+9Yb94AjZTWJQSNHmShaksEzZi85EPl3S3P0= -maunium.net/go/mautrix v0.15.4/go.mod h1:dBaDmsnOOBM4a+gKcgefXH73pHGXm+MCJzCs1dXFgrw= +maunium.net/go/mautrix v0.15.5-0.20230719135321-8c3bd7722909 h1:Bw7WMyeODxjGjBK3v9+kE7Dm5b8HAEawmYMmYYSXvow= +maunium.net/go/mautrix v0.15.5-0.20230719135321-8c3bd7722909/go.mod h1:dBaDmsnOOBM4a+gKcgefXH73pHGXm+MCJzCs1dXFgrw= diff --git a/libgm/gmproto/client.pb.go b/libgm/gmproto/client.pb.go index 7266955..800fb46 100644 --- a/libgm/gmproto/client.pb.go +++ b/libgm/gmproto/client.pb.go @@ -2598,7 +2598,7 @@ type MessagePayload struct { TmpID string `protobuf:"bytes,1,opt,name=tmpID,proto3" json:"tmpID,omitempty"` MessagePayloadContent *MessagePayloadContent `protobuf:"bytes,6,opt,name=messagePayloadContent,proto3" json:"messagePayloadContent,omitempty"` ConversationID string `protobuf:"bytes,7,opt,name=conversationID,proto3" json:"conversationID,omitempty"` - SelfParticipantID string `protobuf:"bytes,9,opt,name=selfParticipantID,proto3" json:"selfParticipantID,omitempty"` // might be participantID + ParticipantID string `protobuf:"bytes,9,opt,name=participantID,proto3" json:"participantID,omitempty"` MessageInfo []*MessageInfo `protobuf:"bytes,10,rep,name=messageInfo,proto3" json:"messageInfo,omitempty"` TmpID2 string `protobuf:"bytes,12,opt,name=tmpID2,proto3" json:"tmpID2,omitempty"` } @@ -2656,9 +2656,9 @@ func (x *MessagePayload) GetConversationID() string { return "" } -func (x *MessagePayload) GetSelfParticipantID() string { +func (x *MessagePayload) GetParticipantID() string { if x != nil { - return x.SelfParticipantID + return x.ParticipantID } return "" } diff --git a/libgm/gmproto/client.pb.raw b/libgm/gmproto/client.pb.raw index 7660219..b91a56f 100644 Binary files a/libgm/gmproto/client.pb.raw and b/libgm/gmproto/client.pb.raw differ diff --git a/libgm/gmproto/client.proto b/libgm/gmproto/client.proto index 9bf6935..c333f3e 100644 --- a/libgm/gmproto/client.proto +++ b/libgm/gmproto/client.proto @@ -269,7 +269,7 @@ message MessagePayload { string tmpID = 1; MessagePayloadContent messagePayloadContent = 6; string conversationID = 7; - string selfParticipantID = 9; // might be participantID + string participantID = 9; repeated conversations.MessageInfo messageInfo = 10; string tmpID2 = 12; } diff --git a/libgm/gmproto/conversations.pb.go b/libgm/gmproto/conversations.pb.go index d5c75cd..3fe418d 100644 --- a/libgm/gmproto/conversations.pb.go +++ b/libgm/gmproto/conversations.pb.go @@ -1760,7 +1760,7 @@ type Conversation struct { LastMessageTimestamp int64 `protobuf:"varint,5,opt,name=lastMessageTimestamp,proto3" json:"lastMessageTimestamp,omitempty"` Unread bool `protobuf:"varint,6,opt,name=unread,proto3" json:"unread,omitempty"` IsGroupChat bool `protobuf:"varint,10,opt,name=isGroupChat,proto3" json:"isGroupChat,omitempty"` // not certain - SelfParticipantID string `protobuf:"bytes,11,opt,name=selfParticipantID,proto3" json:"selfParticipantID,omitempty"` + DefaultOutgoingID string `protobuf:"bytes,11,opt,name=defaultOutgoingID,proto3" json:"defaultOutgoingID,omitempty"` // bool bool1 = 13; Status ConvUpdateTypes `protobuf:"varint,12,opt,name=status,proto3,enum=conversations.ConvUpdateTypes" json:"status,omitempty"` AvatarHexColor string `protobuf:"bytes,15,opt,name=avatarHexColor,proto3" json:"avatarHexColor,omitempty"` @@ -1847,9 +1847,9 @@ func (x *Conversation) GetIsGroupChat() bool { return false } -func (x *Conversation) GetSelfParticipantID() string { +func (x *Conversation) GetDefaultOutgoingID() string { if x != nil { - return x.SelfParticipantID + return x.DefaultOutgoingID } return "" } diff --git a/libgm/gmproto/conversations.pb.raw b/libgm/gmproto/conversations.pb.raw index 56a8d1b..6edb46f 100644 Binary files a/libgm/gmproto/conversations.pb.raw and b/libgm/gmproto/conversations.pb.raw differ diff --git a/libgm/gmproto/conversations.proto b/libgm/gmproto/conversations.proto index f6044b8..cec73dc 100644 --- a/libgm/gmproto/conversations.proto +++ b/libgm/gmproto/conversations.proto @@ -130,7 +130,7 @@ message Conversation { bool unread = 6; bool isGroupChat = 10; // not certain - string selfParticipantID = 11; + string defaultOutgoingID = 11; //bool bool1 = 13; ConvUpdateTypes status = 12; diff --git a/portal.go b/portal.go index 6ff1f0d..5ff3d62 100644 --- a/portal.go +++ b/portal.go @@ -518,7 +518,7 @@ type ConvertedMessage struct { } func (portal *Portal) getIntent(ctx context.Context, source *User, participant string) *appservice.IntentAPI { - if participant == portal.SelfUserID { + if source.IsSelfParticipantID(participant) { intent := source.DoublePuppetIntent if intent == nil { zerolog.Ctx(ctx).Debug().Msg("Dropping message from self as double puppeting is not enabled") @@ -712,6 +712,12 @@ func (portal *Portal) SyncParticipants(source *User, metadata *gmproto.Conversat var manyParticipants bool for _, participant := range metadata.Participants { if participant.IsMe { + err := source.AddSelfParticipantID(context.TODO(), participant.ID.ParticipantID) + if err != nil { + portal.zlog.Warn().Err(err). + Str("participant_id", participant.ID.ParticipantID). + Msg("Failed to save self participant ID") + } continue } else if participant.ID.Number == "" { portal.zlog.Warn().Interface("participant", participant).Msg("No number found in non-self participant entry") @@ -784,8 +790,8 @@ func (portal *Portal) UpdateName(name string, updateInfo bool) bool { func (portal *Portal) UpdateMetadata(user *User, info *gmproto.Conversation) []id.UserID { participants, update := portal.SyncParticipants(user, info) - if portal.SelfUserID != info.SelfParticipantID { - portal.SelfUserID = info.SelfParticipantID + if portal.OutgoingID != info.DefaultOutgoingID { + portal.OutgoingID = info.DefaultOutgoingID update = true } if portal.MXID != "" { @@ -1178,10 +1184,10 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, co ConversationID: portal.ID, TmpID: txnID, MessagePayload: &gmproto.MessagePayload{ - ConversationID: portal.ID, - TmpID: txnID, - TmpID2: txnID, - SelfParticipantID: portal.SelfUserID, + ConversationID: portal.ID, + TmpID: txnID, + TmpID2: txnID, + ParticipantID: portal.OutgoingID, }, } @@ -1342,7 +1348,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error return errTargetNotFound } - existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Key, msg.ID, portal.SelfUserID) + existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Key, msg.ID, portal.OutgoingID) if err != nil { log.Err(err).Msg("Failed to get existing reaction") return fmt.Errorf("failed to get existing reaction from database") @@ -1367,7 +1373,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error existingReaction = portal.bridge.DB.Reaction.New() existingReaction.Chat = portal.Key existingReaction.MessageID = msg.ID - existingReaction.Sender = portal.SelfUserID + existingReaction.Sender = portal.OutgoingID } else if sender.DoublePuppetIntent != nil { _, err = sender.DoublePuppetIntent.RedactEvent(portal.MXID, existingReaction.MXID) if err != nil { diff --git a/user.go b/user.go index 44f41f3..58e6a81 100644 --- a/user.go +++ b/user.go @@ -173,15 +173,15 @@ func (br *GMBridge) GetUserByPhone(phone string) *User { func (user *User) addToPhoneMap() { user.bridge.usersLock.Lock() - user.bridge.usersByPhone[user.Phone] = user + user.bridge.usersByPhone[user.PhoneID] = user user.bridge.usersLock.Unlock() } func (user *User) removeFromPhoneMap(state status.BridgeState) { user.bridge.usersLock.Lock() - phoneUser, ok := user.bridge.usersByPhone[user.Phone] + phoneUser, ok := user.bridge.usersByPhone[user.PhoneID] if ok && user == phoneUser { - delete(user.bridge.usersByPhone, user.Phone) + delete(user.bridge.usersByPhone, user.PhoneID) } user.bridge.usersLock.Unlock() user.BridgeState.Send(state) @@ -231,11 +231,11 @@ func (br *GMBridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User { } user := br.NewUser(dbUser) br.usersByMXID[user.MXID] = user - if user.Session != nil && user.Phone != "" { - br.usersByPhone[user.Phone] = user + if user.Session != nil && user.PhoneID != "" { + br.usersByPhone[user.PhoneID] = user } else { user.Session = nil - user.Phone = "" + user.PhoneID = "" } if len(user.ManagementRoom) > 0 { br.managementRooms[user.ManagementRoom] = user @@ -540,7 +540,7 @@ func (user *User) HasSession() bool { func (user *User) DeleteSession() { user.Session = nil - user.Phone = "" + user.PhoneID = "" err := user.Update(context.TODO()) if err != nil { user.zlog.Err(err).Msg("Failed to delete session from database") @@ -601,7 +601,7 @@ func (user *User) HandleEvent(event interface{}) { user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) case *events.PairSuccessful: user.Session = user.Client.AuthData - user.Phone = v.GetMobile().GetSourceID() + user.PhoneID = v.GetMobile().GetSourceID() user.tryAutomaticDoublePuppeting() user.addToPhoneMap() err := user.Update(context.TODO())