Store all own participant IDs for proper multi-sim support

This commit is contained in:
Tulir Asokan 2023-07-19 20:32:01 +03:00
parent bf277e197f
commit 907e6af77b
16 changed files with 99 additions and 55 deletions

View file

@ -51,7 +51,7 @@ func (user *User) GetRemoteID() string {
if user == nil { if user == nil {
return "" return ""
} }
return user.Phone return user.PhoneID
} }
func (user *User) GetRemoteName() string { func (user *User) GetRemoteName() string {

View file

@ -263,9 +263,9 @@ func fnPing(ce *WrappedCommandEvent) {
ce.Reply("You're not logged into Google Messages.") ce.Reply("You're not logged into Google Messages.")
} }
} else if ce.User.Client == nil || !ce.User.Client.IsConnected() { } else if ce.User.Client == nil || !ce.User.Client.IsConnected() {
ce.Reply("You're logged in as %s, but you don't have a Google Messages connection.", ce.User.Phone) ce.Reply("You're logged in as %s, but you don't have a Google Messages connection.", ce.User.PhoneID)
} else { } else {
ce.Reply("Logged in as %s, connection to Google Messages may be OK", ce.User.Phone) ce.Reply("Logged in as %s, connection to Google Messages may be OK", ce.User.PhoneID)
} }
} }

View file

