Add context for google login and cancel if HTTP request is cancelled

This commit is contained in:
Tulir Asokan 2024-02-26 16:10:32 +02:00
parent 7db7fdf20b
commit 59b3b7d0ec
7 changed files with 55 additions and 33 deletions

View file

@ -176,7 +176,7 @@ func fnLoginGoogleCookies(ce *WrappedCommandEvent) {
return return
} }
ce.Redact() ce.Redact()
err = ce.User.LoginGoogle(cookies, func(emoji string) { err = ce.User.LoginGoogle(ce.Ctx, cookies, func(emoji string) {
ce.Reply(emoji) ce.Reply(emoji)
}) })
if err != nil { if err != nil {

View file

@ -2,6 +2,7 @@ package main
import ( import (
"bufio" "bufio"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -56,7 +57,7 @@ func main() {
cli = libgm.NewClient(&sess, log) cli = libgm.NewClient(&sess, log)
cli.SetEventHandler(evtHandler) cli.SetEventHandler(evtHandler)
if doLogin { if doLogin {
err = cli.DoGaiaPairing(func(emoji string) { err = cli.DoGaiaPairing(context.TODO(), func(emoji string) {
fmt.Println(emoji) fmt.Println(emoji)
}) })
if err != nil { if err != nil {

View file

@ -24,6 +24,11 @@ const ContentTypeProtobuf = "application/x-protobuf"
const ContentTypePBLite = "application/json+protobuf" const ContentTypePBLite = "application/json+protobuf"
func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) { func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) {
ctx := c.Logger.WithContext(context.TODO())
return c.makeProtobufHTTPRequestContext(ctx, url, data, contentType)
}
func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string, data proto.Message, contentType string) (*http.Response, error) {
var body []byte var body []byte
var err error var err error
switch contentType { switch contentType {
@ -37,7 +42,6 @@ func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, content
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx := c.Logger.WithContext(context.TODO())
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -17,6 +17,7 @@
package libgm package libgm
import ( import (
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
@ -61,15 +62,15 @@ func (c *Client) baseSignInGaiaPayload() *gmproto.SignInGaiaRequest {
} }
} }
func (c *Client) signInGaiaInitial() (*gmproto.SignInGaiaResponse, error) { func (c *Client) signInGaiaInitial(ctx context.Context) (*gmproto.SignInGaiaResponse, error) {
payload := c.baseSignInGaiaPayload() payload := c.baseSignInGaiaPayload()
payload.UnknownInt3 = 1 payload.UnknownInt3 = 1
return typedHTTPResponse[*gmproto.SignInGaiaResponse]( return typedHTTPResponse[*gmproto.SignInGaiaResponse](
c.makeProtobufHTTPRequest(util.SignInGaiaURL, payload, ContentTypePBLite), c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite),
) )
} }
func (c *Client) signInGaiaGetToken() (*gmproto.SignInGaiaResponse, error) { func (c *Client) signInGaiaGetToken(ctx context.Context) (*gmproto.SignInGaiaResponse, error) {
key, err := x509.MarshalPKIXPublicKey(c.AuthData.RefreshKey.GetPublicKey()) key, err := x509.MarshalPKIXPublicKey(c.AuthData.RefreshKey.GetPublicKey())
if err != nil { if err != nil {
return nil, err return nil, err
@ -80,7 +81,7 @@ func (c *Client) signInGaiaGetToken() (*gmproto.SignInGaiaResponse, error) {
SomeData: key, SomeData: key,
} }
resp, err := typedHTTPResponse[*gmproto.SignInGaiaResponse]( resp, err := typedHTTPResponse[*gmproto.SignInGaiaResponse](
c.makeProtobufHTTPRequest(util.SignInGaiaURL, payload, ContentTypePBLite), c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -242,11 +243,11 @@ var (
ErrPairingTimeout = errors.New("pairing timed out") ErrPairingTimeout = errors.New("pairing timed out")
) )
func (c *Client) DoGaiaPairing(emojiCallback func(string)) error { func (c *Client) DoGaiaPairing(ctx context.Context, emojiCallback func(string)) error {
if len(c.AuthData.Cookies) == 0 { if len(c.AuthData.Cookies) == 0 {
return ErrNoCookies return ErrNoCookies
} }
sigResp, err := c.signInGaiaGetToken() sigResp, err := c.signInGaiaGetToken(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare gaia pairing: %w", err) return fmt.Errorf("failed to prepare gaia pairing: %w", err)
} }
@ -272,7 +273,7 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare pairing payloads: %w", err) return fmt.Errorf("failed to prepare pairing payloads: %w", err)
} }
serverInit, err := c.sendGaiaPairingMessage(ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, clientInit) serverInit, err := c.sendGaiaPairingMessage(ctx, ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, clientInit)
if err != nil { if err != nil {
return fmt.Errorf("failed to send client init: %w", err) return fmt.Errorf("failed to send client init: %w", err)
} }
@ -281,7 +282,7 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
return fmt.Errorf("error processing server init: %w", err) return fmt.Errorf("error processing server init: %w", err)
} }
emojiCallback(pairingEmoji) emojiCallback(pairingEmoji)
finishResp, err := c.sendGaiaPairingMessage(ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED, clientFinish) finishResp, err := c.sendGaiaPairingMessage(ctx, ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED, clientFinish)
if finishResp.GetFinishErrorType() != 0 { if finishResp.GetFinishErrorType() != 0 {
switch finishResp.GetFinishErrorCode() { switch finishResp.GetFinishErrorCode() {
case 5: case 5:
@ -312,8 +313,8 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
return nil return nil
} }
func (c *Client) sendGaiaPairingMessage(sess PairingSession, action gmproto.ActionType, msg []byte) (*gmproto.GaiaPairingResponseContainer, error) { func (c *Client) sendGaiaPairingMessage(ctx context.Context, sess PairingSession, action gmproto.ActionType, msg []byte) (*gmproto.GaiaPairingResponseContainer, error) {
resp, err := c.sessionHandler.sendMessageWithParams(SendMessageParams{ respCh, err := c.sessionHandler.sendAsyncMessage(SendMessageParams{
Action: action, Action: action,
Data: &gmproto.GaiaPairingRequestContainer{ Data: &gmproto.GaiaPairingRequestContainer{
PairingAttemptID: sess.UUID.String(), PairingAttemptID: sess.UUID.String(),
@ -324,18 +325,21 @@ func (c *Client) sendGaiaPairingMessage(sess PairingSession, action gmproto.Acti
DontEncrypt: true, DontEncrypt: true,
CustomTTL: (300 * time.Second).Microseconds(), CustomTTL: (300 * time.Second).Microseconds(),
MessageType: gmproto.MessageType_GAIA_2, MessageType: gmproto.MessageType_GAIA_2,
NoPingOnTimeout: true,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
select {
case resp := <-respCh:
var respDat gmproto.GaiaPairingResponseContainer var respDat gmproto.GaiaPairingResponseContainer
err = proto.Unmarshal(resp.Message.UnencryptedData, &respDat) err = proto.Unmarshal(resp.Message.UnencryptedData, &respDat)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &respDat, nil return &respDat, nil
case <-ctx.Done():
return nil, ctx.Err()
}
} }
func (c *Client) UnpairGaia() error { func (c *Client) UnpairGaia() error {
@ -344,6 +348,5 @@ func (c *Client) UnpairGaia() error {
Data: &gmproto.RevokeGaiaPairingRequest{ Data: &gmproto.RevokeGaiaPairingRequest{
PairingAttemptID: c.AuthData.PairingID.String(), PairingAttemptID: c.AuthData.PairingID.String(),
}, },
NoPingOnTimeout: true,
}) })
} }

View file

@ -141,10 +141,6 @@ func (s *SessionHandler) sendMessageWithParams(params SendMessageParams) (*Incom
return nil, err return nil, err
} }
if params.NoPingOnTimeout {
return <-ch, nil
}
select { select {
case resp := <-ch: case resp := <-ch:
return resp, nil return resp, nil
@ -175,8 +171,6 @@ type SendMessageParams struct {
CustomTTL int64 CustomTTL int64
DontEncrypt bool DontEncrypt bool
MessageType gmproto.MessageType MessageType gmproto.MessageType
NoPingOnTimeout bool
} }
func (s *SessionHandler) buildMessage(params SendMessageParams) (string, proto.Message, error) { func (s *SessionHandler) buildMessage(params SendMessageParams) (string, proto.Message, error) {

View file

@ -364,7 +364,7 @@ func (prov *ProvisioningAPI) GoogleLoginWait(w http.ResponseWriter, r *http.Requ
log := prov.zlog.With().Str("user_id", user.MXID.String()).Str("endpoint", "login").Logger() log := prov.zlog.With().Str("user_id", user.MXID.String()).Str("endpoint", "login").Logger()
err := user.AsyncLoginGoogleWait() err := user.AsyncLoginGoogleWait(r.Context())
if err != nil { if err != nil {
log.Err(err).Msg("Failed to wait for google login") log.Err(err).Msg("Failed to wait for google login")
switch { switch {
@ -388,6 +388,12 @@ func (prov *ProvisioningAPI) GoogleLoginWait(w http.ResponseWriter, r *http.Requ
Error: err.Error(), Error: err.Error(),
ErrCode: "timeout", ErrCode: "timeout",
}) })
case errors.Is(err, context.Canceled):
// This should only happen if the client already disconnected, so clients will probably never see this error code.
jsonResponse(w, http.StatusBadRequest, Error{
Error: err.Error(),
ErrCode: "context-cancelled",
})
default: default:
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Failed to finish login", Error: "Failed to finish login",

24
user.go
View file

@ -484,8 +484,13 @@ func (user *User) AsyncLoginGoogleStart(cookies map[string]string) (outEmoji str
outEmoji = emoji outEmoji = emoji
initialWait.Done() initialWait.Done()
} }
var ctx context.Context
ctx, user.cancelLogin = context.WithCancel(context.Background())
go func() { go func() {
err := user.LoginGoogle(cookies, callback) defer func() {
user.cancelLogin = nil
}()
err := user.LoginGoogle(ctx, cookies, callback)
if !callbackDone { if !callbackDone {
user.zlog.Err(err).Msg("Async google login failed before callback") user.zlog.Err(err).Msg("Async google login failed before callback")
initialWait.Done() initialWait.Done()
@ -505,15 +510,24 @@ func (user *User) AsyncLoginGoogleStart(cookies map[string]string) (outEmoji str
return return
} }
func (user *User) AsyncLoginGoogleWait() error { func (user *User) AsyncLoginGoogleWait(ctx context.Context) error {
ch := user.googleAsyncPairErrChan.Swap(nil) ch := user.googleAsyncPairErrChan.Swap(nil)
if ch == nil { if ch == nil {
return ErrNoLoginInProgress return ErrNoLoginInProgress
} }
return <-*ch select {
case ret := <-*ch:
return ret
case <-ctx.Done():
user.zlog.Err(ctx.Err()).Msg("Login wait context canceled, canceling login")
if cancelLogin := user.cancelLogin; cancelLogin != nil {
cancelLogin()
}
return ctx.Err()
}
} }
func (user *User) LoginGoogle(cookies map[string]string, emojiCallback func(string)) error { func (user *User) LoginGoogle(ctx context.Context, cookies map[string]string, emojiCallback func(string)) error {
user.connLock.Lock() user.connLock.Lock()
defer user.connLock.Unlock() defer user.connLock.Unlock()
if user.Session != nil { if user.Session != nil {
@ -533,7 +547,7 @@ func (user *User) LoginGoogle(cookies map[string]string, emojiCallback func(stri
authData.Cookies = cookies authData.Cookies = cookies
user.createClient(authData) user.createClient(authData)
Analytics.Track(user.MXID, "$login_start") Analytics.Track(user.MXID, "$login_start")
err := user.Client.DoGaiaPairing(emojiCallback) err := user.Client.DoGaiaPairing(ctx, emojiCallback)
if err != nil { if err != nil {
user.unlockedDeleteConnection() user.unlockedDeleteConnection()
return err return err