Use dbutil.QueryHelper for database stuff

This commit is contained in:
Tulir Asokan 2024-02-25 00:18:06 +02:00
parent 6c4d8d8744
commit aa7c66496f
6 changed files with 102 additions and 211 deletions

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,8 +17,6 @@
package database
import (
"context"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"go.mau.fi/util/dbutil"
@ -36,45 +34,13 @@ type Database struct {
Reaction *ReactionQuery
}
func New(baseDB *dbutil.Database) *Database {
db := &Database{Database: baseDB}
func New(db *dbutil.Database) *Database {
db.UpgradeTable = upgrades.Table
db.User = &UserQuery{db: db}
db.Portal = &PortalQuery{db: db}
db.Puppet = &PuppetQuery{db: db}
db.Message = &MessageQuery{db: db}
db.Reaction = &ReactionQuery{db: db}
return db
}
type dataStruct[T any] interface {
Scan(row dbutil.Scannable) (T, error)
}
type queryStruct[T dataStruct[T]] interface {
New() T
getDB() *Database
}
func get[T dataStruct[T]](qs queryStruct[T], ctx context.Context, query string, args ...any) (T, error) {
return qs.New().Scan(qs.getDB().Conn(ctx).QueryRowContext(ctx, query, args...))
}
func getAll[T dataStruct[T]](qs queryStruct[T], ctx context.Context, query string, args ...any) ([]T, error) {
rows, err := qs.getDB().Conn(ctx).QueryContext(ctx, query, args...)
if err != nil {
return nil, err
return &Database{
User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)},
Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)},
Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)},
Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)},
Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)},
}
items := make([]T, 0)
defer func() {
_ = rows.Close()
}()
for rows.Next() {
item, err := qs.New().Scan(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -18,8 +18,6 @@ package database
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
@ -31,17 +29,11 @@ import (
)
type MessageQuery struct {
db *Database
*dbutil.QueryHelper[*Message]
}
func (mq *MessageQuery) New() *Message {
return &Message{
db: mq.db,
}
}
func (mq *MessageQuery) getDB() *Database {
return mq.db
func newMessage(qh *dbutil.QueryHelper[*Message]) *Message {
return &Message{qh: qh}
}
const (
@ -84,24 +76,23 @@ const (
)
func (mq *MessageQuery) GetByID(ctx context.Context, receiver int, messageID string) (*Message, error) {
return get[*Message](mq, ctx, getMessageByIDQuery, receiver, messageID)
return mq.QueryOne(ctx, getMessageByIDQuery, receiver, messageID)
}
func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
return get[*Message](mq, ctx, getMessageByMXIDQuery, mxid)
return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid)
}
func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message, error) {
return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver)
return mq.QueryOne(ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver)
}
func (mq *MessageQuery) GetLastInChatWithMXID(ctx context.Context, chat Key) (*Message, error) {
return get[*Message](mq, ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver)
return mq.QueryOne(ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver)
}
func (mq *MessageQuery) DeleteAllInChat(ctx context.Context, chat Key) error {
_, err := mq.db.Conn(ctx).ExecContext(ctx, deleteAllMessagesInChatQuery, chat.ID, chat.Receiver)
return err
return mq.Exec(ctx, deleteAllMessagesInChatQuery, chat.ID, chat.Receiver)
}
type MediaPart struct {
@ -132,7 +123,7 @@ func (ms *MessageStatus) HasPendingMediaParts() bool {
}
type Message struct {
db *Database
qh *dbutil.QueryHelper[*Message]
Chat Key
ID string
@ -146,9 +137,7 @@ type Message struct {
func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
var ts int64
err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.RoomID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status})
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
if err != nil {
return nil, err
}
if ts != 0 {
@ -162,13 +151,12 @@ func (msg *Message) sqlVariables() []any {
}
func (msg *Message) Insert(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, insertMessageQuery, msg.sqlVariables()...)
return err
return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...)
}
func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) error {
valueStringFormat := "($1, $2, $%d, $%d, $3, $%d, $%d, $%d)"
if mq.db.Dialect == dbutil.SQLite {
if mq.GetDB().Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
}
placeholders := make([]string, len(messages))
@ -186,23 +174,19 @@ func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) err
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5)
}
query := massInsertMessageQueryPrefix + strings.Join(placeholders, ",")
_, err := mq.db.Conn(ctx).ExecContext(ctx, query, params...)
return err
return mq.Exec(ctx, query, params...)
}
func (msg *Message) Update(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, updateMessageQuery, msg.sqlVariables()...)
return err
return msg.qh.Exec(ctx, updateMessageQuery, msg.sqlVariables()...)
}
func (msg *Message) UpdateStatus(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, updateMessageStatusQuery, dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID)
return err
return msg.qh.Exec(ctx, updateMessageStatusQuery, dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID)
}
func (msg *Message) Delete(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, deleteMessageQuery, msg.Chat.ID, msg.Chat.Receiver, msg.ID)
return err
return msg.qh.Exec(ctx, deleteMessageQuery, msg.Chat.ID, msg.Chat.Receiver, msg.ID)
}
func (msg *Message) IsFakeMXID() bool {

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -19,7 +19,6 @@ package database
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/rs/zerolog"
@ -30,17 +29,11 @@ import (
)
type PortalQuery struct {
db *Database
*dbutil.QueryHelper[*Portal]
}
func (pq *PortalQuery) New() *Portal {
return &Portal{
db: pq.db,
}
}
func (pq *PortalQuery) getDB() *Database {
return pq.db
func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal {
return &Portal{qh: qh}
}
const (
@ -62,23 +55,23 @@ const (
)
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
return getAll[*Portal](pq, ctx, getAllPortalsQuery)
return pq.QueryMany(ctx, getAllPortalsQuery)
}
func (pq *PortalQuery) GetAllForUser(ctx context.Context, receiver int) ([]*Portal, error) {
return getAll[*Portal](pq, ctx, getAllPortalsForUserQuery, receiver)
return pq.QueryMany(ctx, getAllPortalsForUserQuery, receiver)
}
func (pq *PortalQuery) GetByKey(ctx context.Context, key Key) (*Portal, error) {
return get[*Portal](pq, ctx, getPortalByKeyQuery, key.ID, key.Receiver)
return pq.QueryOne(ctx, getPortalByKeyQuery, key.ID, key.Receiver)
}
func (pq *PortalQuery) GetByOtherUser(ctx context.Context, key Key) (*Portal, error) {
return get[*Portal](pq, ctx, getPortalByOtherUserQuery, key.ID, key.Receiver)
return pq.QueryOne(ctx, getPortalByOtherUserQuery, key.ID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) {
return get[*Portal](pq, ctx, getPortalByMXIDQuery, mxid)
return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid)
}
type Key struct {
@ -95,7 +88,7 @@ func (p Key) MarshalZerologObject(e *zerolog.Event) {
}
type Portal struct {
db *Database
qh *dbutil.QueryHelper[*Portal]
Key
OutgoingID string
@ -112,10 +105,11 @@ type Portal struct {
func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
var mxid, selfUserID, otherUserID sql.NullString
var convType int
err := row.Scan(&portal.ID, &portal.Receiver, &selfUserID, &otherUserID, &convType, &mxid, &portal.Name, &portal.NameSet, &portal.Encrypted, &portal.InSpace)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
err := row.Scan(
&portal.ID, &portal.Receiver, &selfUserID, &otherUserID, &convType, &mxid,
&portal.Name, &portal.NameSet, &portal.Encrypted, &portal.InSpace,
)
if err != nil {
return nil, err
}
portal.Type = gmproto.ConversationType(convType)
@ -126,33 +120,21 @@ func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
}
func (portal *Portal) sqlVariables() []any {
var mxid, selfUserID, otherUserID *string
if portal.MXID != "" {
mxid = (*string)(&portal.MXID)
}
if portal.OutgoingID != "" {
selfUserID = &portal.OutgoingID
}
if portal.OtherUserID != "" {
otherUserID = &portal.OtherUserID
}
return []any{
portal.ID, portal.Receiver, selfUserID, otherUserID, int(portal.Type), mxid, portal.Name, portal.NameSet,
portal.Encrypted, portal.InSpace,
portal.ID, portal.Receiver, dbutil.StrPtr(portal.OutgoingID), dbutil.StrPtr(portal.OtherUserID),
int(portal.Type), dbutil.StrPtr(portal.MXID),
portal.Name, portal.NameSet, portal.Encrypted, portal.InSpace,
}
}
func (portal *Portal) Insert(ctx context.Context) error {
_, err := portal.db.Conn(ctx).ExecContext(ctx, insertPortalQuery, portal.sqlVariables()...)
return err
return portal.qh.Exec(ctx, insertPortalQuery, portal.sqlVariables()...)
}
func (portal *Portal) Update(ctx context.Context) error {
_, err := portal.db.Conn(ctx).ExecContext(ctx, updatePortalQuery, portal.sqlVariables()...)
return err
return portal.qh.Exec(ctx, updatePortalQuery, portal.sqlVariables()...)
}
func (portal *Portal) Delete(ctx context.Context) error {
_, err := portal.db.Conn(ctx).ExecContext(ctx, deletePortalQuery, portal.ID, portal.Receiver)
return err
return portal.qh.Exec(ctx, deletePortalQuery, portal.ID, portal.Receiver)
}

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -18,8 +18,6 @@ package database
import (
"context"
"database/sql"
"errors"
"time"
"go.mau.fi/util/dbutil"
@ -27,13 +25,11 @@ import (
)
type PuppetQuery struct {
db *Database
*dbutil.QueryHelper[*Puppet]
}
func (pq *PuppetQuery) New() *Puppet {
return &Puppet{
db: pq.db,
}
func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet {
return &Puppet{qh: qh}
}
const (
@ -50,21 +46,16 @@ const (
`
)
func (pq *PuppetQuery) getDB() *Database {
return pq.db
}
func (pq *PuppetQuery) DeleteAllForUser(ctx context.Context, userID int) error {
_, err := pq.db.Conn(ctx).ExecContext(ctx, deleteAllPuppetsForUserQuery, userID)
return err
return pq.Exec(ctx, deleteAllPuppetsForUserQuery, userID)
}
func (pq *PuppetQuery) Get(ctx context.Context, key Key) (*Puppet, error) {
return get[*Puppet](pq, ctx, getPuppetQuery, key.ID, key.Receiver)
return pq.QueryOne(ctx, getPuppetQuery, key.ID, key.Receiver)
}
type Puppet struct {
db *Database
qh *dbutil.QueryHelper[*Puppet]
Key
Phone string
@ -81,10 +72,11 @@ type Puppet struct {
func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
var avatarHash []byte
var avatarUpdateTS int64
err := row.Scan(&puppet.ID, &puppet.Receiver, &puppet.Phone, &puppet.ContactID, &puppet.Name, &puppet.NameSet, &avatarHash, &puppet.AvatarMXC, &puppet.AvatarSet, &avatarUpdateTS, &puppet.ContactInfoSet)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
err := row.Scan(
&puppet.ID, &puppet.Receiver, &puppet.Phone, &puppet.ContactID, &puppet.Name, &puppet.NameSet,
&avatarHash, &puppet.AvatarMXC, &puppet.AvatarSet, &avatarUpdateTS, &puppet.ContactInfoSet,
)
if err != nil {
return nil, err
}
if len(avatarHash) == 32 {
@ -95,15 +87,16 @@ func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
}
func (puppet *Puppet) sqlVariables() []any {
return []any{puppet.ID, puppet.Receiver, puppet.Phone, puppet.ContactID, puppet.Name, puppet.NameSet, puppet.AvatarHash[:], &puppet.AvatarMXC, puppet.AvatarSet, puppet.AvatarUpdateTS.UnixMilli(), puppet.ContactInfoSet}
return []any{
puppet.ID, puppet.Receiver, puppet.Phone, puppet.ContactID, puppet.Name, puppet.NameSet,
puppet.AvatarHash[:], &puppet.AvatarMXC, puppet.AvatarSet, puppet.AvatarUpdateTS.UnixMilli(), puppet.ContactInfoSet,
}
}
func (puppet *Puppet) Insert(ctx context.Context) error {
_, err := puppet.db.Conn(ctx).ExecContext(ctx, insertPuppetQuery, puppet.sqlVariables()...)
return err
return puppet.qh.Exec(ctx, insertPuppetQuery, puppet.sqlVariables()...)
}
func (puppet *Puppet) Update(ctx context.Context) error {
_, err := puppet.db.Conn(ctx).ExecContext(ctx, updatePuppetQuery, puppet.sqlVariables()...)
return err
return puppet.qh.Exec(ctx, updatePuppetQuery, puppet.sqlVariables()...)
}

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -18,8 +18,6 @@ package database
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
@ -28,17 +26,11 @@ import (
)
type ReactionQuery struct {
db *Database
*dbutil.QueryHelper[*Reaction]
}
func (rq *ReactionQuery) New() *Reaction {
return &Reaction{
db: rq.db,
}
}
func (rq *ReactionQuery) getDB() *Database {
return rq.db
func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction {
return &Reaction{qh: qh}
}
const (
@ -68,24 +60,23 @@ const (
)
func (rq *ReactionQuery) GetByID(ctx context.Context, receiver int, messageID, sender string) (*Reaction, error) {
return get[*Reaction](rq, ctx, getReactionByIDQuery, receiver, messageID, sender)
return rq.QueryOne(ctx, getReactionByIDQuery, receiver, messageID, sender)
}
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
return get[*Reaction](rq, ctx, getReactionByMXIDQuery, mxid)
return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid)
}
func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, receiver int, messageID string) ([]*Reaction, error) {
return getAll[*Reaction](rq, ctx, getReactionsByMessageIDQuery, receiver, messageID)
return rq.QueryMany(ctx, getReactionsByMessageIDQuery, receiver, messageID)
}
func (rq *ReactionQuery) DeleteAllByMessage(ctx context.Context, chat Key, messageID string) error {
_, err := rq.db.Conn(ctx).ExecContext(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID)
return err
return rq.Exec(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID)
}
type Reaction struct {
db *Database
qh *dbutil.QueryHelper[*Reaction]
Chat Key
MessageID string
@ -95,23 +86,16 @@ type Reaction struct {
}
func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
err := row.Scan(&r.Chat.ID, &r.Chat.Receiver, &r.MessageID, &r.Sender, &r.Reaction, &r.MXID)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return r, nil
return dbutil.ValueOrErr(r, row.Scan(&r.Chat.ID, &r.Chat.Receiver, &r.MessageID, &r.Sender, &r.Reaction, &r.MXID))
}
func (r *Reaction) Insert(ctx context.Context) error {
_, err := r.db.Conn(ctx).ExecContext(ctx, insertReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender, r.Reaction, r.MXID)
return err
return r.qh.Exec(ctx, insertReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender, r.Reaction, r.MXID)
}
func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction) error {
valueStringFormat := "($1, $2, $%d, $%d, $%d, $%d)"
if rq.db.Dialect == dbutil.SQLite {
if rq.GetDB().Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
}
placeholders := make([]string, len(reactions))
@ -127,11 +111,9 @@ func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction)
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4)
}
query := strings.Replace(insertReactionQuery, "($1, $2, $3, $4, $5, $6)", strings.Join(placeholders, ","), 1)
_, err := rq.db.Conn(ctx).ExecContext(ctx, query, params...)
return err
return rq.Exec(ctx, query, params...)
}
func (r *Reaction) Delete(ctx context.Context) error {
_, err := r.db.Conn(ctx).ExecContext(ctx, deleteReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender)
return err
return r.qh.Exec(ctx, deleteReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender)
}

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -20,7 +20,6 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"sync"
@ -33,17 +32,11 @@ import (
)
type UserQuery struct {
db *Database
*dbutil.QueryHelper[*User]
}
func (uq *UserQuery) New() *User {
return &User{
db: uq.db,
}
}
func (uq *UserQuery) getDB() *Database {
return uq.db
func newUser(qh *dbutil.QueryHelper[*User]) *User {
return &User{qh: qh}
}
const (
@ -63,22 +56,23 @@ const (
management_room=$7, space_room=$8, access_token=$9
WHERE mxid=$1
`
updateuserParticipantIDsQuery = `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`
)
func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) {
return getAll[*User](uq, ctx, getAllUsersWithSessionQuery)
return uq.QueryMany(ctx, getAllUsersWithSessionQuery)
}
func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) {
return getAll[*User](uq, ctx, getAllUsersWithDoublePuppetQuery)
return uq.QueryMany(ctx, getAllUsersWithDoublePuppetQuery)
}
func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) {
return get[*User](uq, ctx, getUserByRowIDQuery, rowID)
return uq.QueryOne(ctx, getUserByRowIDQuery, rowID)
}
func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
return get[*User](uq, ctx, getUserByMXIDQuery, userID)
return uq.QueryOne(ctx, getUserByMXIDQuery, userID)
}
type Settings struct {
@ -90,7 +84,7 @@ type Settings struct {
}
type User struct {
db *Database
qh *dbutil.QueryHelper[*User]
RowID int
MXID id.UserID
@ -114,10 +108,11 @@ type User struct {
func (user *User) Scan(row dbutil.Scannable) (*User, error) {
var phoneID, session, managementRoom, spaceRoom, accessToken sql.NullString
var selfParticipantIDs, simMetadata, settings string
err := row.Scan(&user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &simMetadata, &settings, &managementRoom, &spaceRoom, &accessToken)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
err := row.Scan(
&user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &simMetadata,
&settings, &managementRoom, &spaceRoom, &accessToken,
)
if err != nil {
return nil, err
}
if session.String != "" {
@ -152,24 +147,12 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) {
}
func (user *User) sqlVariables() []any {
var phoneID, session, managementRoom, spaceRoom, accessToken *string
if user.PhoneID != "" {
phoneID = &user.PhoneID
}
var session *string
if user.Session != nil {
data, _ := json.Marshal(user.Session)
strData := string(data)
session = &strData
}
if user.ManagementRoom != "" {
managementRoom = (*string)(&user.ManagementRoom)
}
if user.SpaceRoom != "" {
spaceRoom = (*string)(&user.SpaceRoom)
}
if user.AccessToken != "" {
accessToken = &user.AccessToken
}
user.selfParticipantIDsLock.RLock()
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
user.selfParticipantIDsLock.RUnlock()
@ -183,7 +166,10 @@ func (user *User) sqlVariables() []any {
if err != nil {
panic(err)
}
return []any{user.MXID, phoneID, session, string(selfParticipantIDs), string(simMetadata), string(settings), managementRoom, spaceRoom, accessToken}
return []any{
user.MXID, dbutil.StrPtr(user.PhoneID), session, string(selfParticipantIDs), string(simMetadata),
string(settings), dbutil.StrPtr(user.ManagementRoom), dbutil.StrPtr(user.SpaceRoom), dbutil.StrPtr(user.AccessToken),
}
}
func (user *User) IsSelfParticipantID(id string) bool {
@ -268,20 +254,18 @@ func (user *User) AddSelfParticipantID(ctx context.Context, id string) error {
if !slices.Contains(user.SelfParticipantIDs, id) {
user.SelfParticipantIDs = append(user.SelfParticipantIDs, id)
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
_, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`, user.MXID, selfParticipantIDs)
return err
return user.qh.Exec(ctx, updateuserParticipantIDsQuery, user.MXID, selfParticipantIDs)
}
return nil
}
func (user *User) Insert(ctx context.Context) error {
err := user.db.Conn(ctx).
QueryRowContext(ctx, insertUserQuery, user.sqlVariables()...).
err := user.qh.GetDB().
QueryRow(ctx, insertUserQuery, user.sqlVariables()...).
Scan(&user.RowID)
return err
}
func (user *User) Update(ctx context.Context) error {
_, err := user.db.Conn(ctx).ExecContext(ctx, updateUserQuery, user.sqlVariables()...)
return err
return user.qh.Exec(ctx, updateUserQuery, user.sqlVariables()...)
}