Add context for google login and cancel if HTTP request is cancelled
This commit is contained in:
parent
7db7fdf20b
commit
59b3b7d0ec
7 changed files with 55 additions and 33 deletions
|
@ -176,7 +176,7 @@ func fnLoginGoogleCookies(ce *WrappedCommandEvent) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ce.Redact()
|
ce.Redact()
|
||||||
err = ce.User.LoginGoogle(cookies, func(emoji string) {
|
err = ce.User.LoginGoogle(ce.Ctx, cookies, func(emoji string) {
|
||||||
ce.Reply(emoji)
|
ce.Reply(emoji)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -56,7 +57,7 @@ func main() {
|
||||||
cli = libgm.NewClient(&sess, log)
|
cli = libgm.NewClient(&sess, log)
|
||||||
cli.SetEventHandler(evtHandler)
|
cli.SetEventHandler(evtHandler)
|
||||||
if doLogin {
|
if doLogin {
|
||||||
err = cli.DoGaiaPairing(func(emoji string) {
|
err = cli.DoGaiaPairing(context.TODO(), func(emoji string) {
|
||||||
fmt.Println(emoji)
|
fmt.Println(emoji)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -24,6 +24,11 @@ const ContentTypeProtobuf = "application/x-protobuf"
|
||||||
const ContentTypePBLite = "application/json+protobuf"
|
const ContentTypePBLite = "application/json+protobuf"
|
||||||
|
|
||||||
func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) {
|
func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) {
|
||||||
|
ctx := c.Logger.WithContext(context.TODO())
|
||||||
|
return c.makeProtobufHTTPRequestContext(ctx, url, data, contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string, data proto.Message, contentType string) (*http.Response, error) {
|
||||||
var body []byte
|
var body []byte
|
||||||
var err error
|
var err error
|
||||||
switch contentType {
|
switch contentType {
|
||||||
|
@ -37,7 +42,6 @@ func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, content
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ctx := c.Logger.WithContext(context.TODO())
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package libgm
|
package libgm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -61,15 +62,15 @@ func (c *Client) baseSignInGaiaPayload() *gmproto.SignInGaiaRequest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) signInGaiaInitial() (*gmproto.SignInGaiaResponse, error) {
|
func (c *Client) signInGaiaInitial(ctx context.Context) (*gmproto.SignInGaiaResponse, error) {
|
||||||
payload := c.baseSignInGaiaPayload()
|
payload := c.baseSignInGaiaPayload()
|
||||||
payload.UnknownInt3 = 1
|
payload.UnknownInt3 = 1
|
||||||
return typedHTTPResponse[*gmproto.SignInGaiaResponse](
|
return typedHTTPResponse[*gmproto.SignInGaiaResponse](
|
||||||
c.makeProtobufHTTPRequest(util.SignInGaiaURL, payload, ContentTypePBLite),
|
c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) signInGaiaGetToken() (*gmproto.SignInGaiaResponse, error) {
|
func (c *Client) signInGaiaGetToken(ctx context.Context) (*gmproto.SignInGaiaResponse, error) {
|
||||||
key, err := x509.MarshalPKIXPublicKey(c.AuthData.RefreshKey.GetPublicKey())
|
key, err := x509.MarshalPKIXPublicKey(c.AuthData.RefreshKey.GetPublicKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -80,7 +81,7 @@ func (c *Client) signInGaiaGetToken() (*gmproto.SignInGaiaResponse, error) {
|
||||||
SomeData: key,
|
SomeData: key,
|
||||||
}
|
}
|
||||||
resp, err := typedHTTPResponse[*gmproto.SignInGaiaResponse](
|
resp, err := typedHTTPResponse[*gmproto.SignInGaiaResponse](
|
||||||
c.makeProtobufHTTPRequest(util.SignInGaiaURL, payload, ContentTypePBLite),
|
c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -242,11 +243,11 @@ var (
|
||||||
ErrPairingTimeout = errors.New("pairing timed out")
|
ErrPairingTimeout = errors.New("pairing timed out")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
|
func (c *Client) DoGaiaPairing(ctx context.Context, emojiCallback func(string)) error {
|
||||||
if len(c.AuthData.Cookies) == 0 {
|
if len(c.AuthData.Cookies) == 0 {
|
||||||
return ErrNoCookies
|
return ErrNoCookies
|
||||||
}
|
}
|
||||||
sigResp, err := c.signInGaiaGetToken()
|
sigResp, err := c.signInGaiaGetToken(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to prepare gaia pairing: %w", err)
|
return fmt.Errorf("failed to prepare gaia pairing: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -272,7 +273,7 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to prepare pairing payloads: %w", err)
|
return fmt.Errorf("failed to prepare pairing payloads: %w", err)
|
||||||
}
|
}
|
||||||
serverInit, err := c.sendGaiaPairingMessage(ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, clientInit)
|
serverInit, err := c.sendGaiaPairingMessage(ctx, ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, clientInit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to send client init: %w", err)
|
return fmt.Errorf("failed to send client init: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -281,7 +282,7 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
|
||||||
return fmt.Errorf("error processing server init: %w", err)
|
return fmt.Errorf("error processing server init: %w", err)
|
||||||
}
|
}
|
||||||
emojiCallback(pairingEmoji)
|
emojiCallback(pairingEmoji)
|
||||||
finishResp, err := c.sendGaiaPairingMessage(ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED, clientFinish)
|
finishResp, err := c.sendGaiaPairingMessage(ctx, ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED, clientFinish)
|
||||||
if finishResp.GetFinishErrorType() != 0 {
|
if finishResp.GetFinishErrorType() != 0 {
|
||||||
switch finishResp.GetFinishErrorCode() {
|
switch finishResp.GetFinishErrorCode() {
|
||||||
case 5:
|
case 5:
|
||||||
|
@ -312,8 +313,8 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) sendGaiaPairingMessage(sess PairingSession, action gmproto.ActionType, msg []byte) (*gmproto.GaiaPairingResponseContainer, error) {
|
func (c *Client) sendGaiaPairingMessage(ctx context.Context, sess PairingSession, action gmproto.ActionType, msg []byte) (*gmproto.GaiaPairingResponseContainer, error) {
|
||||||
resp, err := c.sessionHandler.sendMessageWithParams(SendMessageParams{
|
respCh, err := c.sessionHandler.sendAsyncMessage(SendMessageParams{
|
||||||
Action: action,
|
Action: action,
|
||||||
Data: &gmproto.GaiaPairingRequestContainer{
|
Data: &gmproto.GaiaPairingRequestContainer{
|
||||||
PairingAttemptID: sess.UUID.String(),
|
PairingAttemptID: sess.UUID.String(),
|
||||||
|
@ -324,18 +325,21 @@ func (c *Client) sendGaiaPairingMessage(sess PairingSession, action gmproto.Acti
|
||||||
DontEncrypt: true,
|
DontEncrypt: true,
|
||||||
CustomTTL: (300 * time.Second).Microseconds(),
|
CustomTTL: (300 * time.Second).Microseconds(),
|
||||||
MessageType: gmproto.MessageType_GAIA_2,
|
MessageType: gmproto.MessageType_GAIA_2,
|
||||||
|
|
||||||
NoPingOnTimeout: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
select {
|
||||||
|
case resp := <-respCh:
|
||||||
var respDat gmproto.GaiaPairingResponseContainer
|
var respDat gmproto.GaiaPairingResponseContainer
|
||||||
err = proto.Unmarshal(resp.Message.UnencryptedData, &respDat)
|
err = proto.Unmarshal(resp.Message.UnencryptedData, &respDat)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &respDat, nil
|
return &respDat, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) UnpairGaia() error {
|
func (c *Client) UnpairGaia() error {
|
||||||
|
@ -344,6 +348,5 @@ func (c *Client) UnpairGaia() error {
|
||||||
Data: &gmproto.RevokeGaiaPairingRequest{
|
Data: &gmproto.RevokeGaiaPairingRequest{
|
||||||
PairingAttemptID: c.AuthData.PairingID.String(),
|
PairingAttemptID: c.AuthData.PairingID.String(),
|
||||||
},
|
},
|
||||||
NoPingOnTimeout: true,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -141,10 +141,6 @@ func (s *SessionHandler) sendMessageWithParams(params SendMessageParams) (*Incom
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.NoPingOnTimeout {
|
|
||||||
return <-ch, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case resp := <-ch:
|
case resp := <-ch:
|
||||||
return resp, nil
|
return resp, nil
|
||||||
|
@ -175,8 +171,6 @@ type SendMessageParams struct {
|
||||||
CustomTTL int64
|
CustomTTL int64
|
||||||
DontEncrypt bool
|
DontEncrypt bool
|
||||||
MessageType gmproto.MessageType
|
MessageType gmproto.MessageType
|
||||||
|
|
||||||
NoPingOnTimeout bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionHandler) buildMessage(params SendMessageParams) (string, proto.Message, error) {
|
func (s *SessionHandler) buildMessage(params SendMessageParams) (string, proto.Message, error) {
|
||||||
|
|
|
@ -364,7 +364,7 @@ func (prov *ProvisioningAPI) GoogleLoginWait(w http.ResponseWriter, r *http.Requ
|
||||||
|
|
||||||
log := prov.zlog.With().Str("user_id", user.MXID.String()).Str("endpoint", "login").Logger()
|
log := prov.zlog.With().Str("user_id", user.MXID.String()).Str("endpoint", "login").Logger()
|
||||||
|
|
||||||
err := user.AsyncLoginGoogleWait()
|
err := user.AsyncLoginGoogleWait(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).Msg("Failed to wait for google login")
|
log.Err(err).Msg("Failed to wait for google login")
|
||||||
switch {
|
switch {
|
||||||
|
@ -388,6 +388,12 @@ func (prov *ProvisioningAPI) GoogleLoginWait(w http.ResponseWriter, r *http.Requ
|
||||||
Error: err.Error(),
|
Error: err.Error(),
|
||||||
ErrCode: "timeout",
|
ErrCode: "timeout",
|
||||||
})
|
})
|
||||||
|
case errors.Is(err, context.Canceled):
|
||||||
|
// This should only happen if the client already disconnected, so clients will probably never see this error code.
|
||||||
|
jsonResponse(w, http.StatusBadRequest, Error{
|
||||||
|
Error: err.Error(),
|
||||||
|
ErrCode: "context-cancelled",
|
||||||
|
})
|
||||||
default:
|
default:
|
||||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||||
Error: "Failed to finish login",
|
Error: "Failed to finish login",
|
||||||
|
|
24
user.go
24
user.go
|
@ -484,8 +484,13 @@ func (user *User) AsyncLoginGoogleStart(cookies map[string]string) (outEmoji str
|
||||||
outEmoji = emoji
|
outEmoji = emoji
|
||||||
initialWait.Done()
|
initialWait.Done()
|
||||||
}
|
}
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, user.cancelLogin = context.WithCancel(context.Background())
|
||||||
go func() {
|
go func() {
|
||||||
err := user.LoginGoogle(cookies, callback)
|
defer func() {
|
||||||
|
user.cancelLogin = nil
|
||||||
|
}()
|
||||||
|
err := user.LoginGoogle(ctx, cookies, callback)
|
||||||
if !callbackDone {
|
if !callbackDone {
|
||||||
user.zlog.Err(err).Msg("Async google login failed before callback")
|
user.zlog.Err(err).Msg("Async google login failed before callback")
|
||||||
initialWait.Done()
|
initialWait.Done()
|
||||||
|
@ -505,15 +510,24 @@ func (user *User) AsyncLoginGoogleStart(cookies map[string]string) (outEmoji str
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) AsyncLoginGoogleWait() error {
|
func (user *User) AsyncLoginGoogleWait(ctx context.Context) error {
|
||||||
ch := user.googleAsyncPairErrChan.Swap(nil)
|
ch := user.googleAsyncPairErrChan.Swap(nil)
|
||||||
if ch == nil {
|
if ch == nil {
|
||||||
return ErrNoLoginInProgress
|
return ErrNoLoginInProgress
|
||||||
}
|
}
|
||||||
return <-*ch
|
select {
|
||||||
|
case ret := <-*ch:
|
||||||
|
return ret
|
||||||
|
case <-ctx.Done():
|
||||||
|
user.zlog.Err(ctx.Err()).Msg("Login wait context canceled, canceling login")
|
||||||
|
if cancelLogin := user.cancelLogin; cancelLogin != nil {
|
||||||
|
cancelLogin()
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) LoginGoogle(cookies map[string]string, emojiCallback func(string)) error {
|
func (user *User) LoginGoogle(ctx context.Context, cookies map[string]string, emojiCallback func(string)) error {
|
||||||
user.connLock.Lock()
|
user.connLock.Lock()
|
||||||
defer user.connLock.Unlock()
|
defer user.connLock.Unlock()
|
||||||
if user.Session != nil {
|
if user.Session != nil {
|
||||||
|
@ -533,7 +547,7 @@ func (user *User) LoginGoogle(cookies map[string]string, emojiCallback func(stri
|
||||||
authData.Cookies = cookies
|
authData.Cookies = cookies
|
||||||
user.createClient(authData)
|
user.createClient(authData)
|
||||||
Analytics.Track(user.MXID, "$login_start")
|
Analytics.Track(user.MXID, "$login_start")
|
||||||
err := user.Client.DoGaiaPairing(emojiCallback)
|
err := user.Client.DoGaiaPairing(ctx, emojiCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
user.unlockedDeleteConnection()
|
user.unlockedDeleteConnection()
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Reference in a new issue