@ -74,7 +74,7 @@ type Portal struct {
db *Database db *Database
Key Key
SelfUserID string OutgoingID string
OtherUserID string OtherUserID string
MXID id.RoomID MXID id.RoomID
@ -96,7 +96,7 @@ func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
return nil, err return nil, err
} }
portal.MXID = id.RoomID(mxid.String) portal.MXID = id.RoomID(mxid.String)
portal.SelfUserID = selfUserID.String portal.OutgoingID = selfUserID.String
portal.OtherUserID = otherUserID.String portal.OtherUserID = otherUserID.String
return portal, nil return portal, nil
} }
@ -106,8 +106,8 @@ func (portal *Portal) sqlVariables() []any {
if portal.MXID != "" { if portal.MXID != "" {
mxid = (*string)(&portal.MXID) mxid = (*string)(&portal.MXID)
} }
if portal.SelfUserID != "" { if portal.OutgoingID != "" {
selfUserID = &portal.SelfUserID selfUserID = &portal.OutgoingID
} }
if portal.OtherUserID != "" { if portal.OtherUserID != "" {
otherUserID = &portal.OtherUserID otherUserID = &portal.OtherUserID

View file

@ -1,4 +1,4 @@
-- v0 -> v1: Latest revision -- v0 -> v2: Latest revision
CREATE TABLE "user" ( CREATE TABLE "user" (
-- only: postgres -- only: postgres
@ -7,9 +7,11 @@ CREATE TABLE "user" (
rowid INTEGER PRIMARY KEY, rowid INTEGER PRIMARY KEY,
mxid TEXT NOT NULL UNIQUE, mxid TEXT NOT NULL UNIQUE,
phone TEXT UNIQUE, phone_id TEXT UNIQUE,
session jsonb, session jsonb,
self_participant_ids jsonb NOT NULL DEFAULT '[]',
management_room TEXT, management_room TEXT,
space_room TEXT, space_room TEXT,

View file

@ -0,0 +1,3 @@
-- v2: Store all self-participant IDs
ALTER TABLE "user" RENAME COLUMN phone TO phone_id;
ALTER TABLE "user" ADD COLUMN self_participant_ids jsonb NOT NULL DEFAULT '[]';

View file

@ -22,7 +22,9 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"sync"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix/util/dbutil"
@ -44,23 +46,23 @@ func (uq *UserQuery) getDB() *Database {
} }
func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) { func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) {
return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE phone<>'' AND session IS NOT NULL`) return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE session IS NOT NULL`)
} }
func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) { func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) {
return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE access_token<>''`) return getAll[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE access_token<>''`)
} }
func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) { func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) {
return get[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE rowid=$1`, rowID) return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE rowid=$1`, rowID)
} }
func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) { func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
return get[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE mxid=$1`, userID) return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE mxid=$1`, userID)
} }
func (uq *UserQuery) GetByPhone(ctx context.Context, phone string) (*User, error) { func (uq *UserQuery) GetByPhone(ctx context.Context, phone string) (*User, error) {
return get[*User](uq, ctx, `SELECT rowid, mxid, phone, session, management_room, space_room, access_token FROM "user" WHERE phone=$1`, phone) return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE phone_id=$1`, phone)
} }
type User struct { type User struct {
@ -68,18 +70,22 @@ type User struct {
RowID int RowID int
MXID id.UserID MXID id.UserID
Phone string PhoneID string
Session *libgm.AuthData Session *libgm.AuthData
ManagementRoom id.RoomID ManagementRoom id.RoomID
SpaceRoom id.RoomID SpaceRoom id.RoomID
SelfParticipantIDs []string
selfParticipantIDsLock sync.RWMutex
AccessToken string AccessToken string
} }
func (user *User) Scan(row dbutil.Scannable) (*User, error) { func (user *User) Scan(row dbutil.Scannable) (*User, error) {
var phone, session, managementRoom, spaceRoom, accessToken sql.NullString var phoneID, session, managementRoom, spaceRoom, accessToken sql.NullString
err := row.Scan(&user.RowID, &user.MXID, &phone, &session, &managementRoom, &spaceRoom, &accessToken) var selfParticipantIDs string
err := row.Scan(&user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &managementRoom, &spaceRoom, &accessToken)
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, nil return nil, nil
} else if err != nil { } else if err != nil {
@ -93,7 +99,13 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) {
} }
user.Session = &sess user.Session = &sess
} }
user.Phone = phone.String user.selfParticipantIDsLock.Lock()
err = json.Unmarshal([]byte(selfParticipantIDs), &user.SelfParticipantIDs)
user.selfParticipantIDsLock.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to parse self participant IDs: %w", err)
}
user.PhoneID = phoneID.String
user.AccessToken = accessToken.String user.AccessToken = accessToken.String
user.ManagementRoom = id.RoomID(managementRoom.String) user.ManagementRoom = id.RoomID(managementRoom.String)
user.SpaceRoom = id.RoomID(spaceRoom.String) user.SpaceRoom = id.RoomID(spaceRoom.String)
@ -101,9 +113,9 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) {
} }
func (user *User) sqlVariables() []any { func (user *User) sqlVariables() []any {
var phone, session, managementRoom, spaceRoom, accessToken *string var phoneID, session, managementRoom, spaceRoom, accessToken *string
if user.Phone != "" { if user.PhoneID != "" {
phone = &user.Phone phoneID = &user.PhoneID
} }
if user.Session != nil { if user.Session != nil {
data, _ := json.Marshal(user.Session) data, _ := json.Marshal(user.Session)
@ -119,17 +131,38 @@ func (user *User) sqlVariables() []any {
if user.AccessToken != "" { if user.AccessToken != "" {
accessToken = &user.AccessToken accessToken = &user.AccessToken
} }
return []any{user.MXID, phone, session, managementRoom, spaceRoom, accessToken} user.selfParticipantIDsLock.RLock()
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
user.selfParticipantIDsLock.RUnlock()
return []any{user.MXID, phoneID, session, string(selfParticipantIDs), managementRoom, spaceRoom, accessToken}
}
func (user *User) IsSelfParticipantID(id string) bool {
user.selfParticipantIDsLock.RLock()
defer user.selfParticipantIDsLock.RUnlock()
return slices.Contains(user.SelfParticipantIDs, id)
}
func (user *User) AddSelfParticipantID(ctx context.Context, id string) error {
user.selfParticipantIDsLock.Lock()
defer user.selfParticipantIDsLock.Unlock()
if !slices.Contains(user.SelfParticipantIDs, id) {
user.SelfParticipantIDs = append(user.SelfParticipantIDs, id)
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
_, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`, user.MXID, selfParticipantIDs)
return err
}
return nil
} }
func (user *User) Insert(ctx context.Context) error { func (user *User) Insert(ctx context.Context) error {
err := user.db.Conn(ctx). err := user.db.Conn(ctx).
QueryRowContext(ctx, `INSERT INTO "user" (mxid, phone, session, management_room, space_room, access_token) VALUES ($1, $2, $3, $4, $5, $6) RETURNING rowid`, user.sqlVariables()...). QueryRowContext(ctx, `INSERT INTO "user" (mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING rowid`, user.sqlVariables()...).
Scan(&user.RowID) Scan(&user.RowID)
return err return err
} }
func (user *User) Update(ctx context.Context) error { func (user *User) Update(ctx context.Context) error {
_, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET phone=$2, session=$3, management_room=$4, space_room=$5, access_token=$6 WHERE mxid=$1`, user.sqlVariables()...) _, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET phone_id=$2, session=$3, self_participant_ids=$4, management_room=$5, space_room=$6, access_token=$7 WHERE mxid=$1`, user.sqlVariables()...)
return err return err
} }

4
go.mod
View file

