Use dbutil.QueryHelper for database stuff

This commit is contained in:
Tulir Asokan 2024-02-25 00:18:06 +02:00
parent 6c4d8d8744
commit aa7c66496f
6 changed files with 102 additions and 211 deletions

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge. // mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,8 +17,6 @@
package database package database
import ( import (
"context"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"go.mau.fi/util/dbutil" "go.mau.fi/util/dbutil"
@ -36,45 +34,13 @@ type Database struct {
Reaction *ReactionQuery Reaction *ReactionQuery
} }
func New(baseDB *dbutil.Database) *Database { func New(db *dbutil.Database) *Database {
db := &Database{Database: baseDB}
db.UpgradeTable = upgrades.Table db.UpgradeTable = upgrades.Table
db.User = &UserQuery{db: db} return &Database{
db.Portal = &PortalQuery{db: db} User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)},
db.Puppet = &PuppetQuery{db: db} Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)},
db.Message = &MessageQuery{db: db} Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)},
db.Reaction = &ReactionQuery{db: db} Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)},
return db Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)},
}
type dataStruct[T any] interface {
Scan(row dbutil.Scannable) (T, error)
}
type queryStruct[T dataStruct[T]] interface {
New() T
getDB() *Database
}
func get[T dataStruct[T]](qs queryStruct[T], ctx context.Context, query string, args ...any) (T, error) {
return qs.New().Scan(qs.getDB().Conn(ctx).QueryRowContext(ctx, query, args...))
}
func getAll[T dataStruct[T]](qs queryStruct[T], ctx context.Context, query string, args ...any) ([]T, error) {
rows, err := qs.getDB().Conn(ctx).QueryContext(ctx, query, args...)
if err != nil {
return nil, err
} }
items := make([]T, 0)
defer func() {
_ = rows.Close()
}()
for rows.Next() {
item, err := qs.New().Scan(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
} }

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge. // mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -18,8 +18,6 @@ package database
import ( import (
"context" "context"
"database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@ -31,17 +29,11 @@ import (
) )
type MessageQuery struct { type MessageQuery struct {
db *Database *dbutil.QueryHelper[*Message]
} }
func (mq *MessageQuery) New() *Message { func newMessage(qh *dbutil.QueryHelper[*Message]) *Message {
return &Message{ return &Message{qh: qh}
db: mq.db,
}
}
func (mq *MessageQuery) getDB() *Database {
return mq.db
} }
const ( const (
@ -84,24 +76,23 @@ const (
) )
func (mq *MessageQuery) GetByID(ctx context.Context, receiver int, messageID string) (*Message, error) { func (mq *MessageQuery) GetByID(ctx context.Context, receiver int, messageID string) (*Message, error) {
return get[*Message](mq, ctx, getMessageByIDQuery, receiver, messageID) return mq.QueryOne(ctx, getMessageByIDQuery, receiver, messageID)
} }
func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) { func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
return get[*Message](mq, ctx, getMessageByMXIDQuery, mxid) return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid)
} }
func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message, error) { func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message, error) {
return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver) return mq.QueryOne(ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver)
} }
func (mq *MessageQuery) GetLastInChatWithMXID(ctx context.Context, chat Key) (*Message, error) { func (mq *MessageQuery) GetLastInChatWithMXID(ctx context.Context, chat Key) (*Message, error) {
return get[*Message](mq, ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver) return mq.QueryOne(ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver)
} }
func (mq *MessageQuery) DeleteAllInChat(ctx context.Context, chat Key) error { func (mq *MessageQuery) DeleteAllInChat(ctx context.Context, chat Key) error {
_, err := mq.db.Conn(ctx).ExecContext(ctx, deleteAllMessagesInChatQuery, chat.ID, chat.Receiver) return mq.Exec(ctx, deleteAllMessagesInChatQuery, chat.ID, chat.Receiver)
return err
} }
type MediaPart struct { type MediaPart struct {
@ -132,7 +123,7 @@ func (ms *MessageStatus) HasPendingMediaParts() bool {
} }
type Message struct { type Message struct {
db *Database qh *dbutil.QueryHelper[*Message]
Chat Key Chat Key
ID string ID string
@ -146,9 +137,7 @@ type Message struct {
func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) { func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
var ts int64 var ts int64
err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.RoomID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status}) err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.RoomID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status})
if errors.Is(err, sql.ErrNoRows) { if err != nil {
return nil, nil
} else if err != nil {
return nil, err return nil, err
} }
if ts != 0 { if ts != 0 {
@ -162,13 +151,12 @@ func (msg *Message) sqlVariables() []any {
} }
func (msg *Message) Insert(ctx context.Context) error { func (msg *Message) Insert(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, insertMessageQuery, msg.sqlVariables()...) return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...)
return err
} }
func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) error { func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) error {
valueStringFormat := "($1, $2, $%d, $%d, $3, $%d, $%d, $%d)" valueStringFormat := "($1, $2, $%d, $%d, $3, $%d, $%d, $%d)"
if mq.db.Dialect == dbutil.SQLite { if mq.GetDB().Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
} }
placeholders := make([]string, len(messages)) placeholders := make([]string, len(messages))
@ -186,23 +174,19 @@ func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) err
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5) placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5)
} }
query := massInsertMessageQueryPrefix + strings.Join(placeholders, ",") query := massInsertMessageQueryPrefix + strings.Join(placeholders, ",")
_, err := mq.db.Conn(ctx).ExecContext(ctx, query, params...) return mq.Exec(ctx, query, params...)
return err
} }
func (msg *Message) Update(ctx context.Context) error { func (msg *Message) Update(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, updateMessageQuery, msg.sqlVariables()...) return msg.qh.Exec(ctx, updateMessageQuery, msg.sqlVariables()...)
return err
} }
func (msg *Message) UpdateStatus(ctx context.Context) error { func (msg *Message) UpdateStatus(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, updateMessageStatusQuery, dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID) return msg.qh.Exec(ctx, updateMessageStatusQuery, dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID)
return err
} }
func (msg *Message) Delete(ctx context.Context) error { func (msg *Message) Delete(ctx context.Context) error {
_, err := msg.db.Conn(ctx).ExecContext(ctx, deleteMessageQuery, msg.Chat.ID, msg.Chat.Receiver, msg.ID) return msg.qh.Exec(ctx, deleteMessageQuery, msg.Chat.ID, msg.Chat.Receiver, msg.ID)
return err
} }
func (msg *Message) IsFakeMXID() bool { func (msg *Message) IsFakeMXID() bool {

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge. // mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -19,7 +19,6 @@ package database
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -30,17 +29,11 @@ import (
) )
type PortalQuery struct { type PortalQuery struct {
db *Database *dbutil.QueryHelper[*Portal]
} }
func (pq *PortalQuery) New() *Portal { func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal {
return &Portal{ return &Portal{qh: qh}
db: pq.db,
}
}
func (pq *PortalQuery) getDB() *Database {
return pq.db
} }
const ( const (
@ -62,23 +55,23 @@ const (
) )
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
return getAll[*Portal](pq, ctx, getAllPortalsQuery) return pq.QueryMany(ctx, getAllPortalsQuery)
} }
func (pq *PortalQuery) GetAllForUser(ctx context.Context, receiver int) ([]*Portal, error) { func (pq *PortalQuery) GetAllForUser(ctx context.Context, receiver int) ([]*Portal, error) {
return getAll[*Portal](pq, ctx, getAllPortalsForUserQuery, receiver) return pq.QueryMany(ctx, getAllPortalsForUserQuery, receiver)
} }
func (pq *PortalQuery) GetByKey(ctx context.Context, key Key) (*Portal, error) { func (pq *PortalQuery) GetByKey(ctx context.Context, key Key) (*Portal, error) {
return get[*Portal](pq, ctx, getPortalByKeyQuery, key.ID, key.Receiver) return pq.QueryOne(ctx, getPortalByKeyQuery, key.ID, key.Receiver)
} }
func (pq *PortalQuery) GetByOtherUser(ctx context.Context, key Key) (*Portal, error) { func (pq *PortalQuery) GetByOtherUser(ctx context.Context, key Key) (*Portal, error) {
return get[*Portal](pq, ctx, getPortalByOtherUserQuery, key.ID, key.Receiver) return pq.QueryOne(ctx, getPortalByOtherUserQuery, key.ID, key.Receiver)
} }
func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) {
return get[*Portal](pq, ctx, getPortalByMXIDQuery, mxid) return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid)
} }
type Key struct { type Key struct {
@ -95,7 +88,7 @@ func (p Key) MarshalZerologObject(e *zerolog.Event) {
} }
type Portal struct { type Portal struct {
db *Database qh *dbutil.QueryHelper[*Portal]
Key Key
OutgoingID string OutgoingID string
@ -112,10 +105,11 @@ type Portal struct {
func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) { func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
var mxid, selfUserID, otherUserID sql.NullString var mxid, selfUserID, otherUserID sql.NullString
var convType int var convType int
err := row.Scan(&portal.ID, &portal.Receiver, &selfUserID, &otherUserID, &convType, &mxid, &portal.Name, &portal.NameSet, &portal.Encrypted, &portal.InSpace) err := row.Scan(
if errors.Is(err, sql.ErrNoRows) { &portal.ID, &portal.Receiver, &selfUserID, &otherUserID, &convType, &mxid,
return nil, nil &portal.Name, &portal.NameSet, &portal.Encrypted, &portal.InSpace,
} else if err != nil { )
if err != nil {
return nil, err return nil, err
} }
portal.Type = gmproto.ConversationType(convType) portal.Type = gmproto.ConversationType(convType)
@ -126,33 +120,21 @@ func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
} }
func (portal *Portal) sqlVariables() []any { func (portal *Portal) sqlVariables() []any {
var mxid, selfUserID, otherUserID *string
if portal.MXID != "" {
mxid = (*string)(&portal.MXID)
}
if portal.OutgoingID != "" {
selfUserID = &portal.OutgoingID
}
if portal.OtherUserID != "" {
otherUserID = &portal.OtherUserID
}
return []any{ return []any{
portal.ID, portal.Receiver, selfUserID, otherUserID, int(portal.Type), mxid, portal.Name, portal.NameSet, portal.ID, portal.Receiver, dbutil.StrPtr(portal.OutgoingID), dbutil.StrPtr(portal.OtherUserID),
portal.Encrypted, portal.InSpace, int(portal.Type), dbutil.StrPtr(portal.MXID),
portal.Name, portal.NameSet, portal.Encrypted, portal.InSpace,
} }
} }
func (portal *Portal) Insert(ctx context.Context) error { func (portal *Portal) Insert(ctx context.Context) error {
_, err := portal.db.Conn(ctx).ExecContext(ctx, insertPortalQuery, portal.sqlVariables()...) return portal.qh.Exec(ctx, insertPortalQuery, portal.sqlVariables()...)
return err
} }
func (portal *Portal) Update(ctx context.Context) error { func (portal *Portal) Update(ctx context.Context) error {
_, err := portal.db.Conn(ctx).ExecContext(ctx, updatePortalQuery, portal.sqlVariables()...) return portal.qh.Exec(ctx, updatePortalQuery, portal.sqlVariables()...)
return err
} }
func (portal *Portal) Delete(ctx context.Context) error { func (portal *Portal) Delete(ctx context.Context) error {
_, err := portal.db.Conn(ctx).ExecContext(ctx, deletePortalQuery, portal.ID, portal.Receiver) return portal.qh.Exec(ctx, deletePortalQuery, portal.ID, portal.Receiver)
return err
} }

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge. // mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -18,8 +18,6 @@ package database
import ( import (
"context" "context"
"database/sql"
"errors"
"time" "time"
"go.mau.fi/util/dbutil" "go.mau.fi/util/dbutil"
@ -27,13 +25,11 @@ import (
) )
type PuppetQuery struct { type PuppetQuery struct {
db *Database *dbutil.QueryHelper[*Puppet]
} }
func (pq *PuppetQuery) New() *Puppet { func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet {
return &Puppet{ return &Puppet{qh: qh}
db: pq.db,
}
} }
const ( const (
@ -50,21 +46,16 @@ const (
` `
) )
func (pq *PuppetQuery) getDB() *Database {
return pq.db
}
func (pq *PuppetQuery) DeleteAllForUser(ctx context.Context, userID int) error { func (pq *PuppetQuery) DeleteAllForUser(ctx context.Context, userID int) error {
_, err := pq.db.Conn(ctx).ExecContext(ctx, deleteAllPuppetsForUserQuery, userID) return pq.Exec(ctx, deleteAllPuppetsForUserQuery, userID)
return err
} }
func (pq *PuppetQuery) Get(ctx context.Context, key Key) (*Puppet, error) { func (pq *PuppetQuery) Get(ctx context.Context, key Key) (*Puppet, error) {
return get[*Puppet](pq, ctx, getPuppetQuery, key.ID, key.Receiver) return pq.QueryOne(ctx, getPuppetQuery, key.ID, key.Receiver)
} }
type Puppet struct { type Puppet struct {
db *Database qh *dbutil.QueryHelper[*Puppet]
Key Key
Phone string Phone string
@ -81,10 +72,11 @@ type Puppet struct {
func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) { func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
var avatarHash []byte var avatarHash []byte
var avatarUpdateTS int64 var avatarUpdateTS int64
err := row.Scan(&puppet.ID, &puppet.Receiver, &puppet.Phone, &puppet.ContactID, &puppet.Name, &puppet.NameSet, &avatarHash, &puppet.AvatarMXC, &puppet.AvatarSet, &avatarUpdateTS, &puppet.ContactInfoSet) err := row.Scan(
if errors.Is(err, sql.ErrNoRows) { &puppet.ID, &puppet.Receiver, &puppet.Phone, &puppet.ContactID, &puppet.Name, &puppet.NameSet,
return nil, nil &avatarHash, &puppet.AvatarMXC, &puppet.AvatarSet, &avatarUpdateTS, &puppet.ContactInfoSet,
} else if err != nil { )
if err != nil {
return nil, err return nil, err
} }
if len(avatarHash) == 32 { if len(avatarHash) == 32 {
@ -95,15 +87,16 @@ func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
} }
func (puppet *Puppet) sqlVariables() []any { func (puppet *Puppet) sqlVariables() []any {
return []any{puppet.ID, puppet.Receiver, puppet.Phone, puppet.ContactID, puppet.Name, puppet.NameSet, puppet.AvatarHash[:], &puppet.AvatarMXC, puppet.AvatarSet, puppet.AvatarUpdateTS.UnixMilli(), puppet.ContactInfoSet} return []any{
puppet.ID, puppet.Receiver, puppet.Phone, puppet.ContactID, puppet.Name, puppet.NameSet,
puppet.AvatarHash[:], &puppet.AvatarMXC, puppet.AvatarSet, puppet.AvatarUpdateTS.UnixMilli(), puppet.ContactInfoSet,
}
} }
func (puppet *Puppet) Insert(ctx context.Context) error { func (puppet *Puppet) Insert(ctx context.Context) error {
_, err := puppet.db.Conn(ctx).ExecContext(ctx, insertPuppetQuery, puppet.sqlVariables()...) return puppet.qh.Exec(ctx, insertPuppetQuery, puppet.sqlVariables()...)
return err
} }
func (puppet *Puppet) Update(ctx context.Context) error { func (puppet *Puppet) Update(ctx context.Context) error {
_, err := puppet.db.Conn(ctx).ExecContext(ctx, updatePuppetQuery, puppet.sqlVariables()...) return puppet.qh.Exec(ctx, updatePuppetQuery, puppet.sqlVariables()...)
return err
} }

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge. // mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -18,8 +18,6 @@ package database
import ( import (
"context" "context"
"database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
@ -28,17 +26,11 @@ import (
) )
type ReactionQuery struct { type ReactionQuery struct {
db *Database *dbutil.QueryHelper[*Reaction]
} }
func (rq *ReactionQuery) New() *Reaction { func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction {
return &Reaction{ return &Reaction{qh: qh}
db: rq.db,
}
}
func (rq *ReactionQuery) getDB() *Database {
return rq.db
} }
const ( const (
@ -68,24 +60,23 @@ const (
) )
func (rq *ReactionQuery) GetByID(ctx context.Context, receiver int, messageID, sender string) (*Reaction, error) { func (rq *ReactionQuery) GetByID(ctx context.Context, receiver int, messageID, sender string) (*Reaction, error) {
return get[*Reaction](rq, ctx, getReactionByIDQuery, receiver, messageID, sender) return rq.QueryOne(ctx, getReactionByIDQuery, receiver, messageID, sender)
} }
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
return get[*Reaction](rq, ctx, getReactionByMXIDQuery, mxid) return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid)
} }
func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, receiver int, messageID string) ([]*Reaction, error) { func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, receiver int, messageID string) ([]*Reaction, error) {
return getAll[*Reaction](rq, ctx, getReactionsByMessageIDQuery, receiver, messageID) return rq.QueryMany(ctx, getReactionsByMessageIDQuery, receiver, messageID)
} }
func (rq *ReactionQuery) DeleteAllByMessage(ctx context.Context, chat Key, messageID string) error { func (rq *ReactionQuery) DeleteAllByMessage(ctx context.Context, chat Key, messageID string) error {
_, err := rq.db.Conn(ctx).ExecContext(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID) return rq.Exec(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID)
return err
} }
type Reaction struct { type Reaction struct {
db *Database qh *dbutil.QueryHelper[*Reaction]
Chat Key Chat Key
MessageID string MessageID string
@ -95,23 +86,16 @@ type Reaction struct {
} }
func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
err := row.Scan(&r.Chat.ID, &r.Chat.Receiver, &r.MessageID, &r.Sender, &r.Reaction, &r.MXID) return dbutil.ValueOrErr(r, row.Scan(&r.Chat.ID, &r.Chat.Receiver, &r.MessageID, &r.Sender, &r.Reaction, &r.MXID))
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return r, nil
} }
func (r *Reaction) Insert(ctx context.Context) error { func (r *Reaction) Insert(ctx context.Context) error {
_, err := r.db.Conn(ctx).ExecContext(ctx, insertReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender, r.Reaction, r.MXID) return r.qh.Exec(ctx, insertReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender, r.Reaction, r.MXID)
return err
} }
func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction) error { func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction) error {
valueStringFormat := "($1, $2, $%d, $%d, $%d, $%d)" valueStringFormat := "($1, $2, $%d, $%d, $%d, $%d)"
if rq.db.Dialect == dbutil.SQLite { if rq.GetDB().Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
} }
placeholders := make([]string, len(reactions)) placeholders := make([]string, len(reactions))
@ -127,11 +111,9 @@ func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction)
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4) placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4)
} }
query := strings.Replace(insertReactionQuery, "($1, $2, $3, $4, $5, $6)", strings.Join(placeholders, ","), 1) query := strings.Replace(insertReactionQuery, "($1, $2, $3, $4, $5, $6)", strings.Join(placeholders, ","), 1)
_, err := rq.db.Conn(ctx).ExecContext(ctx, query, params...) return rq.Exec(ctx, query, params...)
return err
} }
func (r *Reaction) Delete(ctx context.Context) error { func (r *Reaction) Delete(ctx context.Context) error {
_, err := r.db.Conn(ctx).ExecContext(ctx, deleteReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender) return r.qh.Exec(ctx, deleteReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender)
return err
} }

