Add lock for updating cookies

This commit is contained in:
Tulir Asokan 2024-04-17 00:17:18 +03:00
parent 1d3ef74817
commit d70ddb415b
8 changed files with 61 additions and 41 deletions

View file

@ -39,11 +39,54 @@ type AuthData struct {
SessionID uuid.UUID `json:"session_id,omitempty"` SessionID uuid.UUID `json:"session_id,omitempty"`
DestRegID uuid.UUID `json:"dest_reg_id,omitempty"` DestRegID uuid.UUID `json:"dest_reg_id,omitempty"`
PairingID uuid.UUID `json:"pairing_id,omitempty"` PairingID uuid.UUID `json:"pairing_id,omitempty"`
Cookies map[string]string `json:"cookies,omitempty"` Cookies map[string]string `json:"cookies,omitempty"`
CookiesLock sync.RWMutex `json:"-"`
}
func (ad *AuthData) SetCookies(cookies map[string]string) {
ad.CookiesLock.Lock()
ad.Cookies = cookies
ad.CookiesLock.Unlock()
}
func (ad *AuthData) AddCookiesToRequest(req *http.Request) {
ad.CookiesLock.RLock()
defer ad.CookiesLock.RUnlock()
if ad.Cookies == nil {
return
}
for name, value := range ad.Cookies {
req.AddCookie(&http.Cookie{Name: name, Value: value})
}
sapisid, ok := ad.Cookies["SAPISID"]
if ok {
req.Header.Set("Authorization", sapisidHash(util.MessagesBaseURL, sapisid))
}
}
func (ad *AuthData) UpdateCookiesFromResponse(resp *http.Response) {
ad.CookiesLock.Lock()
defer ad.CookiesLock.Unlock()
if ad.Cookies == nil {
return
}
for _, cookie := range resp.Cookies() {
ad.Cookies[cookie.Name] = cookie.Value
}
}
func (ad *AuthData) HasCookies() bool {
if ad == nil {
return false
}
ad.CookiesLock.RLock()
defer ad.CookiesLock.RUnlock()
return ad.Cookies != nil
} }
func (ad *AuthData) AuthNetwork() string { func (ad *AuthData) AuthNetwork() string {
if ad.Cookies != nil { if ad.HasCookies() {
return util.GoogleNetwork return util.GoogleNetwork
} }
return "" return ""
@ -253,11 +296,11 @@ func (c *Client) FetchConfig() (*gmproto.Config, error) {
req.Header.Set("sec-fetch-site", "same-origin") req.Header.Set("sec-fetch-site", "same-origin")
req.Header.Del("x-user-agent") req.Header.Del("x-user-agent")
req.Header.Del("origin") req.Header.Del("origin")
c.AddCookieHeaders(req) c.AuthData.AddCookiesToRequest(req)
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if resp != nil { if resp != nil {
c.HandleCookieUpdates(resp) c.AuthData.UpdateCookiesFromResponse(resp)
} }
config, err := typedHTTPResponse[*gmproto.Config](resp, err) config, err := typedHTTPResponse[*gmproto.Config](resp, err)
if err != nil { if err != nil {

View file

@ -112,7 +112,7 @@ func main() {
func saveSession() { func saveSession() {
file := mustReturn(os.OpenFile("session.json", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)) file := mustReturn(os.OpenFile("session.json", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600))
must(json.NewEncoder(file).Encode(sess)) must(json.NewEncoder(file).Encode(&sess))
_ = file.Close() _ = file.Close()
} }

View file

@ -47,7 +47,7 @@ func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string,
return nil, err return nil, err
} }
util.BuildRelayHeaders(req, contentType, "*/*") util.BuildRelayHeaders(req, contentType, "*/*")
c.AddCookieHeaders(req) c.AuthData.AddCookiesToRequest(req)
client := c.http client := c.http
if longPoll { if longPoll {
client = c.lphttp client = c.lphttp
@ -56,32 +56,10 @@ func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string,
if reqErr != nil { if reqErr != nil {
return res, reqErr return res, reqErr
} }
c.HandleCookieUpdates(res) c.AuthData.UpdateCookiesFromResponse(res)
return res, nil return res, nil
} }
func (c *Client) AddCookieHeaders(req *http.Request) {
if c.AuthData == nil || c.AuthData.Cookies == nil {
return
}
for k, v := range c.AuthData.Cookies {
req.AddCookie(&http.Cookie{Name: k, Value: v})
}
sapisid, ok := c.AuthData.Cookies["SAPISID"]
if ok {
req.Header.Set("Authorization", sapisidHash(util.MessagesBaseURL, sapisid))
}
}
func (c *Client) HandleCookieUpdates(resp *http.Response) {
if c.AuthData.Cookies == nil {
return
}
for _, cookie := range resp.Cookies() {
c.AuthData.Cookies[cookie.Name] = cookie.Value
}
}
func sapisidHash(origin, sapisid string) string { func sapisidHash(origin, sapisid string) string {
ts := time.Now().Unix() ts := time.Now().Unix()
hash := sha1.Sum([]byte(fmt.Sprintf("%d %s %s", ts, sapisid, origin))) hash := sha1.Sum([]byte(fmt.Sprintf("%d %s %s", ts, sapisid, origin)))

View file

@ -289,9 +289,8 @@ func (c *Client) doLongPoll(loggedIn bool, onFirstConnect func()) {
}, },
} }
url := util.ReceiveMessagesURL url := util.ReceiveMessagesURL
if c.AuthData.Cookies != nil { if c.AuthData.HasCookies() {
url = util.ReceiveMessagesURLGoogle url = util.ReceiveMessagesURLGoogle
payload.Auth.Network = util.GoogleNetwork
} }
resp, err := c.makeProtobufHTTPRequestContext(ctx, url, payload, ContentTypePBLite, true) resp, err := c.makeProtobufHTTPRequestContext(ctx, url, payload, ContentTypePBLite, true)
if err != nil { if err != nil {

View file

@ -152,7 +152,7 @@ func (c *Client) UnpairBugle() (*gmproto.RevokeRelayPairingResponse, error) {
} }
func (c *Client) Unpair() (err error) { func (c *Client) Unpair() (err error) {
if c.AuthData.Cookies != nil { if c.AuthData.HasCookies() {
err = c.UnpairGaia() err = c.UnpairGaia()
} else { } else {
_, err = c.UnpairBugle() _, err = c.UnpairBugle()

View file

@ -255,7 +255,7 @@ type primaryDeviceID struct {
} }
func (c *Client) DoGaiaPairing(ctx context.Context, emojiCallback func(string)) error { func (c *Client) DoGaiaPairing(ctx context.Context, emojiCallback func(string)) error {
if len(c.AuthData.Cookies) == 0 { if !c.AuthData.HasCookies() {
return ErrNoCookies return ErrNoCookies
} }
sigResp, err := c.signInGaiaGetToken(ctx) sigResp, err := c.signInGaiaGetToken(ctx)

View file

@ -38,7 +38,7 @@ func (s *SessionHandler) sendMessageNoResponse(params SendMessageParams) error {
} }
url := util.SendMessageURL url := util.SendMessageURL
if s.client.AuthData.Cookies != nil { if s.client.AuthData.HasCookies() {
url = util.SendMessageURLGoogle url = util.SendMessageURLGoogle
} }
_, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse]( _, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse](
@ -55,7 +55,7 @@ func (s *SessionHandler) sendAsyncMessage(params SendMessageParams) (<-chan *Inc
ch := s.waitResponse(requestID) ch := s.waitResponse(requestID)
url := util.SendMessageURL url := util.SendMessageURL
if s.client.AuthData.Cookies != nil { if s.client.AuthData.HasCookies() {
url = util.SendMessageURLGoogle url = util.SendMessageURLGoogle
} }
_, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse]( _, err = typedHTTPResponse[*gmproto.OutgoingRPCResponse](
@ -100,7 +100,7 @@ func (s *SessionHandler) receiveResponse(msg *IncomingRPCMessage) bool {
if msg.Message == nil { if msg.Message == nil {
return false return false
} }
if s.client.AuthData.Cookies != nil { if s.client.AuthData.HasCookies() {
switch msg.Message.Action { switch msg.Message.Action {
case gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED: case gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED:
default: default:
@ -291,7 +291,7 @@ func (s *SessionHandler) sendAckRequest() {
Acks: ackMessages, Acks: ackMessages,
} }
url := util.AckMessagesURL url := util.AckMessagesURL
if s.client.AuthData.Cookies != nil { if s.client.AuthData.HasCookies() {
url = util.AckMessagesURLGoogle url = util.AckMessagesURLGoogle
} }
_, err := typedHTTPResponse[*gmproto.OutgoingRPCResponse]( _, err := typedHTTPResponse[*gmproto.OutgoingRPCResponse](

View file

@ -555,7 +555,7 @@ func (user *User) LoginGoogle(ctx context.Context, cookies map[string]string, em
user.pairSuccessChan = nil user.pairSuccessChan = nil
}() }()
authData := libgm.NewAuthData() authData := libgm.NewAuthData()
authData.Cookies = cookies authData.SetCookies(cookies)
user.createClient(authData) user.createClient(authData)
Analytics.Track(user.MXID, "$login_start", map[string]any{"mode": "google"}) Analytics.Track(user.MXID, "$login_start", map[string]any{"mode": "google"})
user.Client.GaiaHackyDeviceSwitcher = user.gaiaHackyDeviceSwitcher user.Client.GaiaHackyDeviceSwitcher = user.gaiaHackyDeviceSwitcher