@ -8,8 +8,9 @@ require (
github.com/rs/zerolog v1.29.1 github.com/rs/zerolog v1.29.1
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
go.mau.fi/mautrix-gmessages/libgm v0.1.0 go.mau.fi/mautrix-gmessages/libgm v0.1.0
google.golang.org/protobuf v1.31.0
maunium.net/go/maulogger/v2 v2.4.1 maunium.net/go/maulogger/v2 v2.4.1
maunium.net/go/mautrix v0.15.4 maunium.net/go/mautrix v0.15.5-0.20230719135321-8c3bd7722909
) )
require ( require (
@ -33,7 +34,6 @@ require (
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect
golang.org/x/net v0.12.0 // indirect golang.org/x/net v0.12.0 // indirect
golang.org/x/sys v0.10.0 // indirect golang.org/x/sys v0.10.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

4
go.sum
View file

@ -81,5 +81,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
maunium.net/go/mautrix v0.15.4 h1:Ug3n2Mo+9Yb94AjZTWJQSNHmShaksEzZi85EPl3S3P0= maunium.net/go/mautrix v0.15.5-0.20230719135321-8c3bd7722909 h1:Bw7WMyeODxjGjBK3v9+kE7Dm5b8HAEawmYMmYYSXvow=
maunium.net/go/mautrix v0.15.4/go.mod h1:dBaDmsnOOBM4a+gKcgefXH73pHGXm+MCJzCs1dXFgrw= maunium.net/go/mautrix v0.15.5-0.20230719135321-8c3bd7722909/go.mod h1:dBaDmsnOOBM4a+gKcgefXH73pHGXm+MCJzCs1dXFgrw=

View file

@ -2598,7 +2598,7 @@ type MessagePayload struct {
TmpID string `protobuf:"bytes,1,opt,name=tmpID,proto3" json:"tmpID,omitempty"` TmpID string `protobuf:"bytes,1,opt,name=tmpID,proto3" json:"tmpID,omitempty"`
MessagePayloadContent *MessagePayloadContent `protobuf:"bytes,6,opt,name=messagePayloadContent,proto3" json:"messagePayloadContent,omitempty"` MessagePayloadContent *MessagePayloadContent `protobuf:"bytes,6,opt,name=messagePayloadContent,proto3" json:"messagePayloadContent,omitempty"`
ConversationID string `protobuf:"bytes,7,opt,name=conversationID,proto3" json:"conversationID,omitempty"` ConversationID string `protobuf:"bytes,7,opt,name=conversationID,proto3" json:"conversationID,omitempty"`
SelfParticipantID string `protobuf:"bytes,9,opt,name=selfParticipantID,proto3" json:"selfParticipantID,omitempty"` // might be participantID ParticipantID string `protobuf:"bytes,9,opt,name=participantID,proto3" json:"participantID,omitempty"`
MessageInfo []*MessageInfo `protobuf:"bytes,10,rep,name=messageInfo,proto3" json:"messageInfo,omitempty"` MessageInfo []*MessageInfo `protobuf:"bytes,10,rep,name=messageInfo,proto3" json:"messageInfo,omitempty"`
TmpID2 string `protobuf:"bytes,12,opt,name=tmpID2,proto3" json:"tmpID2,omitempty"` TmpID2 string `protobuf:"bytes,12,opt,name=tmpID2,proto3" json:"tmpID2,omitempty"`
} }
@ -2656,9 +2656,9 @@ func (x *MessagePayload) GetConversationID() string {
return "" return ""
} }
func (x *MessagePayload) GetSelfParticipantID() string { func (x *MessagePayload) GetParticipantID() string {
if x != nil { if x != nil {
return x.SelfParticipantID return x.ParticipantID
} }
return "" return ""
} }

Binary file not shown.

View file

@ -269,7 +269,7 @@ message MessagePayload {
string tmpID = 1; string tmpID = 1;
MessagePayloadContent messagePayloadContent = 6; MessagePayloadContent messagePayloadContent = 6;
string conversationID = 7; string conversationID = 7;
string selfParticipantID = 9; // might be participantID string participantID = 9;
repeated conversations.MessageInfo messageInfo = 10; repeated conversations.MessageInfo messageInfo = 10;
string tmpID2 = 12; string tmpID2 = 12;
} }

View file

