From 6602e6c9379c754ee26f407127d8a364fa03f0c2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 17 Jul 2023 01:13:46 +0300 Subject: [PATCH] Add provisioning API --- bridgestate.go | 11 +- commands.go | 2 + main.go | 16 +-- provisioning.go | 307 ++++++++++++++++++++++++++++++++++++++++++++++++ user.go | 30 ++++- 5 files changed, 348 insertions(+), 18 deletions(-) create mode 100644 provisioning.go diff --git a/bridgestate.go b/bridgestate.go index b9b47ec..6fd463c 100644 --- a/bridgestate.go +++ b/bridgestate.go @@ -17,7 +17,10 @@ package main import ( + "net/http" + "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/id" ) const ( @@ -55,7 +58,7 @@ func (user *User) GetRemoteName() string { return "" } -/*func (prov *ProvisioningAPI) BridgeStatePing(w http.ResponseWriter, r *http.Request) { +func (prov *ProvisioningAPI) BridgeStatePing(w http.ResponseWriter, r *http.Request) { if !prov.bridge.AS.CheckServerToken(w, r) { return } @@ -69,11 +72,11 @@ func (user *User) GetRemoteName() string { remote.StateEvent = status.StateConnected } else if user.Session != nil { remote.StateEvent = status.StateConnecting - remote.Error = WAConnecting + //remote.Error = WAConnecting } // else: unconfigured } else if user.Session != nil { remote.StateEvent = status.StateBadCredentials - remote.Error = WANotConnected + //remote.Error = WANotConnected } // else: unconfigured global = global.Fill(nil) resp := status.GlobalBridgeState{ @@ -89,4 +92,4 @@ func (user *User) GetRemoteName() string { if len(resp.RemoteStates) > 0 { user.BridgeState.SetPrev(remote) } -}*/ +} diff --git a/commands.go b/commands.go index 6d07c3c..930e91b 100644 --- a/commands.go +++ b/commands.go @@ -112,6 +112,8 @@ func fnLogin(ce *WrappedCommandEvent) { MsgType: event.MsgNotice, Body: "Successfully logged in", }, prevEvent) + default: + ce.ZLog.Error().Any("item_data", item).Msg("Unknown item in QR channel") } } ce.ZLog.Trace().Msg("Login command finished") diff --git a/main.go b/main.go index 4403a1f..7a88b7a 100644 --- a/main.go +++ b/main.go @@ -45,9 +45,9 @@ var ExampleConfig string type GMBridge struct { bridge.Bridge - Config *config.Config - DB *database.Database - //Provisioning *ProvisioningAPI + Config *config.Config + DB *database.Database + Provisioning *ProvisioningAPI usersByMXID map[id.UserID]*User usersByPhone map[string]*User @@ -89,15 +89,15 @@ func (br *GMBridge) Init() { ss := br.Config.Bridge.Provisioning.SharedSecret if len(ss) > 0 && ss != "disable" { - //br.Provisioning = &ProvisioningAPI{bridge: br} + br.Provisioning = &ProvisioningAPI{bridge: br} } } func (br *GMBridge) Start() { - //if br.Provisioning != nil { - // br.ZLog.Debug().Msg("Initializing provisioning API") - // br.Provisioning.Init() - //} + if br.Provisioning != nil { + br.ZLog.Debug().Msg("Initializing provisioning API") + br.Provisioning.Init() + } br.WaitWebsocketConnected() go br.StartUsers() } diff --git a/provisioning.go b/provisioning.go new file mode 100644 index 0000000..ccd80bb --- /dev/null +++ b/provisioning.go @@ -0,0 +1,307 @@ +// mautrix-gmessages - A Matrix-Google Messages puppeting bridge. +// Copyright (C) 2023 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "strings" + "time" + + "github.com/rs/zerolog" + + log "maunium.net/go/maulogger/v2" + + "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/id" +) + +type ProvisioningAPI struct { + bridge *GMBridge + log log.Logger + zlog zerolog.Logger +} + +func (prov *ProvisioningAPI) Init() { + prov.log = prov.bridge.Log.Sub("Provisioning") + prov.zlog = prov.bridge.ZLog.With().Str("component", "provisioning").Logger() + + prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix) + r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter() + r.Use(prov.AuthMiddleware) + r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet) + r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodPost) + r.HandleFunc("/v1/logout", prov.Logout).Methods(http.MethodPost) + r.HandleFunc("/v1/delete_session", prov.DeleteSession).Methods(http.MethodPost) + r.HandleFunc("/v1/disconnect", prov.Disconnect).Methods(http.MethodPost) + r.HandleFunc("/v1/reconnect", prov.Reconnect).Methods(http.MethodPost) + r.HandleFunc("/v1/contacts", prov.ListContacts).Methods(http.MethodGet) + r.HandleFunc("/v1/start_chat", prov.StartChat).Methods(http.MethodPost) + prov.bridge.AS.Router.HandleFunc("/_matrix/app/com.beeper.asmux/ping", prov.BridgeStatePing).Methods(http.MethodPost) + prov.bridge.AS.Router.HandleFunc("/_matrix/app/com.beeper.bridge_state", prov.BridgeStatePing).Methods(http.MethodPost) + + // Deprecated, just use /disconnect + r.HandleFunc("/v1/delete_connection", prov.Disconnect).Methods(http.MethodPost) +} + +type responseWrap struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWrap) WriteHeader(statusCode int) { + rw.ResponseWriter.WriteHeader(statusCode) + rw.statusCode = statusCode +} + +func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + auth = auth[len("Bearer "):] + } + if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret { + prov.log.Infof("Authentication token does not match shared secret") + jsonResponse(w, http.StatusForbidden, map[string]interface{}{ + "error": "Authentication token does not match shared secret", + "errcode": "M_FORBIDDEN", + }) + return + } + userID := r.URL.Query().Get("user_id") + user := prov.bridge.GetUserByMXID(id.UserID(userID)) + start := time.Now() + wWrap := &responseWrap{w, 200} + h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user))) + duration := time.Now().Sub(start).Seconds() + prov.log.Infofln("%s %s from %s took %.2f seconds and returned status %d", r.Method, r.URL.Path, user.MXID, duration, wWrap.statusCode) + }) +} + +type Error struct { + Success bool `json:"success"` + Error string `json:"error"` + ErrCode string `json:"errcode"` +} + +type Response struct { + Success bool `json:"success"` + Status string `json:"status"` +} + +func (prov *ProvisioningAPI) DeleteSession(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + if user.Session == nil && user.Client == nil { + jsonResponse(w, http.StatusNotFound, Error{ + Error: "Nothing to purge: no session information stored and no active connection.", + ErrCode: "no session", + }) + return + } + user.Logout(status.BridgeState{StateEvent: status.StateLoggedOut}, false) + jsonResponse(w, http.StatusOK, Response{true, "Session information purged"}) +} + +func (prov *ProvisioningAPI) Disconnect(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + if user.Client == nil { + jsonResponse(w, http.StatusNotFound, Error{ + Error: "You don't have a Google Messages connection.", + ErrCode: "no connection", + }) + return + } + user.DeleteConnection() + jsonResponse(w, http.StatusOK, Response{true, "Disconnected from Google Messages"}) + user.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: GMNotConnected}) +} + +func (prov *ProvisioningAPI) Reconnect(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + if user.Client == nil { + if user.Session == nil { + jsonResponse(w, http.StatusForbidden, Error{ + Error: "No existing connection and no session. Please log in first.", + ErrCode: "no session", + }) + } else { + user.Connect() + jsonResponse(w, http.StatusAccepted, Response{true, "Created connection to Google Messages."}) + } + } else { + user.DeleteConnection() + user.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: GMNotConnected}) + user.Connect() + jsonResponse(w, http.StatusAccepted, Response{true, "Restarted connection to Google Messages"}) + } +} + +func (prov *ProvisioningAPI) ListContacts(w http.ResponseWriter, r *http.Request) { + if user := r.Context().Value("user").(*User); user.Client == nil { + jsonResponse(w, http.StatusBadRequest, Error{ + Error: "User is not connected to Google Messages", + ErrCode: "no session", + }) + } else if contacts, err := user.Client.ListContacts(); err != nil { + prov.log.Errorfln("Failed to fetch %s's contacts: %v", user.MXID, err) + jsonResponse(w, http.StatusInternalServerError, Error{ + Error: "Internal server error while fetching contact list", + ErrCode: "failed to get contacts", + }) + } else { + jsonResponse(w, http.StatusOK, contacts) + } +} + +func (prov *ProvisioningAPI) StartChat(w http.ResponseWriter, r *http.Request) { + if user := r.Context().Value("user").(*User); user.Client == nil { + jsonResponse(w, http.StatusBadRequest, Error{ + Error: "User is not connected to Google Messages", + ErrCode: "no session", + }) + } + +} + +func (prov *ProvisioningAPI) Ping(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + gm := map[string]interface{}{ + "has_session": user.Session != nil, + "conn": nil, + } + if user.Session != nil { + gm["phone_id"] = user.Session.Mobile.SourceID + gm["browser_id"] = user.Session.Browser.SourceID + } + if user.Client != nil { + gm["conn"] = map[string]interface{}{ + "is_connected": user.Client.IsConnected(), + "is_logged_in": user.Client.IsLoggedIn(), + } + } + resp := map[string]interface{}{ + "mxid": user.MXID, + "admin": user.Admin, + "whitelisted": user.Whitelisted, + "management_room": user.ManagementRoom, + "space_room": user.SpaceRoom, + "gmessages": gm, + } + jsonResponse(w, http.StatusOK, resp) +} + +func jsonResponse(w http.ResponseWriter, status int, response interface{}) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(response) +} + +func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + if user.Session == nil { + jsonResponse(w, http.StatusOK, Error{ + Error: "You're not logged in", + ErrCode: "not logged in", + }) + return + } + + user.Logout(status.BridgeState{StateEvent: status.StateLoggedOut}, true) + jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) +} + +type LoginResponse struct { + Status string `json:"status"` + Code string `json:"code,omitempty"` + ErrCode string `json:"errcode,omitempty"` + Error string `json:"error,omitempty"` +} + +func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { + userID := r.URL.Query().Get("user_id") + user := prov.bridge.GetUserByMXID(id.UserID(userID)) + + log := prov.zlog.With().Str("user_id", user.MXID.String()).Str("endpoint", "login").Logger() + + if user.IsLoggedIn() { + jsonResponse(w, http.StatusOK, LoginResponse{Status: "success", ErrCode: "already logged in"}) + return + } + + ch, err := user.Login(context.Background(), 5) + if err != nil && !errors.Is(err, ErrLoginInProgress) { + log.Err(err).Msg("Failed to start login via provisioning API") + jsonResponse(w, http.StatusInternalServerError, Error{ + Error: "Failed to start login", + ErrCode: "start login fail", + }) + return + } + + var item, prevItem qrChannelItem + var hasItem bool +Loop: + for { + prevItem = item + select { + case item = <-ch: + hasItem = true + default: + break Loop + } + } + if !hasItem { + log.Debug().Msg("Nothing in QR channel, waiting for next item") + select { + case item = <-ch: + case <-r.Context().Done(): + log.Warn().Err(r.Context().Err()).Msg("Client left while waiting for QR code") + return + } + } else if item.IsEmpty() && !prevItem.IsEmpty() { + item = prevItem + } + + switch { + case item.qr != "": + log.Debug().Msg("Got code in QR channel") + Segment.Track(user.MXID, "$qrcode_retrieved") + jsonResponse(w, http.StatusOK, LoginResponse{Status: "qr", Code: item.qr}) + case item.err != nil: + log.Err(item.err).Msg("Got error in QR channel") + Segment.Track(user.MXID, "$login_failure") + var resp LoginResponse + switch { + case errors.Is(item.err, ErrLoginTimeout): + resp = LoginResponse{ErrCode: "timeout", Error: "Scanning QR code timed out"} + default: + resp = LoginResponse{ErrCode: "unknown", Error: "Login failed"} + } + resp.Status = "fail" + jsonResponse(w, http.StatusOK, resp) + case item.success: + log.Debug().Msg("Got pair success in QR channel") + Segment.Track(user.MXID, "$login_success") + jsonResponse(w, http.StatusOK, LoginResponse{Status: "success"}) + default: + log.Error().Any("item_data", item).Msg("Unknown item in QR channel") + resp := LoginResponse{Status: "fail", ErrCode: "internal-error", Error: "Unknown item in login channel"} + jsonResponse(w, http.StatusInternalServerError, resp) + } +} diff --git a/user.go b/user.go index b77ed1f..8895085 100644 --- a/user.go +++ b/user.go @@ -24,6 +24,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -71,7 +72,10 @@ type User struct { batteryLow bool mobileData bool - pairSuccessChan chan struct{} + loginInProgress atomic.Bool + pairSuccessChan chan struct{} + ongoingLoginChan <-chan qrChannelItem + loginChanReadLock sync.Mutex DoublePuppetIntent *appservice.IntentAPI } @@ -406,15 +410,20 @@ type qrChannelItem struct { err error } +func (qci qrChannelItem) IsEmpty() bool { + return !qci.success && qci.qr == "" && qci.err == nil +} + func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelItem, error) { user.connLock.Lock() defer user.connLock.Unlock() if user.Session != nil { return nil, ErrAlreadyLoggedIn - } else if user.Client != nil { + } else if !user.loginInProgress.CompareAndSwap(false, true) { + return user.ongoingLoginChan, ErrLoginInProgress + } + if user.Client != nil { user.unlockedDeleteConnection() - } else if user.pairSuccessChan != nil { - return nil, ErrLoginInProgress } pairSuccessChan := make(chan struct{}) user.pairSuccessChan = pairSuccessChan @@ -423,9 +432,12 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI if err != nil { user.DeleteConnection() user.pairSuccessChan = nil + user.loginInProgress.Store(false) return nil, fmt.Errorf("failed to connect to Google Messages: %w", err) } + Segment.Track(user.MXID, "$login_start") ch := make(chan qrChannelItem, maxAttempts+2) + user.ongoingLoginChan = ch ch <- qrChannelItem{qr: qr} go func() { ticker := time.NewTicker(30 * time.Second) @@ -437,14 +449,21 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI user.DeleteConnection() } user.pairSuccessChan = nil + user.ongoingLoginChan = nil close(ch) + user.loginInProgress.Store(false) }() - for ; maxAttempts > 0; maxAttempts-- { + for { + maxAttempts-- select { case <-ctx.Done(): user.zlog.Debug().Err(ctx.Err()).Msg("Login context cancelled") return case <-ticker.C: + if maxAttempts <= 0 { + ch <- qrChannelItem{err: ErrLoginTimeout} + return + } qr, err := user.Client.RefreshPhoneRelay() if err != nil { ch <- qrChannelItem{err: fmt.Errorf("failed to refresh QR code: %w", err)} @@ -457,7 +476,6 @@ func (user *User) Login(ctx context.Context, maxAttempts int) (<-chan qrChannelI return } } - ch <- qrChannelItem{err: ErrLoginTimeout} }() return ch, nil }