Add lock for updating cookies

This commit is contained in:
Tulir Asokan 2024-04-17 00:17:18 +03:00
parent 1d3ef74817
commit d70ddb415b
8 changed files with 61 additions and 41 deletions

View file

@ -39,11 +39,54 @@ type AuthData struct {
SessionID uuid.UUID `json:"session_id,omitempty"`
DestRegID uuid.UUID `json:"dest_reg_id,omitempty"`
PairingID uuid.UUID `json:"pairing_id,omitempty"`
Cookies map[string]string `json:"cookies,omitempty"`
CookiesLock sync.RWMutex `json:"-"`
}
func (ad *AuthData) SetCookies(cookies map[string]string) {
ad.CookiesLock.Lock()
ad.Cookies = cookies
ad.CookiesLock.Unlock()
}
func (ad *AuthData) AddCookiesToRequest(req *http.Request) {
ad.CookiesLock.RLock()
defer ad.CookiesLock.RUnlock()
if ad.Cookies == nil {
return
}
for name, value := range ad.Cookies {
req.AddCookie(&http.Cookie{Name: name, Value: value})
}
sapisid, ok := ad.Cookies["SAPISID"]
if ok {
req.Header.Set("Authorization", sapisidHash(util.MessagesBaseURL, sapisid))
}
}
func (ad *AuthData) UpdateCookiesFromResponse(resp *http.Response) {
ad.CookiesLock.Lock()
defer ad.CookiesLock.Unlock()
if ad.Cookies == nil {
return
}
for _, cookie := range resp.Cookies() {
ad.Cookies[cookie.Name] = cookie.Value
}
}
func (ad *AuthData) HasCookies() bool {
if ad == nil {
return false
}
ad.CookiesLock.RLock()
defer ad.CookiesLock.RUnlock()
return ad.Cookies != nil
}
func (ad *AuthData) AuthNetwork() string {
if ad.Cookies != nil {
if ad.HasCookies() {
return util.GoogleNetwork
}
return ""
@ -253,11 +296,11 @@ func (c *Client) FetchConfig() (*gmproto.Config, error) {
req.Header.Set("sec-fetch-site", "same-origin")
req.Header.Del("x-user-agent")
req.Header.Del("origin")
c.AddCookieHeaders(req)
c.AuthData.AddCookiesToRequest(req)
resp, err := c.http.Do(req)
if resp != nil {
c.HandleCookieUpdates(resp)
c.AuthData.UpdateCookiesFromResponse(resp)
}
config, err := typedHTTPResponse[*gmproto.Config](resp, err)
if err != nil {

View file

@ -112,7 +112,7 @@ func main() {
func saveSession() {
file := mustReturn(os.OpenFile("session.json", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600))
must(json.NewEncoder(file).Encode(sess))
must(json.NewEncoder(file).Encode(&sess))
_ = file.Close()
}

View file

@ -47,7 +47,7 @@ func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string,
return nil, err
}
util.BuildRelayHeaders(req, contentType, "*/*")
c.AddCookieHeaders(req)
c.AuthData.AddCookiesToRequest(req)
client := c.http
if longPoll {
client = c.lphttp
@ -56,32 +56,10 @@ func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string,
if reqErr != nil {
return res, reqErr
}
c.HandleCookieUpdates(res)
c.AuthData.UpdateCookiesFromResponse(res)
return res, nil
}
func (c *Client) AddCookieHeaders(req *http.Request) {
if c.AuthData == nil || c.AuthData.Cookies == nil {
return
}
for k, v := range c.AuthData.Cookies {
req.AddCookie(&http.Cookie{Name: k, Value: v})
}
sapisid, ok := c.AuthData.Cookies["SAPISID"]
if ok {
req.Header.Set("Authorization", sapisidHash(util.MessagesBaseURL, sapisid))
}
}
func (c *Client) HandleCookieUpdates(resp *http.Response) {
if c.AuthData.Cookies == nil {
return
}
for _, cookie := range resp.Cookies() {
c.AuthData.Cookies[cookie.Name] = cookie.Value
}
}
func sapisidHash(origin, sapisid string) string {
ts := time.Now().Unix()
hash := sha1.Sum([]byte(fmt.Sprintf("%d %s %s", ts, sapisid, origin)))

View file

@ -289,9 +289,8 @@ func (c *Client) doLongPoll(loggedIn bool, onFirstConnect func()) {
},
}
url := util.ReceiveMessagesURL
if c.AuthData.Cookies != nil {
if c.AuthData.HasCookies() {
url = util.ReceiveMessagesURLGoogle
payload.Auth.Network = util.GoogleNetwork
}
resp, err := c.makeProtobufHTTPRequestContext(ctx, url, payload, ContentTypePBLite, true)
if err != nil {

View file

@ -152,7 +152,7 @@ func (c *Client) UnpairBugle() (*gmproto.RevokeRelayPairingResponse, error) {
}
func (c *Client) Unpair() (err error) {
if c.AuthData.Cookies != nil {
if c.AuthData.HasCookies() {
err = c.UnpairGaia()
} else {
_, err = c.UnpairBugle()

View file

@ -255,7 +255,7 @@ type primaryDeviceID struct {
}
func (c *Client) DoGaiaPairing(ctx context.Context, emojiCallback func(string)) error {
if len(c.AuthData.Cookies) == 0 {
if !c.AuthData.HasCookies() {
return ErrNoCookies
}
sigResp, err := c.signInGaiaGetToken(ctx)

View file

@ -38,7 +38,7 @@ func (s *SessionHandler) sendMessageNoResponse(params SendMessageParams) error {
}
url := util.SendMessageURL
if s.client.AuthData.Cookies != nil {
if s.client.AuthData.HasCookies() {
url = util.SendMessageURLGoogle
}
_, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse](
@ -55,7 +55,7 @@ func (s *SessionHandler) sendAsyncMessage(params SendMessageParams) (<-chan *Inc
ch := s.waitResponse(requestID)
url := util.SendMessageURL
if s.client.AuthData.Cookies != nil {
if s.client.AuthData.HasCookies() {
url = util.SendMessageURLGoogle
}
_, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse](
@ -100,7 +100,7 @@ func (s *SessionHandler) receiveResponse(msg *IncomingRPCMessage) bool {
if msg.Message == nil {
return false
}
if s.client.AuthData.Cookies != nil {
if s.client.AuthData.HasCookies() {
switch msg.Message.Action {
case gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED:
default:
@ -291,7 +291,7 @@ func (s *SessionHandler) sendAckRequest() {
Acks: ackMessages,
}
url := util.AckMessagesURL
if s.client.AuthData.Cookies != nil {
if s.client.AuthData.HasCookies() {
url = util.AckMessagesURLGoogle
}
_, err := typedHTTPResponse[*gmproto.OutgoingRPCResponse](

View file

@ -555,7 +555,7 @@ func (user *User) LoginGoogle(ctx context.Context, cookies map[string]string, em
user.pairSuccessChan = nil
}()
authData := libgm.NewAuthData()
authData.Cookies = cookies
authData.SetCookies(cookies)
user.createClient(authData)
Analytics.Track(user.MXID, "$login_start", map[string]any{"mode": "google"})
user.Client.GaiaHackyDeviceSwitcher = user.gaiaHackyDeviceSwitcher