@ -1760,7 +1760,7 @@ type Conversation struct {
LastMessageTimestamp int64 `protobuf:"varint,5,opt,name=lastMessageTimestamp,proto3" json:"lastMessageTimestamp,omitempty"` LastMessageTimestamp int64 `protobuf:"varint,5,opt,name=lastMessageTimestamp,proto3" json:"lastMessageTimestamp,omitempty"`
Unread bool `protobuf:"varint,6,opt,name=unread,proto3" json:"unread,omitempty"` Unread bool `protobuf:"varint,6,opt,name=unread,proto3" json:"unread,omitempty"`
IsGroupChat bool `protobuf:"varint,10,opt,name=isGroupChat,proto3" json:"isGroupChat,omitempty"` // not certain IsGroupChat bool `protobuf:"varint,10,opt,name=isGroupChat,proto3" json:"isGroupChat,omitempty"` // not certain
SelfParticipantID string `protobuf:"bytes,11,opt,name=selfParticipantID,proto3" json:"selfParticipantID,omitempty"` DefaultOutgoingID string `protobuf:"bytes,11,opt,name=defaultOutgoingID,proto3" json:"defaultOutgoingID,omitempty"`
// bool bool1 = 13; // bool bool1 = 13;
Status ConvUpdateTypes `protobuf:"varint,12,opt,name=status,proto3,enum=conversations.ConvUpdateTypes" json:"status,omitempty"` Status ConvUpdateTypes `protobuf:"varint,12,opt,name=status,proto3,enum=conversations.ConvUpdateTypes" json:"status,omitempty"`
AvatarHexColor string `protobuf:"bytes,15,opt,name=avatarHexColor,proto3" json:"avatarHexColor,omitempty"` AvatarHexColor string `protobuf:"bytes,15,opt,name=avatarHexColor,proto3" json:"avatarHexColor,omitempty"`
@ -1847,9 +1847,9 @@ func (x *Conversation) GetIsGroupChat() bool {
return false return false
} }
func (x *Conversation) GetSelfParticipantID() string { func (x *Conversation) GetDefaultOutgoingID() string {
if x != nil { if x != nil {
return x.SelfParticipantID return x.DefaultOutgoingID
} }
return "" return ""
} }

Binary file not shown.

View file

