Add timeouts for HTTP requests

This commit is contained in:
Tulir Asokan 2024-03-05 19:14:02 +02:00
parent 96c09b4752
commit 90b5346763
4 changed files with 24 additions and 14 deletions

View file

@ -6,6 +6,7 @@ import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
@ -77,8 +78,9 @@ type Client struct {
AuthData *AuthData AuthData *AuthData
cfg *gmproto.Config cfg *gmproto.Config
proxy Proxy httpTransport *http.Transport
http *http.Client http *http.Client
lphttp *http.Client
} }
func NewAuthData() *AuthData { func NewAuthData() *AuthData {
@ -92,11 +94,17 @@ func NewClient(authData *AuthData, logger zerolog.Logger) *Client {
sessionHandler := &SessionHandler{ sessionHandler := &SessionHandler{
responseWaiters: make(map[string]chan<- *IncomingRPCMessage), responseWaiters: make(map[string]chan<- *IncomingRPCMessage),
} }
transport := &http.Transport{
DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 20 * time.Second,
}
cli := &Client{ cli := &Client{
AuthData: authData, AuthData: authData,
Logger: logger, Logger: logger,
sessionHandler: sessionHandler, sessionHandler: sessionHandler,
http: &http.Client{}, http: &http.Client{Transport: transport, Timeout: 2 * time.Minute},
lphttp: &http.Client{Transport: transport, Timeout: 30 * time.Minute},
pingShortCircuit: make(chan struct{}), pingShortCircuit: make(chan struct{}),
} }
@ -127,11 +135,7 @@ func (c *Client) SetProxy(proxy string) error {
if err != nil { if err != nil {
c.Logger.Fatal().Err(err).Msg("Failed to set proxy") c.Logger.Fatal().Err(err).Msg("Failed to set proxy")
} }
proxyUrl := http.ProxyURL(proxyParsed) c.httpTransport.Proxy = http.ProxyURL(proxyParsed)
c.http.Transport = &http.Transport{
Proxy: proxyUrl,
}
c.proxy = proxyUrl
c.Logger.Debug().Any("proxy", proxyParsed.Host).Msg("SetProxy") c.Logger.Debug().Any("proxy", proxyParsed.Host).Msg("SetProxy")
return nil return nil
} }

View file

@ -25,10 +25,10 @@ const ContentTypePBLite = "application/json+protobuf"
func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) { func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) {
ctx := c.Logger.WithContext(context.TODO()) ctx := c.Logger.WithContext(context.TODO())
return c.makeProtobufHTTPRequestContext(ctx, url, data, contentType) return c.makeProtobufHTTPRequestContext(ctx, url, data, contentType, false)
} }
func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string, data proto.Message, contentType string) (*http.Response, error) { func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string, data proto.Message, contentType string, longPoll bool) (*http.Response, error) {
var body []byte var body []byte
var err error var err error
switch contentType { switch contentType {
@ -48,7 +48,11 @@ func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string,
} }
util.BuildRelayHeaders(req, contentType, "*/*") util.BuildRelayHeaders(req, contentType, "*/*")
c.AddCookieHeaders(req) c.AddCookieHeaders(req)
res, reqErr := c.http.Do(req) client := c.http
if longPoll {
client = c.lphttp
}
res, reqErr := client.Do(req)
if reqErr != nil { if reqErr != nil {
return res, reqErr return res, reqErr
} }

View file

@ -3,6 +3,7 @@ package libgm
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
@ -214,6 +215,7 @@ func (c *Client) doLongPoll(loggedIn bool, onFirstConnect func()) {
defer func() { defer func() {
log.Debug().Msg("Long polling stopped") log.Debug().Msg("Long polling stopped")
}() }()
ctx := log.WithContext(context.TODO())
log.Debug().Str("listen_uuid", listenReqID).Msg("Long polling starting") log.Debug().Str("listen_uuid", listenReqID).Msg("Long polling starting")
dittoPing := make(chan struct{}, 1) dittoPing := make(chan struct{}, 1)
@ -253,7 +255,7 @@ func (c *Client) doLongPoll(loggedIn bool, onFirstConnect func()) {
url = util.ReceiveMessagesURLGoogle url = util.ReceiveMessagesURLGoogle
payload.Auth.Network = util.GoogleNetwork payload.Auth.Network = util.GoogleNetwork
} }
resp, err := c.makeProtobufHTTPRequest(url, payload, ContentTypePBLite) resp, err := c.makeProtobufHTTPRequestContext(ctx, url, payload, ContentTypePBLite, true)
if err != nil { if err != nil {
if loggedIn { if loggedIn {
c.triggerEvent(&events.ListenTemporaryError{Error: err}) c.triggerEvent(&events.ListenTemporaryError{Error: err})

View file

@ -67,7 +67,7 @@ func (c *Client) signInGaiaInitial(ctx context.Context) (*gmproto.SignInGaiaResp
payload := c.baseSignInGaiaPayload() payload := c.baseSignInGaiaPayload()
payload.UnknownInt3 = 1 payload.UnknownInt3 = 1
return typedHTTPResponse[*gmproto.SignInGaiaResponse]( return typedHTTPResponse[*gmproto.SignInGaiaResponse](
c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite), c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite, false),
) )
} }
@ -82,7 +82,7 @@ func (c *Client) signInGaiaGetToken(ctx context.Context) (*gmproto.SignInGaiaRes
SomeData: key, SomeData: key,
} }
resp, err := typedHTTPResponse[*gmproto.SignInGaiaResponse]( resp, err := typedHTTPResponse[*gmproto.SignInGaiaResponse](
c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite), c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite, false),
) )
if err != nil { if err != nil {
return nil, err return nil, err