Add option to get last code immediately in provisioning API

This commit is contained in:
Tulir Asokan 2023-07-18 14:51:43 +03:00
parent 30bfa14141
commit ecabb22407
3 changed files with 16 additions and 5 deletions

View file

@ -17,7 +17,6 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
@ -88,7 +87,7 @@ func fnLogin(ce *WrappedCommandEvent) {
return return
} }
ch, err := ce.User.Login(context.Background(), 6) ch, err := ce.User.Login(6)
if err != nil { if err != nil {
ce.ZLog.Err(err).Msg("Failed to start login") ce.ZLog.Err(err).Msg("Failed to start login")
ce.Reply("Failed to start login: %v", err) ce.Reply("Failed to start login: %v", err)

View file

@ -313,7 +313,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
return return
} }
ch, err := user.Login(context.Background(), 5) ch, err := user.Login(5)
if err != nil && !errors.Is(err, ErrLoginInProgress) { if err != nil && !errors.Is(err, ErrLoginInProgress) {
log.Err(err).Msg("Failed to start login via provisioning API") log.Err(err).Msg("Failed to start login via provisioning API")
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
@ -335,7 +335,10 @@ Loop:
break Loop break Loop
} }
} }
if !hasItem { if !hasItem && r.URL.Query().Get("return_immediately") == "true" && user.lastQRCode != "" {
log.Debug().Msg("Nothing in QR channel, returning last code immediately")
item.qr = user.lastQRCode
} else if !hasItem {
log.Debug().Msg("Nothing in QR channel, waiting for next item") log.Debug().Msg("Nothing in QR channel, waiting for next item")
select { select {
case item = <-ch: case item = <-ch:

11
user.go
View file

@ -76,6 +76,8 @@ type User struct {
pairSuccessChan chan struct{} pairSuccessChan chan struct{}
ongoingLoginChan <-chan qrChannelItem ongoingLoginChan <-chan qrChannelItem
loginChanReadLock sync.Mutex loginChanReadLock sync.Mutex
lastQRCode string
cancelLogin func()
DoublePuppetIntent *appservice.IntentAPI DoublePuppetIntent *appservice.IntentAPI
} }
@ -414,7 +416,7 @@ func (qci qrChannelItem) IsEmpty() bool {
return !qci.success && qci.qr == "" && qci.err == nil return !qci.success && qci.qr == "" && qci.err == nil
} }
func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelItem, error) { func (user *User) Login(maxAttempts int) (<-chan qrChannelItem, error) {
user.connLock.Lock() user.connLock.Lock()
defer user.connLock.Unlock() defer user.connLock.Unlock()
if user.Session != nil { if user.Session != nil {
@ -437,8 +439,11 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI
} }
Segment.Track(user.MXID, "$login_start") Segment.Track(user.MXID, "$login_start")
ch := make(chan qrChannelItem, maxAttempts+2) ch := make(chan qrChannelItem, maxAttempts+2)
ctx, cancel := context.WithCancel(context.Background())
user.cancelLogin = cancel
user.ongoingLoginChan = ch user.ongoingLoginChan = ch
ch <- qrChannelItem{qr: qr} ch <- qrChannelItem{qr: qr}
user.lastQRCode = qr
go func() { go func() {
ticker := time.NewTicker(30 * time.Second) ticker := time.NewTicker(30 * time.Second)
success := false success := false
@ -450,8 +455,11 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI
} }
user.pairSuccessChan = nil user.pairSuccessChan = nil
user.ongoingLoginChan = nil user.ongoingLoginChan = nil
user.lastQRCode = ""
close(ch) close(ch)
user.loginInProgress.Store(false) user.loginInProgress.Store(false)
cancel()
user.cancelLogin = nil
}() }()
for { for {
maxAttempts-- maxAttempts--
@ -470,6 +478,7 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI
return return
} }
ch <- qrChannelItem{qr: qr} ch <- qrChannelItem{qr: qr}
user.lastQRCode = qr
case <-pairSuccessChan: case <-pairSuccessChan:
ch <- qrChannelItem{success: true} ch <- qrChannelItem{success: true}
success = true success = true