diff --git a/commands.go b/commands.go index 930e91b..c4f7283 100644 --- a/commands.go +++ b/commands.go @@ -17,7 +17,6 @@ package main import ( - "context" "fmt" "github.com/skip2/go-qrcode" @@ -88,7 +87,7 @@ func fnLogin(ce *WrappedCommandEvent) { return } - ch, err := ce.User.Login(context.Background(), 6) + ch, err := ce.User.Login(6) if err != nil { ce.ZLog.Err(err).Msg("Failed to start login") ce.Reply("Failed to start login: %v", err) diff --git a/provisioning.go b/provisioning.go index 17613d7..ba642b5 100644 --- a/provisioning.go +++ b/provisioning.go @@ -313,7 +313,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { return } - ch, err := user.Login(context.Background(), 5) + ch, err := user.Login(5) if err != nil && !errors.Is(err, ErrLoginInProgress) { log.Err(err).Msg("Failed to start login via provisioning API") jsonResponse(w, http.StatusInternalServerError, Error{ @@ -335,7 +335,10 @@ Loop: break Loop } } - if !hasItem { + if !hasItem && r.URL.Query().Get("return_immediately") == "true" && user.lastQRCode != "" { + log.Debug().Msg("Nothing in QR channel, returning last code immediately") + item.qr = user.lastQRCode + } else if !hasItem { log.Debug().Msg("Nothing in QR channel, waiting for next item") select { case item = <-ch: diff --git a/user.go b/user.go index ccf238f..44f41f3 100644 --- a/user.go +++ b/user.go @@ -76,6 +76,8 @@ type User struct { pairSuccessChan chan struct{} ongoingLoginChan <-chan qrChannelItem loginChanReadLock sync.Mutex + lastQRCode string + cancelLogin func() DoublePuppetIntent *appservice.IntentAPI } @@ -414,7 +416,7 @@ func (qci qrChannelItem) IsEmpty() bool { return !qci.success && qci.qr == "" && qci.err == nil } -func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelItem, error) { +func (user *User) Login(maxAttempts int) (<-chan qrChannelItem, error) { user.connLock.Lock() defer user.connLock.Unlock() if user.Session != nil { @@ -437,8 +439,11 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI } Segment.Track(user.MXID, "$login_start") ch := make(chan qrChannelItem, maxAttempts+2) + ctx, cancel := context.WithCancel(context.Background()) + user.cancelLogin = cancel user.ongoingLoginChan = ch ch <- qrChannelItem{qr: qr} + user.lastQRCode = qr go func() { ticker := time.NewTicker(30 * time.Second) success := false @@ -450,8 +455,11 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI } user.pairSuccessChan = nil user.ongoingLoginChan = nil + user.lastQRCode = "" close(ch) user.loginInProgress.Store(false) + cancel() + user.cancelLogin = nil }() for { maxAttempts-- @@ -470,6 +478,7 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI return } ch <- qrChannelItem{qr: qr} + user.lastQRCode = qr case <-pairSuccessChan: ch <- qrChannelItem{success: true} success = true