Delete state if phone ID changes on login

This commit is contained in:
Tulir Asokan 2023-08-10 15:55:33 +03:00
parent efb0008ca0
commit bd213bf550
7 changed files with 60 additions and 87 deletions

View file

@ -380,7 +380,7 @@ func fnDeletePortal(ce *WrappedCommandEvent) {
ce.ZLog.Info().Str("conversation_id", ce.Portal.ID).Msg("Deleting portal from command")
ce.Portal.Delete()
ce.Portal.Cleanup(false)
ce.Portal.Cleanup()
}
var cmdDeleteAllPortals = &commands.FullHandler{
@ -419,7 +419,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
roomYeeting := ce.Bridge.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting)
if roomYeeting {
leave = func(portal *Portal) {
portal.Cleanup(false)
portal.Cleanup()
}
}
ce.Reply("Found %d portals, deleting...", len(portals))
@ -431,7 +431,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.")
go func() {
for _, portal := range portals {
portal.Cleanup(false)
portal.Cleanup()
}
ce.Reply("Finished background cleanup of deleted portal rooms.")
}()

View file

@ -43,6 +43,11 @@ func (pq *PuppetQuery) GetAll(ctx context.Context) ([]*Puppet, error) {
return getAll[*Puppet](pq, ctx, "SELECT id, receiver, phone, name, name_set, avatar_id, avatar_mxc, avatar_set, contact_info_set FROM puppet")
}
func (pq *PuppetQuery) DeleteAllForUser(ctx context.Context, userID int) error {
_, err := pq.db.Conn(ctx).ExecContext(ctx, "DELETE FROM puppet WHERE receiver=$1", userID)
return err
}
func (pq *PuppetQuery) Get(ctx context.Context, key Key) (*Puppet, error) {
return get[*Puppet](pq, ctx, "SELECT id, receiver, phone, name, name_set, avatar_id, avatar_mxc, avatar_set, contact_info_set FROM puppet WHERE id=$1 AND receiver=$2", key.ID, key.Receiver)
}

View file