View file

@ -1,5 +1,5 @@
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge. // mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
// Copyright (C) 2023 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -20,7 +20,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"sync" "sync"
@ -33,17 +32,11 @@ import (
) )
type UserQuery struct { type UserQuery struct {
db *Database *dbutil.QueryHelper[*User]
} }
func (uq *UserQuery) New() *User { func newUser(qh *dbutil.QueryHelper[*User]) *User {
return &User{ return &User{qh: qh}
db: uq.db,
}
}
func (uq *UserQuery) getDB() *Database {
return uq.db
} }
const ( const (
@ -63,22 +56,23 @@ const (
management_room=$7, space_room=$8, access_token=$9 management_room=$7, space_room=$8, access_token=$9
WHERE mxid=$1 WHERE mxid=$1
` `
updateuserParticipantIDsQuery = `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`
) )
func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) { func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) {
return getAll[*User](uq, ctx, getAllUsersWithSessionQuery) return uq.QueryMany(ctx, getAllUsersWithSessionQuery)
} }
func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) { func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) {
return getAll[*User](uq, ctx, getAllUsersWithDoublePuppetQuery) return uq.QueryMany(ctx, getAllUsersWithDoublePuppetQuery)
} }
func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) { func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) {
return get[*User](uq, ctx, getUserByRowIDQuery, rowID) return uq.QueryOne(ctx, getUserByRowIDQuery, rowID)
} }
func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) { func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
return get[*User](uq, ctx, getUserByMXIDQuery, userID) return uq.QueryOne(ctx, getUserByMXIDQuery, userID)
} }
type Settings struct { type Settings struct {
@ -90,7 +84,7 @@ type Settings struct {
} }
type User struct { type User struct {
db *Database qh *dbutil.QueryHelper[*User]
RowID int RowID int
MXID id.UserID MXID id.UserID
@ -114,10 +108,11 @@ type User struct {
func (user *User) Scan(row dbutil.Scannable) (*User, error) { func (user *User) Scan(row dbutil.Scannable) (*User, error) {
var phoneID, session, managementRoom, spaceRoom, accessToken sql.NullString var phoneID, session, managementRoom, spaceRoom, accessToken sql.NullString
var selfParticipantIDs, simMetadata, settings string var selfParticipantIDs, simMetadata, settings string
err := row.Scan(&user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &simMetadata, &settings, &managementRoom, &spaceRoom, &accessToken) err := row.Scan(
if errors.Is(err, sql.ErrNoRows) { &user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &simMetadata,
return nil, nil &settings, &managementRoom, &spaceRoom, &accessToken,
} else if err != nil { )
if err != nil {
return nil, err return nil, err
} }
if session.String != "" { if session.String != "" {
@ -152,24 +147,12 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) {
} }
func (user *User) sqlVariables() []any { func (user *User) sqlVariables() []any {
var phoneID, session, managementRoom, spaceRoom, accessToken *string var session *string
if user.PhoneID != "" {
phoneID = &user.PhoneID
}
if user.Session != nil { if user.Session != nil {
data, _ := json.Marshal(user.Session) data, _ := json.Marshal(user.Session)
strData := string(data) strData := string(data)
session = &strData session = &strData
} }
if user.ManagementRoom != "" {
managementRoom = (*string)(&user.ManagementRoom)
}
if user.SpaceRoom != "" {
spaceRoom = (*string)(&user.SpaceRoom)
}
if user.AccessToken != "" {
accessToken = &user.AccessToken
}
user.selfParticipantIDsLock.RLock() user.selfParticipantIDsLock.RLock()
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs) selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
user.selfParticipantIDsLock.RUnlock() user.selfParticipantIDsLock.RUnlock()
@ -183,7 +166,10 @@ func (user *User) sqlVariables() []any {
if err != nil { if err != nil {
panic(err) panic(err)
} }
return []any{user.MXID, phoneID, session, string(selfParticipantIDs), string(simMetadata), string(settings), managementRoom, spaceRoom, accessToken} return []any{
user.MXID, dbutil.StrPtr(user.PhoneID), session, string(selfParticipantIDs), string(simMetadata),
string(settings), dbutil.StrPtr(user.ManagementRoom), dbutil.StrPtr(user.SpaceRoom), dbutil.StrPtr(user.AccessToken),
}
} }
func (user *User) IsSelfParticipantID(id string) bool { func (user *User) IsSelfParticipantID(id string) bool {
@ -268,20 +254,18 @@ func (user *User) AddSelfParticipantID(ctx context.Context, id string) error {
if !slices.Contains(user.SelfParticipantIDs, id) { if !slices.Contains(user.SelfParticipantIDs, id) {
user.SelfParticipantIDs = append(user.SelfParticipantIDs, id) user.SelfParticipantIDs = append(user.SelfParticipantIDs, id)
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs) selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
_, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`, user.MXID, selfParticipantIDs) return user.qh.Exec(ctx, updateuserParticipantIDsQuery, user.MXID, selfParticipantIDs)
return err
} }
return nil return nil
} }
func (user *User) Insert(ctx context.Context) error { func (user *User) Insert(ctx context.Context) error {
err := user.db.Conn(ctx). err := user.qh.GetDB().
QueryRowContext(ctx, insertUserQuery, user.sqlVariables()...). QueryRow(ctx, insertUserQuery, user.sqlVariables()...).
Scan(&user.RowID) Scan(&user.RowID)
return err return err
} }
func (user *User) Update(ctx context.Context) error { func (user *User) Update(ctx context.Context) error {
_, err := user.db.Conn(ctx).ExecContext(ctx, updateUserQuery, user.sqlVariables()...) return user.qh.Exec(ctx, updateUserQuery, user.sqlVariables()...)
return err
} }