Add option to get last code immediately in provisioning API
This commit is contained in:
parent
30bfa14141
commit
ecabb22407
3 changed files with 16 additions and 5 deletions
|
@ -17,7 +17,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/skip2/go-qrcode"
|
||||
|
@ -88,7 +87,7 @@ func fnLogin(ce *WrappedCommandEvent) {
|
|||
return
|
||||
}
|
||||
|
||||
ch, err := ce.User.Login(context.Background(), 6)
|
||||
ch, err := ce.User.Login(6)
|
||||
if err != nil {
|
||||
ce.ZLog.Err(err).Msg("Failed to start login")
|
||||
ce.Reply("Failed to start login: %v", err)
|
||||
|
|
|
@ -313,7 +313,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ch, err := user.Login(context.Background(), 5)
|
||||
ch, err := user.Login(5)
|
||||
if err != nil && !errors.Is(err, ErrLoginInProgress) {
|
||||
log.Err(err).Msg("Failed to start login via provisioning API")
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
|
@ -335,7 +335,10 @@ Loop:
|
|||
break Loop
|
||||
}
|
||||
}
|
||||
if !hasItem {
|
||||
if !hasItem && r.URL.Query().Get("return_immediately") == "true" && user.lastQRCode != "" {
|
||||
log.Debug().Msg("Nothing in QR channel, returning last code immediately")
|
||||
item.qr = user.lastQRCode
|
||||
} else if !hasItem {
|
||||
log.Debug().Msg("Nothing in QR channel, waiting for next item")
|
||||
select {
|
||||
case item = <-ch:
|
||||
|
|
11
user.go
11
user.go
|
@ -76,6 +76,8 @@ type User struct {
|
|||
pairSuccessChan chan struct{}
|
||||
ongoingLoginChan <-chan qrChannelItem
|
||||
loginChanReadLock sync.Mutex
|
||||
lastQRCode string
|
||||
cancelLogin func()
|
||||
|
||||
DoublePuppetIntent *appservice.IntentAPI
|
||||
}
|
||||
|
@ -414,7 +416,7 @@ func (qci qrChannelItem) IsEmpty() bool {
|
|||
return !qci.success && qci.qr == "" && qci.err == nil
|
||||
}
|
||||
|
||||
func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelItem, error) {
|
||||
func (user *User) Login(maxAttempts int) (<-chan qrChannelItem, error) {
|
||||
user.connLock.Lock()
|
||||
defer user.connLock.Unlock()
|
||||
if user.Session != nil {
|
||||
|
@ -437,8 +439,11 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI
|
|||
}
|
||||
Segment.Track(user.MXID, "$login_start")
|
||||
ch := make(chan qrChannelItem, maxAttempts+2)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
user.cancelLogin = cancel
|
||||
user.ongoingLoginChan = ch
|
||||
ch <- qrChannelItem{qr: qr}
|
||||
user.lastQRCode = qr
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
success := false
|
||||
|
@ -450,8 +455,11 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI
|
|||
}
|
||||
user.pairSuccessChan = nil
|
||||
user.ongoingLoginChan = nil
|
||||
user.lastQRCode = ""
|
||||
close(ch)
|
||||
user.loginInProgress.Store(false)
|
||||
cancel()
|
||||
user.cancelLogin = nil
|
||||
}()
|
||||
for {
|
||||
maxAttempts--
|
||||
|
@ -470,6 +478,7 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI
|
|||
return
|
||||
}
|
||||
ch <- qrChannelItem{qr: qr}
|
||||
user.lastQRCode = qr
|
||||
case <-pairSuccessChan:
|
||||
ch <- qrChannelItem{success: true}
|
||||
success = true
|
||||
|
|
Loading…
Reference in a new issue