diff --git a/commands.go b/commands.go index f442fa2..dd2b2b7 100644 --- a/commands.go +++ b/commands.go @@ -380,7 +380,7 @@ func fnDeletePortal(ce *WrappedCommandEvent) { ce.ZLog.Info().Str("conversation_id", ce.Portal.ID).Msg("Deleting portal from command") ce.Portal.Delete() - ce.Portal.Cleanup(false) + ce.Portal.Cleanup() } var cmdDeleteAllPortals = &commands.FullHandler{ @@ -419,7 +419,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) { roomYeeting := ce.Bridge.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) if roomYeeting { leave = func(portal *Portal) { - portal.Cleanup(false) + portal.Cleanup() } } ce.Reply("Found %d portals, deleting...", len(portals)) @@ -431,7 +431,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) { ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.") go func() { for _, portal := range portals { - portal.Cleanup(false) + portal.Cleanup() } ce.Reply("Finished background cleanup of deleted portal rooms.") }() diff --git a/database/puppet.go b/database/puppet.go index c9be16d..49e617b 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -43,6 +43,11 @@ func (pq *PuppetQuery) GetAll(ctx context.Context) ([]*Puppet, error) { return getAll[*Puppet](pq, ctx, "SELECT id, receiver, phone, name, name_set, avatar_id, avatar_mxc, avatar_set, contact_info_set FROM puppet") } +func (pq *PuppetQuery) DeleteAllForUser(ctx context.Context, userID int) error { + _, err := pq.db.Conn(ctx).ExecContext(ctx, "DELETE FROM puppet WHERE receiver=$1", userID) + return err +} + func (pq *PuppetQuery) Get(ctx context.Context, key Key) (*Puppet, error) { return get[*Puppet](pq, ctx, "SELECT id, receiver, phone, name, name_set, avatar_id, avatar_mxc, avatar_set, contact_info_set FROM puppet WHERE id=$1 AND receiver=$2", key.ID, key.Receiver) } diff --git a/database/user.go b/database/user.go index 23804e3..5251988 100644 --- a/database/user.go +++ b/database/user.go @@ -61,10 +61,6 @@ func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, er return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE mxid=$1`, userID) } -func (uq *UserQuery) GetByPhone(ctx context.Context, phone string) (*User, error) { - return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE phone_id=$1`, phone) -} - type User struct { db *Database diff --git a/main.go b/main.go index 1d62676..15dd026 100644 --- a/main.go +++ b/main.go @@ -50,7 +50,6 @@ type GMBridge struct { Provisioning *ProvisioningAPI usersByMXID map[id.UserID]*User - usersByPhone map[string]*User usersLock sync.Mutex spaceRooms map[id.RoomID]*User spaceRoomsLock sync.Mutex @@ -125,7 +124,7 @@ func (br *GMBridge) StartUsers() { } func (br *GMBridge) Stop() { - for _, user := range br.usersByPhone { + for _, user := range br.usersByMXID { if user.Client == nil { continue } @@ -149,7 +148,6 @@ func (br *GMBridge) GetConfigPtr() interface{} { func main() { br := &GMBridge{ usersByMXID: make(map[id.UserID]*User), - usersByPhone: make(map[string]*User), spaceRooms: make(map[id.RoomID]*User), managementRooms: make(map[id.RoomID]*User), portalsByMXID: make(map[id.RoomID]*Portal), diff --git a/portal.go b/portal.go index 5457bb1..3445a00 100644 --- a/portal.go +++ b/portal.go @@ -1796,40 +1796,17 @@ func (portal *Portal) Delete() { portal.bridge.portalsLock.Unlock() } -func (portal *Portal) GetMatrixUsers() ([]id.UserID, error) { - members, err := portal.MainIntent().JoinedMembers(portal.MXID) - if err != nil { - return nil, fmt.Errorf("failed to get member list: %w", err) - } - var users []id.UserID - for userID := range members.Joined { - _, isPuppet := portal.bridge.ParsePuppetMXID(userID) - if !isPuppet && userID != portal.bridge.Bot.UserID { - users = append(users, userID) - } - } - return users, nil -} - -func (portal *Portal) CleanupIfEmpty() { - users, err := portal.GetMatrixUsers() - if err != nil { - portal.zlog.Err(err).Msg("Failed to get Matrix user list to determine if portal needs to be cleaned up") - return - } - - if len(users) == 0 { - portal.zlog.Info().Msg("Room seems to be empty, cleaning up...") - portal.Delete() - portal.Cleanup(false) - } -} - -func (portal *Portal) Cleanup(puppetsOnly bool) { +func (portal *Portal) Cleanup() { if len(portal.MXID) == 0 { return } - intent := portal.MainIntent() + intent := portal.bridge.Bot + if portal.IsPrivateChat() { + intent = portal.bridge.AS.Intent(portal.bridge.FormatPuppetMXID(database.Key{ + ID: portal.OtherUserID, + Receiver: portal.Receiver, + })) + } if portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { err := intent.BeeperDeleteRoom(portal.MXID) if err != nil && !errors.Is(err, mautrix.MNotFound) { @@ -1846,13 +1823,12 @@ func (portal *Portal) Cleanup(puppetsOnly bool) { if member == intent.UserID { continue } - puppet := portal.bridge.GetPuppetByMXID(member) - if puppet != nil { - _, err = puppet.DefaultIntent().LeaveRoom(portal.MXID) + if portal.bridge.IsGhost(member) { + _, err = portal.bridge.AS.Intent(member).LeaveRoom(portal.MXID) if err != nil { portal.zlog.Err(err).Msg("Failed to leave as puppet while cleaning up portal") } - } else if !puppetsOnly { + } else { _, err = intent.KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) if err != nil { portal.zlog.Err(err).Msg("Failed to kick user while cleaning up portal") diff --git a/puppet.go b/puppet.go index a21e716..4f37071 100644 --- a/puppet.go +++ b/puppet.go @@ -57,6 +57,20 @@ func (br *GMBridge) GetPuppetByMXID(mxid id.UserID) *Puppet { return br.GetPuppetByKey(key, "") } +func (br *GMBridge) DeleteAllPuppetsForUser(userID int) { + br.puppetsLock.Lock() + defer br.puppetsLock.Unlock() + err := br.DB.Puppet.DeleteAllForUser(context.Background(), userID) + if err != nil { + br.ZLog.Err(err).Msg("Failed to delete all ghosts for user from database") + } + for key, puppet := range br.puppetsByKey { + if puppet.Receiver == userID { + delete(br.puppetsByKey, key) + } + } +} + func (br *GMBridge) GetPuppetByKey(key database.Key, phone string) *Puppet { br.puppetsLock.Lock() defer br.puppetsLock.Unlock() diff --git a/user.go b/user.go index 912e33c..46bd37b 100644 --- a/user.go +++ b/user.go @@ -161,38 +161,6 @@ func (br *GMBridge) GetUserByMXIDIfExists(userID id.UserID) *User { return br.getUserByMXID(userID, true) } -func (br *GMBridge) GetUserByPhone(phone string) *User { - br.usersLock.Lock() - defer br.usersLock.Unlock() - user, ok := br.usersByPhone[phone] - if !ok { - dbUser, err := br.DB.User.GetByPhone(context.TODO(), phone) - if err != nil { - br.ZLog.Err(err). - Str("phone", phone). - Msg("Failed to load user from database") - } - return br.loadDBUser(dbUser, nil) - } - return user -} - -func (user *User) addToPhoneMap() { - user.bridge.usersLock.Lock() - user.bridge.usersByPhone[user.PhoneID] = user - user.bridge.usersLock.Unlock() -} - -func (user *User) removeFromPhoneMap(state status.BridgeState) { - user.bridge.usersLock.Lock() - phoneUser, ok := user.bridge.usersByPhone[user.PhoneID] - if ok && user == phoneUser { - delete(user.bridge.usersByPhone, user.PhoneID) - } - user.bridge.usersLock.Unlock() - user.BridgeState.Send(state) -} - func (br *GMBridge) GetAllUsersWithSession() []*User { return br.loadManyUsers(br.DB.User.GetAllWithSession) } @@ -237,12 +205,6 @@ func (br *GMBridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User { } user := br.NewUser(dbUser) br.usersByMXID[user.MXID] = user - if user.Session != nil && user.PhoneID != "" { - br.usersByPhone[user.PhoneID] = user - } else { - user.Session = nil - user.PhoneID = "" - } if len(user.ManagementRoom) > 0 { br.managementRooms[user.ManagementRoom] = user } @@ -545,7 +507,6 @@ func (user *User) HasSession() bool { func (user *User) DeleteSession() { user.Session = nil user.SelfParticipantIDs = []string{} - user.PhoneID = "" err := user.Update(context.TODO()) if err != nil { user.zlog.Err(err).Msg("Failed to delete session from database") @@ -624,8 +585,14 @@ func (user *User) syncHandleEvent(event any) { user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) case *events.PairSuccessful: user.Session = user.Client.AuthData + if user.PhoneID != "" && user.PhoneID != v.GetMobile().GetSourceID() { + user.zlog.Warn(). + Str("old_phone_id", user.PhoneID). + Str("new_phone_id", v.GetMobile().GetSourceID()). + Msg("Phone ID changed, resetting state") + user.ResetState() + } user.PhoneID = v.GetMobile().GetSourceID() - user.addToPhoneMap() err := user.Update(context.TODO()) if err != nil { user.zlog.Err(err).Msg("Failed to update session in database") @@ -660,6 +627,23 @@ func (user *User) syncHandleEvent(event any) { } } +func (user *User) ResetState() { + portals := user.bridge.GetAllPortalsForUser(user.RowID) + user.zlog.Debug().Int("portal_count", len(portals)).Msg("Deleting portals") + for _, portal := range portals { + portal.Delete() + } + user.bridge.DeleteAllPuppetsForUser(user.RowID) + user.PhoneID = "" + go func() { + user.zlog.Debug().Msg("Cleaning up portal rooms in background") + for _, portal := range portals { + portal.Cleanup() + } + user.zlog.Debug().Msg("Finished cleaning up portals") + }() +} + func (user *User) aggressiveSetActive() { sleepTimes := []int{5, 10, 30} for i := 0; i < 3; i++ { @@ -798,9 +782,9 @@ func (user *User) Logout(state status.BridgeState, unpair bool) (logoutOK bool) logoutOK = true } } - user.removeFromPhoneMap(state) user.DeleteConnection() user.DeleteSession() + user.BridgeState.Send(state) return } @@ -821,7 +805,7 @@ func (user *User) syncConversation(v *gmproto.Conversation, source string) { case gmproto.ConversationStatus_DELETED: log.Info().Msg("Got delete event, cleaning up portal") portal.Delete() - portal.Cleanup(false) + portal.Cleanup() default: if v.Participants == nil { log.Debug().Msg("Not syncing conversation with nil participants")