diff --git a/database/database.go b/database/database.go index a593741..32196c2 100644 --- a/database/database.go +++ b/database/database.go @@ -1,5 +1,5 @@ // mautrix-gmessages - A Matrix-Google Messages puppeting bridge. -// Copyright (C) 2023 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,8 +17,6 @@ package database import ( - "context" - _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "go.mau.fi/util/dbutil" @@ -36,45 +34,13 @@ type Database struct { Reaction *ReactionQuery } -func New(baseDB *dbutil.Database) *Database { - db := &Database{Database: baseDB} +func New(db *dbutil.Database) *Database { db.UpgradeTable = upgrades.Table - db.User = &UserQuery{db: db} - db.Portal = &PortalQuery{db: db} - db.Puppet = &PuppetQuery{db: db} - db.Message = &MessageQuery{db: db} - db.Reaction = &ReactionQuery{db: db} - return db -} - -type dataStruct[T any] interface { - Scan(row dbutil.Scannable) (T, error) -} - -type queryStruct[T dataStruct[T]] interface { - New() T - getDB() *Database -} - -func get[T dataStruct[T]](qs queryStruct[T], ctx context.Context, query string, args ...any) (T, error) { - return qs.New().Scan(qs.getDB().Conn(ctx).QueryRowContext(ctx, query, args...)) -} - -func getAll[T dataStruct[T]](qs queryStruct[T], ctx context.Context, query string, args ...any) ([]T, error) { - rows, err := qs.getDB().Conn(ctx).QueryContext(ctx, query, args...) - if err != nil { - return nil, err + return &Database{ + User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)}, + Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)}, + Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)}, + Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)}, + Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)}, } - items := make([]T, 0) - defer func() { - _ = rows.Close() - }() - for rows.Next() { - item, err := qs.New().Scan(rows) - if err != nil { - return nil, err - } - items = append(items, item) - } - return items, rows.Err() } diff --git a/database/message.go b/database/message.go index c1e2a54..c67f80b 100644 --- a/database/message.go +++ b/database/message.go @@ -1,5 +1,5 @@ // mautrix-gmessages - A Matrix-Google Messages puppeting bridge. -// Copyright (C) 2023 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -18,8 +18,6 @@ package database import ( "context" - "database/sql" - "errors" "fmt" "strings" "time" @@ -31,17 +29,11 @@ import ( ) type MessageQuery struct { - db *Database + *dbutil.QueryHelper[*Message] } -func (mq *MessageQuery) New() *Message { - return &Message{ - db: mq.db, - } -} - -func (mq *MessageQuery) getDB() *Database { - return mq.db +func newMessage(qh *dbutil.QueryHelper[*Message]) *Message { + return &Message{qh: qh} } const ( @@ -84,24 +76,23 @@ const ( ) func (mq *MessageQuery) GetByID(ctx context.Context, receiver int, messageID string) (*Message, error) { - return get[*Message](mq, ctx, getMessageByIDQuery, receiver, messageID) + return mq.QueryOne(ctx, getMessageByIDQuery, receiver, messageID) } func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) { - return get[*Message](mq, ctx, getMessageByMXIDQuery, mxid) + return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid) } func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message, error) { - return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver) + return mq.QueryOne(ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver) } func (mq *MessageQuery) GetLastInChatWithMXID(ctx context.Context, chat Key) (*Message, error) { - return get[*Message](mq, ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver) + return mq.QueryOne(ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver) } func (mq *MessageQuery) DeleteAllInChat(ctx context.Context, chat Key) error { - _, err := mq.db.Conn(ctx).ExecContext(ctx, deleteAllMessagesInChatQuery, chat.ID, chat.Receiver) - return err + return mq.Exec(ctx, deleteAllMessagesInChatQuery, chat.ID, chat.Receiver) } type MediaPart struct { @@ -132,7 +123,7 @@ func (ms *MessageStatus) HasPendingMediaParts() bool { } type Message struct { - db *Database + qh *dbutil.QueryHelper[*Message] Chat Key ID string @@ -146,9 +137,7 @@ type Message struct { func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) { var ts int64 err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.RoomID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status}) - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } else if err != nil { + if err != nil { return nil, err } if ts != 0 { @@ -162,13 +151,12 @@ func (msg *Message) sqlVariables() []any { } func (msg *Message) Insert(ctx context.Context) error { - _, err := msg.db.Conn(ctx).ExecContext(ctx, insertMessageQuery, msg.sqlVariables()...) - return err + return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...) } func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) error { valueStringFormat := "($1, $2, $%d, $%d, $3, $%d, $%d, $%d)" - if mq.db.Dialect == dbutil.SQLite { + if mq.GetDB().Dialect == dbutil.SQLite { valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") } placeholders := make([]string, len(messages)) @@ -186,23 +174,19 @@ func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) err placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5) } query := massInsertMessageQueryPrefix + strings.Join(placeholders, ",") - _, err := mq.db.Conn(ctx).ExecContext(ctx, query, params...) - return err + return mq.Exec(ctx, query, params...) } func (msg *Message) Update(ctx context.Context) error { - _, err := msg.db.Conn(ctx).ExecContext(ctx, updateMessageQuery, msg.sqlVariables()...) - return err + return msg.qh.Exec(ctx, updateMessageQuery, msg.sqlVariables()...) } func (msg *Message) UpdateStatus(ctx context.Context) error { - _, err := msg.db.Conn(ctx).ExecContext(ctx, updateMessageStatusQuery, dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID) - return err + return msg.qh.Exec(ctx, updateMessageStatusQuery, dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID) } func (msg *Message) Delete(ctx context.Context) error { - _, err := msg.db.Conn(ctx).ExecContext(ctx, deleteMessageQuery, msg.Chat.ID, msg.Chat.Receiver, msg.ID) - return err + return msg.qh.Exec(ctx, deleteMessageQuery, msg.Chat.ID, msg.Chat.Receiver, msg.ID) } func (msg *Message) IsFakeMXID() bool { diff --git a/database/portal.go b/database/portal.go index 5201b5a..48a6a36 100644 --- a/database/portal.go +++ b/database/portal.go @@ -1,5 +1,5 @@ // mautrix-gmessages - A Matrix-Google Messages puppeting bridge. -// Copyright (C) 2023 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -19,7 +19,6 @@ package database import ( "context" "database/sql" - "errors" "fmt" "github.com/rs/zerolog" @@ -30,17 +29,11 @@ import ( ) type PortalQuery struct { - db *Database + *dbutil.QueryHelper[*Portal] } -func (pq *PortalQuery) New() *Portal { - return &Portal{ - db: pq.db, - } -} - -func (pq *PortalQuery) getDB() *Database { - return pq.db +func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal { + return &Portal{qh: qh} } const ( @@ -62,23 +55,23 @@ const ( ) func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { - return getAll[*Portal](pq, ctx, getAllPortalsQuery) + return pq.QueryMany(ctx, getAllPortalsQuery) } func (pq *PortalQuery) GetAllForUser(ctx context.Context, receiver int) ([]*Portal, error) { - return getAll[*Portal](pq, ctx, getAllPortalsForUserQuery, receiver) + return pq.QueryMany(ctx, getAllPortalsForUserQuery, receiver) } func (pq *PortalQuery) GetByKey(ctx context.Context, key Key) (*Portal, error) { - return get[*Portal](pq, ctx, getPortalByKeyQuery, key.ID, key.Receiver) + return pq.QueryOne(ctx, getPortalByKeyQuery, key.ID, key.Receiver) } func (pq *PortalQuery) GetByOtherUser(ctx context.Context, key Key) (*Portal, error) { - return get[*Portal](pq, ctx, getPortalByOtherUserQuery, key.ID, key.Receiver) + return pq.QueryOne(ctx, getPortalByOtherUserQuery, key.ID, key.Receiver) } func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { - return get[*Portal](pq, ctx, getPortalByMXIDQuery, mxid) + return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid) } type Key struct { @@ -95,7 +88,7 @@ func (p Key) MarshalZerologObject(e *zerolog.Event) { } type Portal struct { - db *Database + qh *dbutil.QueryHelper[*Portal] Key OutgoingID string @@ -112,10 +105,11 @@ type Portal struct { func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, selfUserID, otherUserID sql.NullString var convType int - err := row.Scan(&portal.ID, &portal.Receiver, &selfUserID, &otherUserID, &convType, &mxid, &portal.Name, &portal.NameSet, &portal.Encrypted, &portal.InSpace) - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } else if err != nil { + err := row.Scan( + &portal.ID, &portal.Receiver, &selfUserID, &otherUserID, &convType, &mxid, + &portal.Name, &portal.NameSet, &portal.Encrypted, &portal.InSpace, + ) + if err != nil { return nil, err } portal.Type = gmproto.ConversationType(convType) @@ -126,33 +120,21 @@ func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) { } func (portal *Portal) sqlVariables() []any { - var mxid, selfUserID, otherUserID *string - if portal.MXID != "" { - mxid = (*string)(&portal.MXID) - } - if portal.OutgoingID != "" { - selfUserID = &portal.OutgoingID - } - if portal.OtherUserID != "" { - otherUserID = &portal.OtherUserID - } return []any{ - portal.ID, portal.Receiver, selfUserID, otherUserID, int(portal.Type), mxid, portal.Name, portal.NameSet, - portal.Encrypted, portal.InSpace, + portal.ID, portal.Receiver, dbutil.StrPtr(portal.OutgoingID), dbutil.StrPtr(portal.OtherUserID), + int(portal.Type), dbutil.StrPtr(portal.MXID), + portal.Name, portal.NameSet, portal.Encrypted, portal.InSpace, } } func (portal *Portal) Insert(ctx context.Context) error { - _, err := portal.db.Conn(ctx).ExecContext(ctx, insertPortalQuery, portal.sqlVariables()...) - return err + return portal.qh.Exec(ctx, insertPortalQuery, portal.sqlVariables()...) } func (portal *Portal) Update(ctx context.Context) error { - _, err := portal.db.Conn(ctx).ExecContext(ctx, updatePortalQuery, portal.sqlVariables()...) - return err + return portal.qh.Exec(ctx, updatePortalQuery, portal.sqlVariables()...) } func (portal *Portal) Delete(ctx context.Context) error { - _, err := portal.db.Conn(ctx).ExecContext(ctx, deletePortalQuery, portal.ID, portal.Receiver) - return err + return portal.qh.Exec(ctx, deletePortalQuery, portal.ID, portal.Receiver) } diff --git a/database/puppet.go b/database/puppet.go index 90a63dc..f22fd1d 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -1,5 +1,5 @@ // mautrix-gmessages - A Matrix-Google Messages puppeting bridge. -// Copyright (C) 2023 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -18,8 +18,6 @@ package database import ( "context" - "database/sql" - "errors" "time" "go.mau.fi/util/dbutil" @@ -27,13 +25,11 @@ import ( ) type PuppetQuery struct { - db *Database + *dbutil.QueryHelper[*Puppet] } -func (pq *PuppetQuery) New() *Puppet { - return &Puppet{ - db: pq.db, - } +func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet { + return &Puppet{qh: qh} } const ( @@ -50,21 +46,16 @@ const ( ` ) -func (pq *PuppetQuery) getDB() *Database { - return pq.db -} - func (pq *PuppetQuery) DeleteAllForUser(ctx context.Context, userID int) error { - _, err := pq.db.Conn(ctx).ExecContext(ctx, deleteAllPuppetsForUserQuery, userID) - return err + return pq.Exec(ctx, deleteAllPuppetsForUserQuery, userID) } func (pq *PuppetQuery) Get(ctx context.Context, key Key) (*Puppet, error) { - return get[*Puppet](pq, ctx, getPuppetQuery, key.ID, key.Receiver) + return pq.QueryOne(ctx, getPuppetQuery, key.ID, key.Receiver) } type Puppet struct { - db *Database + qh *dbutil.QueryHelper[*Puppet] Key Phone string @@ -81,10 +72,11 @@ type Puppet struct { func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) { var avatarHash []byte var avatarUpdateTS int64 - err := row.Scan(&puppet.ID, &puppet.Receiver, &puppet.Phone, &puppet.ContactID, &puppet.Name, &puppet.NameSet, &avatarHash, &puppet.AvatarMXC, &puppet.AvatarSet, &avatarUpdateTS, &puppet.ContactInfoSet) - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } else if err != nil { + err := row.Scan( + &puppet.ID, &puppet.Receiver, &puppet.Phone, &puppet.ContactID, &puppet.Name, &puppet.NameSet, + &avatarHash, &puppet.AvatarMXC, &puppet.AvatarSet, &avatarUpdateTS, &puppet.ContactInfoSet, + ) + if err != nil { return nil, err } if len(avatarHash) == 32 { @@ -95,15 +87,16 @@ func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) { } func (puppet *Puppet) sqlVariables() []any { - return []any{puppet.ID, puppet.Receiver, puppet.Phone, puppet.ContactID, puppet.Name, puppet.NameSet, puppet.AvatarHash[:], &puppet.AvatarMXC, puppet.AvatarSet, puppet.AvatarUpdateTS.UnixMilli(), puppet.ContactInfoSet} + return []any{ + puppet.ID, puppet.Receiver, puppet.Phone, puppet.ContactID, puppet.Name, puppet.NameSet, + puppet.AvatarHash[:], &puppet.AvatarMXC, puppet.AvatarSet, puppet.AvatarUpdateTS.UnixMilli(), puppet.ContactInfoSet, + } } func (puppet *Puppet) Insert(ctx context.Context) error { - _, err := puppet.db.Conn(ctx).ExecContext(ctx, insertPuppetQuery, puppet.sqlVariables()...) - return err + return puppet.qh.Exec(ctx, insertPuppetQuery, puppet.sqlVariables()...) } func (puppet *Puppet) Update(ctx context.Context) error { - _, err := puppet.db.Conn(ctx).ExecContext(ctx, updatePuppetQuery, puppet.sqlVariables()...) - return err + return puppet.qh.Exec(ctx, updatePuppetQuery, puppet.sqlVariables()...) } diff --git a/database/reaction.go b/database/reaction.go index dae327d..643b7fd 100644 --- a/database/reaction.go +++ b/database/reaction.go @@ -1,5 +1,5 @@ // mautrix-gmessages - A Matrix-Google Messages puppeting bridge. -// Copyright (C) 2023 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -18,8 +18,6 @@ package database import ( "context" - "database/sql" - "errors" "fmt" "strings" @@ -28,17 +26,11 @@ import ( ) type ReactionQuery struct { - db *Database + *dbutil.QueryHelper[*Reaction] } -func (rq *ReactionQuery) New() *Reaction { - return &Reaction{ - db: rq.db, - } -} - -func (rq *ReactionQuery) getDB() *Database { - return rq.db +func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction { + return &Reaction{qh: qh} } const ( @@ -68,24 +60,23 @@ const ( ) func (rq *ReactionQuery) GetByID(ctx context.Context, receiver int, messageID, sender string) (*Reaction, error) { - return get[*Reaction](rq, ctx, getReactionByIDQuery, receiver, messageID, sender) + return rq.QueryOne(ctx, getReactionByIDQuery, receiver, messageID, sender) } func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { - return get[*Reaction](rq, ctx, getReactionByMXIDQuery, mxid) + return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid) } func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, receiver int, messageID string) ([]*Reaction, error) { - return getAll[*Reaction](rq, ctx, getReactionsByMessageIDQuery, receiver, messageID) + return rq.QueryMany(ctx, getReactionsByMessageIDQuery, receiver, messageID) } func (rq *ReactionQuery) DeleteAllByMessage(ctx context.Context, chat Key, messageID string) error { - _, err := rq.db.Conn(ctx).ExecContext(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID) - return err + return rq.Exec(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID) } type Reaction struct { - db *Database + qh *dbutil.QueryHelper[*Reaction] Chat Key MessageID string @@ -95,23 +86,16 @@ type Reaction struct { } func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { - err := row.Scan(&r.Chat.ID, &r.Chat.Receiver, &r.MessageID, &r.Sender, &r.Reaction, &r.MXID) - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } else if err != nil { - return nil, err - } - return r, nil + return dbutil.ValueOrErr(r, row.Scan(&r.Chat.ID, &r.Chat.Receiver, &r.MessageID, &r.Sender, &r.Reaction, &r.MXID)) } func (r *Reaction) Insert(ctx context.Context) error { - _, err := r.db.Conn(ctx).ExecContext(ctx, insertReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender, r.Reaction, r.MXID) - return err + return r.qh.Exec(ctx, insertReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender, r.Reaction, r.MXID) } func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction) error { valueStringFormat := "($1, $2, $%d, $%d, $%d, $%d)" - if rq.db.Dialect == dbutil.SQLite { + if rq.GetDB().Dialect == dbutil.SQLite { valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") } placeholders := make([]string, len(reactions)) @@ -127,11 +111,9 @@ func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction) placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4) } query := strings.Replace(insertReactionQuery, "($1, $2, $3, $4, $5, $6)", strings.Join(placeholders, ","), 1) - _, err := rq.db.Conn(ctx).ExecContext(ctx, query, params...) - return err + return rq.Exec(ctx, query, params...) } func (r *Reaction) Delete(ctx context.Context) error { - _, err := r.db.Conn(ctx).ExecContext(ctx, deleteReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender) - return err + return r.qh.Exec(ctx, deleteReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender) } diff --git a/database/user.go b/database/user.go index 14a2d1c..6e0294e 100644 --- a/database/user.go +++ b/database/user.go @@ -1,5 +1,5 @@ // mautrix-gmessages - A Matrix-Google Messages puppeting bridge. -// Copyright (C) 2023 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -20,7 +20,6 @@ import ( "context" "database/sql" "encoding/json" - "errors" "fmt" "sync" @@ -33,17 +32,11 @@ import ( ) type UserQuery struct { - db *Database + *dbutil.QueryHelper[*User] } -func (uq *UserQuery) New() *User { - return &User{ - db: uq.db, - } -} - -func (uq *UserQuery) getDB() *Database { - return uq.db +func newUser(qh *dbutil.QueryHelper[*User]) *User { + return &User{qh: qh} } const ( @@ -63,22 +56,23 @@ const ( management_room=$7, space_room=$8, access_token=$9 WHERE mxid=$1 ` + updateuserParticipantIDsQuery = `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1` ) func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) { - return getAll[*User](uq, ctx, getAllUsersWithSessionQuery) + return uq.QueryMany(ctx, getAllUsersWithSessionQuery) } func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) { - return getAll[*User](uq, ctx, getAllUsersWithDoublePuppetQuery) + return uq.QueryMany(ctx, getAllUsersWithDoublePuppetQuery) } func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) { - return get[*User](uq, ctx, getUserByRowIDQuery, rowID) + return uq.QueryOne(ctx, getUserByRowIDQuery, rowID) } func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) { - return get[*User](uq, ctx, getUserByMXIDQuery, userID) + return uq.QueryOne(ctx, getUserByMXIDQuery, userID) } type Settings struct { @@ -90,7 +84,7 @@ type Settings struct { } type User struct { - db *Database + qh *dbutil.QueryHelper[*User] RowID int MXID id.UserID @@ -114,10 +108,11 @@ type User struct { func (user *User) Scan(row dbutil.Scannable) (*User, error) { var phoneID, session, managementRoom, spaceRoom, accessToken sql.NullString var selfParticipantIDs, simMetadata, settings string - err := row.Scan(&user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &simMetadata, &settings, &managementRoom, &spaceRoom, &accessToken) - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } else if err != nil { + err := row.Scan( + &user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &simMetadata, + &settings, &managementRoom, &spaceRoom, &accessToken, + ) + if err != nil { return nil, err } if session.String != "" { @@ -152,24 +147,12 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) { } func (user *User) sqlVariables() []any { - var phoneID, session, managementRoom, spaceRoom, accessToken *string - if user.PhoneID != "" { - phoneID = &user.PhoneID - } + var session *string if user.Session != nil { data, _ := json.Marshal(user.Session) strData := string(data) session = &strData } - if user.ManagementRoom != "" { - managementRoom = (*string)(&user.ManagementRoom) - } - if user.SpaceRoom != "" { - spaceRoom = (*string)(&user.SpaceRoom) - } - if user.AccessToken != "" { - accessToken = &user.AccessToken - } user.selfParticipantIDsLock.RLock() selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs) user.selfParticipantIDsLock.RUnlock() @@ -183,7 +166,10 @@ func (user *User) sqlVariables() []any { if err != nil { panic(err) } - return []any{user.MXID, phoneID, session, string(selfParticipantIDs), string(simMetadata), string(settings), managementRoom, spaceRoom, accessToken} + return []any{ + user.MXID, dbutil.StrPtr(user.PhoneID), session, string(selfParticipantIDs), string(simMetadata), + string(settings), dbutil.StrPtr(user.ManagementRoom), dbutil.StrPtr(user.SpaceRoom), dbutil.StrPtr(user.AccessToken), + } } func (user *User) IsSelfParticipantID(id string) bool { @@ -268,20 +254,18 @@ func (user *User) AddSelfParticipantID(ctx context.Context, id string) error { if !slices.Contains(user.SelfParticipantIDs, id) { user.SelfParticipantIDs = append(user.SelfParticipantIDs, id) selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs) - _, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`, user.MXID, selfParticipantIDs) - return err + return user.qh.Exec(ctx, updateuserParticipantIDsQuery, user.MXID, selfParticipantIDs) } return nil } func (user *User) Insert(ctx context.Context) error { - err := user.db.Conn(ctx). - QueryRowContext(ctx, insertUserQuery, user.sqlVariables()...). + err := user.qh.GetDB(). + QueryRow(ctx, insertUserQuery, user.sqlVariables()...). Scan(&user.RowID) return err } func (user *User) Update(ctx context.Context) error { - _, err := user.db.Conn(ctx).ExecContext(ctx, updateUserQuery, user.sqlVariables()...) - return err + return user.qh.Exec(ctx, updateUserQuery, user.sqlVariables()...) }