Add context for google login and cancel if HTTP request is cancelled
This commit is contained in:
parent
7db7fdf20b
commit
59b3b7d0ec
7 changed files with 55 additions and 33 deletions
|
@ -176,7 +176,7 @@ func fnLoginGoogleCookies(ce *WrappedCommandEvent) {
|
|||
return
|
||||
}
|
||||
ce.Redact()
|
||||
err = ce.User.LoginGoogle(cookies, func(emoji string) {
|
||||
err = ce.User.LoginGoogle(ce.Ctx, cookies, func(emoji string) {
|
||||
ce.Reply(emoji)
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -56,7 +57,7 @@ func main() {
|
|||
cli = libgm.NewClient(&sess, log)
|
||||
cli.SetEventHandler(evtHandler)
|
||||
if doLogin {
|
||||
err = cli.DoGaiaPairing(func(emoji string) {
|
||||
err = cli.DoGaiaPairing(context.TODO(), func(emoji string) {
|
||||
fmt.Println(emoji)
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -24,6 +24,11 @@ const ContentTypeProtobuf = "application/x-protobuf"
|
|||
const ContentTypePBLite = "application/json+protobuf"
|
||||
|
||||
func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) {
|
||||
ctx := c.Logger.WithContext(context.TODO())
|
||||
return c.makeProtobufHTTPRequestContext(ctx, url, data, contentType)
|
||||
}
|
||||
|
||||
func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string, data proto.Message, contentType string) (*http.Response, error) {
|
||||
var body []byte
|
||||
var err error
|
||||
switch contentType {
|
||||
|
@ -37,7 +42,6 @@ func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, content
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx := c.Logger.WithContext(context.TODO())
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package libgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
|
@ -61,15 +62,15 @@ func (c *Client) baseSignInGaiaPayload() *gmproto.SignInGaiaRequest {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Client) signInGaiaInitial() (*gmproto.SignInGaiaResponse, error) {
|
||||
func (c *Client) signInGaiaInitial(ctx context.Context) (*gmproto.SignInGaiaResponse, error) {
|
||||
payload := c.baseSignInGaiaPayload()
|
||||
payload.UnknownInt3 = 1
|
||||
return typedHTTPResponse[*gmproto.SignInGaiaResponse](
|
||||
c.makeProtobufHTTPRequest(util.SignInGaiaURL, payload, ContentTypePBLite),
|
||||
c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *Client) signInGaiaGetToken() (*gmproto.SignInGaiaResponse, error) {
|
||||
func (c *Client) signInGaiaGetToken(ctx context.Context) (*gmproto.SignInGaiaResponse, error) {
|
||||
key, err := x509.MarshalPKIXPublicKey(c.AuthData.RefreshKey.GetPublicKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -80,7 +81,7 @@ func (c *Client) signInGaiaGetToken() (*gmproto.SignInGaiaResponse, error) {
|
|||
SomeData: key,
|
||||
}
|
||||
resp, err := typedHTTPResponse[*gmproto.SignInGaiaResponse](
|
||||
c.makeProtobufHTTPRequest(util.SignInGaiaURL, payload, ContentTypePBLite),
|
||||
c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -242,11 +243,11 @@ var (
|
|||
ErrPairingTimeout = errors.New("pairing timed out")
|
||||
)
|
||||
|
||||
func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
|
||||
func (c *Client) DoGaiaPairing(ctx context.Context, emojiCallback func(string)) error {
|
||||
if len(c.AuthData.Cookies) == 0 {
|
||||
return ErrNoCookies
|
||||
}
|
||||
sigResp, err := c.signInGaiaGetToken()
|
||||
sigResp, err := c.signInGaiaGetToken(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare gaia pairing: %w", err)
|
||||
}
|
||||
|
@ -272,7 +273,7 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare pairing payloads: %w", err)
|
||||
}
|
||||
serverInit, err := c.sendGaiaPairingMessage(ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, clientInit)
|
||||
serverInit, err := c.sendGaiaPairingMessage(ctx, ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, clientInit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send client init: %w", err)
|
||||
}
|
||||
|
@ -281,7 +282,7 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
|
|||
return fmt.Errorf("error processing server init: %w", err)
|
||||
}
|
||||
emojiCallback(pairingEmoji)
|
||||
finishResp, err := c.sendGaiaPairingMessage(ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED, clientFinish)
|
||||
finishResp, err := c.sendGaiaPairingMessage(ctx, ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED, clientFinish)
|
||||
if finishResp.GetFinishErrorType() != 0 {
|
||||
switch finishResp.GetFinishErrorCode() {
|
||||
case 5:
|
||||
|
@ -312,8 +313,8 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) sendGaiaPairingMessage(sess PairingSession, action gmproto.ActionType, msg []byte) (*gmproto.GaiaPairingResponseContainer, error) {
|
||||
resp, err := c.sessionHandler.sendMessageWithParams(SendMessageParams{
|
||||
func (c *Client) sendGaiaPairingMessage(ctx context.Context, sess PairingSession, action gmproto.ActionType, msg []byte) (*gmproto.GaiaPairingResponseContainer, error) {
|
||||
respCh, err := c.sessionHandler.sendAsyncMessage(SendMessageParams{
|
||||
Action: action,
|
||||
Data: &gmproto.GaiaPairingRequestContainer{
|
||||
PairingAttemptID: sess.UUID.String(),
|
||||
|
@ -324,18 +325,21 @@ func (c *Client) sendGaiaPairingMessage(sess PairingSession, action gmproto.Acti
|
|||
DontEncrypt: true,
|
||||
CustomTTL: (300 * time.Second).Microseconds(),
|
||||
MessageType: gmproto.MessageType_GAIA_2,
|
||||
|
||||
NoPingOnTimeout: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var respDat gmproto.GaiaPairingResponseContainer
|
||||
err = proto.Unmarshal(resp.Message.UnencryptedData, &respDat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
select {
|
||||
case resp := <-respCh:
|
||||
var respDat gmproto.GaiaPairingResponseContainer
|
||||
err = proto.Unmarshal(resp.Message.UnencryptedData, &respDat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &respDat, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return &respDat, nil
|
||||
}
|
||||
|
||||
func (c *Client) UnpairGaia() error {
|
||||
|
@ -344,6 +348,5 @@ func (c *Client) UnpairGaia() error {
|
|||
Data: &gmproto.RevokeGaiaPairingRequest{
|
||||
PairingAttemptID: c.AuthData.PairingID.String(),
|
||||
},
|
||||
NoPingOnTimeout: true,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -141,10 +141,6 @@ func (s *SessionHandler) sendMessageWithParams(params SendMessageParams) (*Incom
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if params.NoPingOnTimeout {
|
||||
return <-ch, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case resp := <-ch:
|
||||
return resp, nil
|
||||
|
@ -175,8 +171,6 @@ type SendMessageParams struct {
|
|||
CustomTTL int64
|
||||
DontEncrypt bool
|
||||
MessageType gmproto.MessageType
|
||||
|
||||
NoPingOnTimeout bool
|
||||
}
|
||||
|
||||
func (s *SessionHandler) buildMessage(params SendMessageParams) (string, proto.Message, error) {
|
||||
|
|
|
@ -364,7 +364,7 @@ func (prov *ProvisioningAPI) GoogleLoginWait(w http.ResponseWriter, r *http.Requ
|
|||
|
||||
log := prov.zlog.With().Str("user_id", user.MXID.String()).Str("endpoint", "login").Logger()
|
||||
|
||||
err := user.AsyncLoginGoogleWait()
|
||||
err := user.AsyncLoginGoogleWait(r.Context())
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to wait for google login")
|
||||
switch {
|
||||
|
@ -388,6 +388,12 @@ func (prov *ProvisioningAPI) GoogleLoginWait(w http.ResponseWriter, r *http.Requ
|
|||
Error: err.Error(),
|
||||
ErrCode: "timeout",
|
||||
})
|
||||
case errors.Is(err, context.Canceled):
|
||||
// This should only happen if the client already disconnected, so clients will probably never see this error code.
|
||||
jsonResponse(w, http.StatusBadRequest, Error{
|
||||
Error: err.Error(),
|
||||
ErrCode: "context-cancelled",
|
||||
})
|
||||
default:
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: "Failed to finish login",
|
||||
|
|
24
user.go
24
user.go
|
@ -484,8 +484,13 @@ func (user *User) AsyncLoginGoogleStart(cookies map[string]string) (outEmoji str
|
|||
outEmoji = emoji
|
||||
initialWait.Done()
|
||||
}
|
||||
var ctx context.Context
|
||||
ctx, user.cancelLogin = context.WithCancel(context.Background())
|
||||
go func() {
|
||||
err := user.LoginGoogle(cookies, callback)
|
||||
defer func() {
|
||||
user.cancelLogin = nil
|
||||
}()
|
||||
err := user.LoginGoogle(ctx, cookies, callback)
|
||||
if !callbackDone {
|
||||
user.zlog.Err(err).Msg("Async google login failed before callback")
|
||||
initialWait.Done()
|
||||
|
@ -505,15 +510,24 @@ func (user *User) AsyncLoginGoogleStart(cookies map[string]string) (outEmoji str
|
|||
return
|
||||
}
|
||||
|
||||
func (user *User) AsyncLoginGoogleWait() error {
|
||||
func (user *User) AsyncLoginGoogleWait(ctx context.Context) error {
|
||||
ch := user.googleAsyncPairErrChan.Swap(nil)
|
||||
if ch == nil {
|
||||
return ErrNoLoginInProgress
|
||||
}
|
||||
return <-*ch
|
||||
select {
|
||||
case ret := <-*ch:
|
||||
return ret
|
||||
case <-ctx.Done():
|
||||
user.zlog.Err(ctx.Err()).Msg("Login wait context canceled, canceling login")
|
||||
if cancelLogin := user.cancelLogin; cancelLogin != nil {
|
||||
cancelLogin()
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) LoginGoogle(cookies map[string]string, emojiCallback func(string)) error {
|
||||
func (user *User) LoginGoogle(ctx context.Context, cookies map[string]string, emojiCallback func(string)) error {
|
||||
user.connLock.Lock()
|
||||
defer user.connLock.Unlock()
|
||||
if user.Session != nil {
|
||||
|
@ -533,7 +547,7 @@ func (user *User) LoginGoogle(cookies map[string]string, emojiCallback func(stri
|
|||
authData.Cookies = cookies
|
||||
user.createClient(authData)
|
||||
Analytics.Track(user.MXID, "$login_start")
|
||||
err := user.Client.DoGaiaPairing(emojiCallback)
|
||||
err := user.Client.DoGaiaPairing(ctx, emojiCallback)
|
||||
if err != nil {
|
||||
user.unlockedDeleteConnection()
|
||||
return err
|
||||
|
|
Loading…
Reference in a new issue