@ -61,10 +61,6 @@ func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, er
return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE mxid=$1`, userID)
}
func (uq *UserQuery) GetByPhone(ctx context.Context, phone string) (*User, error) {
return get[*User](uq, ctx, `SELECT rowid, mxid, phone_id, session, self_participant_ids, management_room, space_room, access_token FROM "user" WHERE phone_id=$1`, phone)
}
type User struct {
db *Database

View file

@ -50,7 +50,6 @@ type GMBridge struct {
Provisioning *ProvisioningAPI
usersByMXID map[id.UserID]*User
usersByPhone map[string]*User
usersLock sync.Mutex
spaceRooms map[id.RoomID]*User
spaceRoomsLock sync.Mutex
@ -125,7 +124,7 @@ func (br *GMBridge) StartUsers() {
}
func (br *GMBridge) Stop() {
for _, user := range br.usersByPhone {
for _, user := range br.usersByMXID {
if user.Client == nil {
continue
}
@ -149,7 +148,6 @@ func (br *GMBridge) GetConfigPtr() interface{} {
func main() {
br := &GMBridge{
usersByMXID: make(map[id.UserID]*User),
usersByPhone: make(map[string]*User),
spaceRooms: make(map[id.RoomID]*User),
managementRooms: make(map[id.RoomID]*User),
portalsByMXID: make(map[id.RoomID]*Portal),

View file

@ -1796,40 +1796,17 @@ func (portal *Portal) Delete() {
portal.bridge.portalsLock.Unlock()
}
func (portal *Portal) GetMatrixUsers() ([]id.UserID, error) {
members, err := portal.MainIntent().JoinedMembers(portal.MXID)
if err != nil {
return nil, fmt.Errorf("failed to get member list: %w", err)
}
var users []id.UserID
for userID := range members.Joined {
_, isPuppet := portal.bridge.ParsePuppetMXID(userID)
if !isPuppet && userID != portal.bridge.Bot.UserID {
users = append(users, userID)
}
}
return users, nil
}
func (portal *Portal) CleanupIfEmpty() {
users, err := portal.GetMatrixUsers()
if err != nil {
portal.zlog.Err(err).Msg("Failed to get Matrix user list to determine if portal needs to be cleaned up")
return
}
if len(users) == 0 {
portal.zlog.Info().Msg("Room seems to be empty, cleaning up...")
portal.Delete()
portal.Cleanup(false)
}
}
func (portal *Portal) Cleanup(puppetsOnly bool) {
func (portal *Portal) Cleanup() {
if len(portal.MXID) == 0 {
return
}
intent := portal.MainIntent()
intent := portal.bridge.Bot
if portal.IsPrivateChat() {
intent = portal.bridge.AS.Intent(portal.bridge.FormatPuppetMXID(database.Key{
ID: portal.OtherUserID,
Receiver: portal.Receiver,
}))
}
if portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) {
err := intent.BeeperDeleteRoom(portal.MXID)
if err != nil && !errors.Is(err, mautrix.MNotFound) {
@ -1846,13 +1823,12 @@ func (portal *Portal) Cleanup(puppetsOnly bool) {
if member == intent.UserID {
continue
}
puppet := portal.bridge.GetPuppetByMXID(member)
if puppet != nil {
_, err = puppet.DefaultIntent().LeaveRoom(portal.MXID)
if portal.bridge.IsGhost(member) {
_, err = portal.bridge.AS.Intent(member).LeaveRoom(portal.MXID)
if err != nil {
portal.zlog.Err(err).Msg("Failed to leave as puppet while cleaning up portal")
}
} else if !puppetsOnly {
} else {
_, err = intent.KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"})
if err != nil {
portal.zlog.Err(err).Msg("Failed to kick user while cleaning up portal")

View file

@ -57,6 +57,20 @@ func (br *GMBridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
return br.GetPuppetByKey(key, "")
}
func (br *GMBridge) DeleteAllPuppetsForUser(userID int) {
br.puppetsLock.Lock()
defer br.puppetsLock.Unlock()
err := br.DB.Puppet.DeleteAllForUser(context.Background(), userID)
if err != nil {
br.ZLog.Err(err).Msg("Failed to delete all ghosts for user from database")
}
for key, puppet := range br.puppetsByKey {
if puppet.Receiver == userID {
delete(br.puppetsByKey, key)
}
}
}
func (br *GMBridge) GetPuppetByKey(key database.Key, phone string) *Puppet {
br.puppetsLock.Lock()
defer br.puppetsLock.Unlock()

68
user.go
View file

@ -161,38 +161,6 @@ func (br *GMBridge) GetUserByMXIDIfExists(userID id.UserID) *User {
return br.getUserByMXID(userID, true)
}
func (br *GMBridge) GetUserByPhone(phone string) *User {
br.usersLock.Lock()
defer br.usersLock.Unlock()
user, ok := br.usersByPhone[phone]
if !ok {
dbUser, err := br.DB.User.GetByPhone(context.TODO(), phone)
if err != nil {
br.ZLog.Err(err).
Str("phone", phone).
Msg("Failed to load user from database")
}
return br.loadDBUser(dbUser, nil)
}
return user
}
func (user *User) addToPhoneMap() {
user.bridge.usersLock.Lock()
user.bridge.usersByPhone[user.PhoneID] = user
user.bridge.usersLock.Unlock()
}
func (user *User) removeFromPhoneMap(state status.BridgeState) {
user.bridge.usersLock.Lock()
phoneUser, ok := user.bridge.usersByPhone[user.PhoneID]
if ok && user == phoneUser {
delete(user.bridge.usersByPhone, user.PhoneID)
}
user.bridge.usersLock.Unlock()
user.BridgeState.Send(state)
}
func (br *GMBridge) GetAllUsersWithSession() []*User {
return br.loadManyUsers(br.DB.User.GetAllWithSession)
}
@ -237,12 +205,6 @@ func (br *GMBridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
}
user := br.NewUser(dbUser)
br.usersByMXID[user.MXID] = user
if user.Session != nil && user.PhoneID != "" {
br.usersByPhone[user.PhoneID] = user
} else {
user.Session = nil
user.PhoneID = ""
}
if len(user.ManagementRoom) > 0 {
br.managementRooms[user.ManagementRoom] = user
}
@ -545,7 +507,6 @@ func (user *User) HasSession() bool {
func (user *User) DeleteSession() {
user.Session = nil
user.SelfParticipantIDs = []string{}
user.PhoneID = ""
err := user.Update(context.TODO())
if err != nil {
user.zlog.Err(err).Msg("Failed to delete session from database")
@ -624,8 +585,14 @@ func (user *User) syncHandleEvent(event any) {
user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected})
case *events.PairSuccessful:
user.Session = user.Client.AuthData
if user.PhoneID != "" && user.PhoneID != v.GetMobile().GetSourceID() {
user.zlog.Warn().
Str("old_phone_id", user.PhoneID).
Str("new_phone_id", v.GetMobile().GetSourceID()).
Msg("Phone ID changed, resetting state")
user.ResetState()
}
user.PhoneID = v.GetMobile().GetSourceID()
user.addToPhoneMap()
err := user.Update(context.TODO())
if err != nil {
user.zlog.Err(err).Msg("Failed to update session in database")
@ -660,6 +627,23 @@ func (user *User) syncHandleEvent(event any) {
}
}
func (user *User) ResetState() {
portals := user.bridge.GetAllPortalsForUser(user.RowID)
user.zlog.Debug().Int("portal_count", len(portals)).Msg("Deleting portals")
for _, portal := range portals {
portal.Delete()
}
user.bridge.DeleteAllPuppetsForUser(user.RowID)
user.PhoneID = ""
go func() {
user.zlog.Debug().Msg("Cleaning up portal rooms in background")
for _, portal := range portals {
portal.Cleanup()
}
user.zlog.Debug().Msg("Finished cleaning up portals")
}()
}
func (user *User) aggressiveSetActive() {
sleepTimes := []int{5, 10, 30}
for i := 0; i < 3; i++ {
@ -798,9 +782,9 @@ func (user *User) Logout(state status.BridgeState, unpair bool) (logoutOK bool)
logoutOK = true
}
}
user.removeFromPhoneMap(state)
user.DeleteConnection()
user.DeleteSession()
user.BridgeState.Send(state)
return
}
@ -821,7 +805,7 @@ func (user *User) syncConversation(v *gmproto.Conversation, source string) {
case gmproto.ConversationStatus_DELETED:
log.Info().Msg("Got delete event, cleaning up portal")
portal.Delete()
portal.Cleanup(false)
portal.Cleanup()
default:
if v.Participants == nil {
log.Debug().Msg("Not syncing conversation with nil participants")