@ -130,7 +130,7 @@ message Conversation {
bool unread = 6; bool unread = 6;
bool isGroupChat = 10; // not certain bool isGroupChat = 10; // not certain
string selfParticipantID = 11; string defaultOutgoingID = 11;
//bool bool1 = 13; //bool bool1 = 13;
ConvUpdateTypes status = 12; ConvUpdateTypes status = 12;

View file

@ -518,7 +518,7 @@ type ConvertedMessage struct {
} }
func (portal *Portal) getIntent(ctx context.Context, source *User, participant string) *appservice.IntentAPI { func (portal *Portal) getIntent(ctx context.Context, source *User, participant string) *appservice.IntentAPI {
if participant == portal.SelfUserID { if source.IsSelfParticipantID(participant) {
intent := source.DoublePuppetIntent intent := source.DoublePuppetIntent
if intent == nil { if intent == nil {
zerolog.Ctx(ctx).Debug().Msg("Dropping message from self as double puppeting is not enabled") zerolog.Ctx(ctx).Debug().Msg("Dropping message from self as double puppeting is not enabled")
@ -712,6 +712,12 @@ func (portal *Portal) SyncParticipants(source *User, metadata *gmproto.Conversat
var manyParticipants bool var manyParticipants bool
for _, participant := range metadata.Participants { for _, participant := range metadata.Participants {
if participant.IsMe { if participant.IsMe {
err := source.AddSelfParticipantID(context.TODO(), participant.ID.ParticipantID)
if err != nil {
portal.zlog.Warn().Err(err).
Str("participant_id", participant.ID.ParticipantID).
Msg("Failed to save self participant ID")
}
continue continue
} else if participant.ID.Number == "" { } else if participant.ID.Number == "" {
portal.zlog.Warn().Interface("participant", participant).Msg("No number found in non-self participant entry") portal.zlog.Warn().Interface("participant", participant).Msg("No number found in non-self participant entry")
@ -784,8 +790,8 @@ func (portal *Portal) UpdateName(name string, updateInfo bool) bool {
func (portal *Portal) UpdateMetadata(user *User, info *gmproto.Conversation) []id.UserID { func (portal *Portal) UpdateMetadata(user *User, info *gmproto.Conversation) []id.UserID {
participants, update := portal.SyncParticipants(user, info) participants, update := portal.SyncParticipants(user, info)
if portal.SelfUserID != info.SelfParticipantID { if portal.OutgoingID != info.DefaultOutgoingID {
portal.SelfUserID = info.SelfParticipantID portal.OutgoingID = info.DefaultOutgoingID
update = true update = true
} }
if portal.MXID != "" { if portal.MXID != "" {
@ -1181,7 +1187,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, co
ConversationID: portal.ID, ConversationID: portal.ID,
TmpID: txnID, TmpID: txnID,
TmpID2: txnID, TmpID2: txnID,
SelfParticipantID: portal.SelfUserID, ParticipantID: portal.OutgoingID,
}, },
} }
@ -1342,7 +1348,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
return errTargetNotFound return errTargetNotFound
} }
existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Key, msg.ID, portal.SelfUserID) existingReaction, err := portal.bridge.DB.Reaction.GetByID(ctx, portal.Key, msg.ID, portal.OutgoingID)
if err != nil { if err != nil {
log.Err(err).Msg("Failed to get existing reaction") log.Err(err).Msg("Failed to get existing reaction")
return fmt.Errorf("failed to get existing reaction from database") return fmt.Errorf("failed to get existing reaction from database")
@ -1367,7 +1373,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
existingReaction = portal.bridge.DB.Reaction.New() existingReaction = portal.bridge.DB.Reaction.New()
existingReaction.Chat = portal.Key existingReaction.Chat = portal.Key
existingReaction.MessageID = msg.ID existingReaction.MessageID = msg.ID
existingReaction.Sender = portal.SelfUserID existingReaction.Sender = portal.OutgoingID
} else if sender.DoublePuppetIntent != nil { } else if sender.DoublePuppetIntent != nil {
_, err = sender.DoublePuppetIntent.RedactEvent(portal.MXID, existingReaction.MXID) _, err = sender.DoublePuppetIntent.RedactEvent(portal.MXID, existingReaction.MXID)
if err != nil { if err != nil {

16
user.go
View file

@ -173,15 +173,15 @@ func (br *GMBridge) GetUserByPhone(phone string) *User {
func (user *User) addToPhoneMap() { func (user *User) addToPhoneMap() {
user.bridge.usersLock.Lock() user.bridge.usersLock.Lock()
user.bridge.usersByPhone[user.Phone] = user user.bridge.usersByPhone[user.PhoneID] = user
user.bridge.usersLock.Unlock() user.bridge.usersLock.Unlock()
} }
func (user *User) removeFromPhoneMap(state status.BridgeState) { func (user *User) removeFromPhoneMap(state status.BridgeState) {
user.bridge.usersLock.Lock() user.bridge.usersLock.Lock()
phoneUser, ok := user.bridge.usersByPhone[user.Phone] phoneUser, ok := user.bridge.usersByPhone[user.PhoneID]
if ok && user == phoneUser { if ok && user == phoneUser {
delete(user.bridge.usersByPhone, user.Phone) delete(user.bridge.usersByPhone, user.PhoneID)
} }
user.bridge.usersLock.Unlock() user.bridge.usersLock.Unlock()
user.BridgeState.Send(state) user.BridgeState.Send(state)
@ -231,11 +231,11 @@ func (br *GMBridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
} }
user := br.NewUser(dbUser) user := br.NewUser(dbUser)
br.usersByMXID[user.MXID] = user br.usersByMXID[user.MXID] = user
if user.Session != nil && user.Phone != "" { if user.Session != nil && user.PhoneID != "" {
br.usersByPhone[user.Phone] = user br.usersByPhone[user.PhoneID] = user
} else { } else {
user.Session = nil user.Session = nil
user.Phone = "" user.PhoneID = ""
} }
if len(user.ManagementRoom) > 0 { if len(user.ManagementRoom) > 0 {
br.managementRooms[user.ManagementRoom] = user br.managementRooms[user.ManagementRoom] = user
@ -540,7 +540,7 @@ func (user *User) HasSession() bool {
func (user *User) DeleteSession() { func (user *User) DeleteSession() {
user.Session = nil user.Session = nil
user.Phone = "" user.PhoneID = ""
err := user.Update(context.TODO()) err := user.Update(context.TODO())
if err != nil { if err != nil {
user.zlog.Err(err).Msg("Failed to delete session from database") user.zlog.Err(err).Msg("Failed to delete session from database")
@ -601,7 +601,7 @@ func (user *User) HandleEvent(event interface{}) {
user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected})
case *events.PairSuccessful: case *events.PairSuccessful:
user.Session = user.Client.AuthData user.Session = user.Client.AuthData
user.Phone = v.GetMobile().GetSourceID() user.PhoneID = v.GetMobile().GetSourceID()
user.tryAutomaticDoublePuppeting() user.tryAutomaticDoublePuppeting()
user.addToPhoneMap() user.addToPhoneMap()
err := user.Update(context.TODO()) err := user.Update(context.TODO())