Use dbutil.QueryHelper for database stuff
This commit is contained in:
parent
6c4d8d8744
commit
aa7c66496f
6 changed files with 102 additions and 211 deletions
|
@ -1,5 +1,5 @@
|
||||||
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
||||||
// Copyright (C) 2023 Tulir Asokan
|
// Copyright (C) 2024 Tulir Asokan
|
||||||
//
|
//
|
||||||
// This program is free software: you can redistribute it and/or modify
|
// This program is free software: you can redistribute it and/or modify
|
||||||
// it under the terms of the GNU Affero General Public License as published by
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
@ -17,8 +17,6 @@
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"go.mau.fi/util/dbutil"
|
"go.mau.fi/util/dbutil"
|
||||||
|
@ -36,45 +34,13 @@ type Database struct {
|
||||||
Reaction *ReactionQuery
|
Reaction *ReactionQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(baseDB *dbutil.Database) *Database {
|
func New(db *dbutil.Database) *Database {
|
||||||
db := &Database{Database: baseDB}
|
|
||||||
db.UpgradeTable = upgrades.Table
|
db.UpgradeTable = upgrades.Table
|
||||||
db.User = &UserQuery{db: db}
|
return &Database{
|
||||||
db.Portal = &PortalQuery{db: db}
|
User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)},
|
||||||
db.Puppet = &PuppetQuery{db: db}
|
Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)},
|
||||||
db.Message = &MessageQuery{db: db}
|
Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)},
|
||||||
db.Reaction = &ReactionQuery{db: db}
|
Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)},
|
||||||
return db
|
Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)},
|
||||||
}
|
|
||||||
|
|
||||||
type dataStruct[T any] interface {
|
|
||||||
Scan(row dbutil.Scannable) (T, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type queryStruct[T dataStruct[T]] interface {
|
|
||||||
New() T
|
|
||||||
getDB() *Database
|
|
||||||
}
|
|
||||||
|
|
||||||
func get[T dataStruct[T]](qs queryStruct[T], ctx context.Context, query string, args ...any) (T, error) {
|
|
||||||
return qs.New().Scan(qs.getDB().Conn(ctx).QueryRowContext(ctx, query, args...))
|
|
||||||
}
|
|
||||||
|
|
||||||
func getAll[T dataStruct[T]](qs queryStruct[T], ctx context.Context, query string, args ...any) ([]T, error) {
|
|
||||||
rows, err := qs.getDB().Conn(ctx).QueryContext(ctx, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
items := make([]T, 0)
|
|
||||||
defer func() {
|
|
||||||
_ = rows.Close()
|
|
||||||
}()
|
|
||||||
for rows.Next() {
|
|
||||||
item, err := qs.New().Scan(rows)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
items = append(items, item)
|
|
||||||
}
|
|
||||||
return items, rows.Err()
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
||||||
// Copyright (C) 2023 Tulir Asokan
|
// Copyright (C) 2024 Tulir Asokan
|
||||||
//
|
//
|
||||||
// This program is free software: you can redistribute it and/or modify
|
// This program is free software: you can redistribute it and/or modify
|
||||||
// it under the terms of the GNU Affero General Public License as published by
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
@ -18,8 +18,6 @@ package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -31,17 +29,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type MessageQuery struct {
|
type MessageQuery struct {
|
||||||
db *Database
|
*dbutil.QueryHelper[*Message]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mq *MessageQuery) New() *Message {
|
func newMessage(qh *dbutil.QueryHelper[*Message]) *Message {
|
||||||
return &Message{
|
return &Message{qh: qh}
|
||||||
db: mq.db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mq *MessageQuery) getDB() *Database {
|
|
||||||
return mq.db
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -84,24 +76,23 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (mq *MessageQuery) GetByID(ctx context.Context, receiver int, messageID string) (*Message, error) {
|
func (mq *MessageQuery) GetByID(ctx context.Context, receiver int, messageID string) (*Message, error) {
|
||||||
return get[*Message](mq, ctx, getMessageByIDQuery, receiver, messageID)
|
return mq.QueryOne(ctx, getMessageByIDQuery, receiver, messageID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
|
func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
|
||||||
return get[*Message](mq, ctx, getMessageByMXIDQuery, mxid)
|
return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message, error) {
|
func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat Key) (*Message, error) {
|
||||||
return get[*Message](mq, ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver)
|
return mq.QueryOne(ctx, getLastMessageInChatQuery, chat.ID, chat.Receiver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mq *MessageQuery) GetLastInChatWithMXID(ctx context.Context, chat Key) (*Message, error) {
|
func (mq *MessageQuery) GetLastInChatWithMXID(ctx context.Context, chat Key) (*Message, error) {
|
||||||
return get[*Message](mq, ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver)
|
return mq.QueryOne(ctx, getLastMessageInChatWithMXIDQuery, chat.ID, chat.Receiver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mq *MessageQuery) DeleteAllInChat(ctx context.Context, chat Key) error {
|
func (mq *MessageQuery) DeleteAllInChat(ctx context.Context, chat Key) error {
|
||||||
_, err := mq.db.Conn(ctx).ExecContext(ctx, deleteAllMessagesInChatQuery, chat.ID, chat.Receiver)
|
return mq.Exec(ctx, deleteAllMessagesInChatQuery, chat.ID, chat.Receiver)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MediaPart struct {
|
type MediaPart struct {
|
||||||
|
@ -132,7 +123,7 @@ func (ms *MessageStatus) HasPendingMediaParts() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
db *Database
|
qh *dbutil.QueryHelper[*Message]
|
||||||
|
|
||||||
Chat Key
|
Chat Key
|
||||||
ID string
|
ID string
|
||||||
|
@ -146,9 +137,7 @@ type Message struct {
|
||||||
func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
|
func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
|
||||||
var ts int64
|
var ts int64
|
||||||
err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.RoomID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status})
|
err := row.Scan(&msg.Chat.ID, &msg.Chat.Receiver, &msg.ID, &msg.MXID, &msg.RoomID, &msg.Sender, &ts, dbutil.JSON{Data: &msg.Status})
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if err != nil {
|
||||||
return nil, nil
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if ts != 0 {
|
if ts != 0 {
|
||||||
|
@ -162,13 +151,12 @@ func (msg *Message) sqlVariables() []any {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (msg *Message) Insert(ctx context.Context) error {
|
func (msg *Message) Insert(ctx context.Context) error {
|
||||||
_, err := msg.db.Conn(ctx).ExecContext(ctx, insertMessageQuery, msg.sqlVariables()...)
|
return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) error {
|
func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) error {
|
||||||
valueStringFormat := "($1, $2, $%d, $%d, $3, $%d, $%d, $%d)"
|
valueStringFormat := "($1, $2, $%d, $%d, $3, $%d, $%d, $%d)"
|
||||||
if mq.db.Dialect == dbutil.SQLite {
|
if mq.GetDB().Dialect == dbutil.SQLite {
|
||||||
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
|
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
|
||||||
}
|
}
|
||||||
placeholders := make([]string, len(messages))
|
placeholders := make([]string, len(messages))
|
||||||
|
@ -186,23 +174,19 @@ func (mq *MessageQuery) MassInsert(ctx context.Context, messages []*Message) err
|
||||||
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5)
|
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5)
|
||||||
}
|
}
|
||||||
query := massInsertMessageQueryPrefix + strings.Join(placeholders, ",")
|
query := massInsertMessageQueryPrefix + strings.Join(placeholders, ",")
|
||||||
_, err := mq.db.Conn(ctx).ExecContext(ctx, query, params...)
|
return mq.Exec(ctx, query, params...)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (msg *Message) Update(ctx context.Context) error {
|
func (msg *Message) Update(ctx context.Context) error {
|
||||||
_, err := msg.db.Conn(ctx).ExecContext(ctx, updateMessageQuery, msg.sqlVariables()...)
|
return msg.qh.Exec(ctx, updateMessageQuery, msg.sqlVariables()...)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (msg *Message) UpdateStatus(ctx context.Context) error {
|
func (msg *Message) UpdateStatus(ctx context.Context) error {
|
||||||
_, err := msg.db.Conn(ctx).ExecContext(ctx, updateMessageStatusQuery, dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID)
|
return msg.qh.Exec(ctx, updateMessageStatusQuery, dbutil.JSON{Data: &msg.Status}, msg.Timestamp.UnixMicro(), msg.Chat.Receiver, msg.ID)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (msg *Message) Delete(ctx context.Context) error {
|
func (msg *Message) Delete(ctx context.Context) error {
|
||||||
_, err := msg.db.Conn(ctx).ExecContext(ctx, deleteMessageQuery, msg.Chat.ID, msg.Chat.Receiver, msg.ID)
|
return msg.qh.Exec(ctx, deleteMessageQuery, msg.Chat.ID, msg.Chat.Receiver, msg.ID)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (msg *Message) IsFakeMXID() bool {
|
func (msg *Message) IsFakeMXID() bool {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
||||||
// Copyright (C) 2023 Tulir Asokan
|
// Copyright (C) 2024 Tulir Asokan
|
||||||
//
|
//
|
||||||
// This program is free software: you can redistribute it and/or modify
|
// This program is free software: you can redistribute it and/or modify
|
||||||
// it under the terms of the GNU Affero General Public License as published by
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
@ -19,7 +19,6 @@ package database
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -30,17 +29,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type PortalQuery struct {
|
type PortalQuery struct {
|
||||||
db *Database
|
*dbutil.QueryHelper[*Portal]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq *PortalQuery) New() *Portal {
|
func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal {
|
||||||
return &Portal{
|
return &Portal{qh: qh}
|
||||||
db: pq.db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pq *PortalQuery) getDB() *Database {
|
|
||||||
return pq.db
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -62,23 +55,23 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
|
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
|
||||||
return getAll[*Portal](pq, ctx, getAllPortalsQuery)
|
return pq.QueryMany(ctx, getAllPortalsQuery)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq *PortalQuery) GetAllForUser(ctx context.Context, receiver int) ([]*Portal, error) {
|
func (pq *PortalQuery) GetAllForUser(ctx context.Context, receiver int) ([]*Portal, error) {
|
||||||
return getAll[*Portal](pq, ctx, getAllPortalsForUserQuery, receiver)
|
return pq.QueryMany(ctx, getAllPortalsForUserQuery, receiver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq *PortalQuery) GetByKey(ctx context.Context, key Key) (*Portal, error) {
|
func (pq *PortalQuery) GetByKey(ctx context.Context, key Key) (*Portal, error) {
|
||||||
return get[*Portal](pq, ctx, getPortalByKeyQuery, key.ID, key.Receiver)
|
return pq.QueryOne(ctx, getPortalByKeyQuery, key.ID, key.Receiver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq *PortalQuery) GetByOtherUser(ctx context.Context, key Key) (*Portal, error) {
|
func (pq *PortalQuery) GetByOtherUser(ctx context.Context, key Key) (*Portal, error) {
|
||||||
return get[*Portal](pq, ctx, getPortalByOtherUserQuery, key.ID, key.Receiver)
|
return pq.QueryOne(ctx, getPortalByOtherUserQuery, key.ID, key.Receiver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) {
|
func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) {
|
||||||
return get[*Portal](pq, ctx, getPortalByMXIDQuery, mxid)
|
return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Key struct {
|
type Key struct {
|
||||||
|
@ -95,7 +88,7 @@ func (p Key) MarshalZerologObject(e *zerolog.Event) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Portal struct {
|
type Portal struct {
|
||||||
db *Database
|
qh *dbutil.QueryHelper[*Portal]
|
||||||
|
|
||||||
Key
|
Key
|
||||||
OutgoingID string
|
OutgoingID string
|
||||||
|
@ -112,10 +105,11 @@ type Portal struct {
|
||||||
func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
|
func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
|
||||||
var mxid, selfUserID, otherUserID sql.NullString
|
var mxid, selfUserID, otherUserID sql.NullString
|
||||||
var convType int
|
var convType int
|
||||||
err := row.Scan(&portal.ID, &portal.Receiver, &selfUserID, &otherUserID, &convType, &mxid, &portal.Name, &portal.NameSet, &portal.Encrypted, &portal.InSpace)
|
err := row.Scan(
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
&portal.ID, &portal.Receiver, &selfUserID, &otherUserID, &convType, &mxid,
|
||||||
return nil, nil
|
&portal.Name, &portal.NameSet, &portal.Encrypted, &portal.InSpace,
|
||||||
} else if err != nil {
|
)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
portal.Type = gmproto.ConversationType(convType)
|
portal.Type = gmproto.ConversationType(convType)
|
||||||
|
@ -126,33 +120,21 @@ func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (portal *Portal) sqlVariables() []any {
|
func (portal *Portal) sqlVariables() []any {
|
||||||
var mxid, selfUserID, otherUserID *string
|
|
||||||
if portal.MXID != "" {
|
|
||||||
mxid = (*string)(&portal.MXID)
|
|
||||||
}
|
|
||||||
if portal.OutgoingID != "" {
|
|
||||||
selfUserID = &portal.OutgoingID
|
|
||||||
}
|
|
||||||
if portal.OtherUserID != "" {
|
|
||||||
otherUserID = &portal.OtherUserID
|
|
||||||
}
|
|
||||||
return []any{
|
return []any{
|
||||||
portal.ID, portal.Receiver, selfUserID, otherUserID, int(portal.Type), mxid, portal.Name, portal.NameSet,
|
portal.ID, portal.Receiver, dbutil.StrPtr(portal.OutgoingID), dbutil.StrPtr(portal.OtherUserID),
|
||||||
portal.Encrypted, portal.InSpace,
|
int(portal.Type), dbutil.StrPtr(portal.MXID),
|
||||||
|
portal.Name, portal.NameSet, portal.Encrypted, portal.InSpace,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (portal *Portal) Insert(ctx context.Context) error {
|
func (portal *Portal) Insert(ctx context.Context) error {
|
||||||
_, err := portal.db.Conn(ctx).ExecContext(ctx, insertPortalQuery, portal.sqlVariables()...)
|
return portal.qh.Exec(ctx, insertPortalQuery, portal.sqlVariables()...)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (portal *Portal) Update(ctx context.Context) error {
|
func (portal *Portal) Update(ctx context.Context) error {
|
||||||
_, err := portal.db.Conn(ctx).ExecContext(ctx, updatePortalQuery, portal.sqlVariables()...)
|
return portal.qh.Exec(ctx, updatePortalQuery, portal.sqlVariables()...)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (portal *Portal) Delete(ctx context.Context) error {
|
func (portal *Portal) Delete(ctx context.Context) error {
|
||||||
_, err := portal.db.Conn(ctx).ExecContext(ctx, deletePortalQuery, portal.ID, portal.Receiver)
|
return portal.qh.Exec(ctx, deletePortalQuery, portal.ID, portal.Receiver)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
||||||
// Copyright (C) 2023 Tulir Asokan
|
// Copyright (C) 2024 Tulir Asokan
|
||||||
//
|
//
|
||||||
// This program is free software: you can redistribute it and/or modify
|
// This program is free software: you can redistribute it and/or modify
|
||||||
// it under the terms of the GNU Affero General Public License as published by
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
@ -18,8 +18,6 @@ package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.mau.fi/util/dbutil"
|
"go.mau.fi/util/dbutil"
|
||||||
|
@ -27,13 +25,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type PuppetQuery struct {
|
type PuppetQuery struct {
|
||||||
db *Database
|
*dbutil.QueryHelper[*Puppet]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq *PuppetQuery) New() *Puppet {
|
func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet {
|
||||||
return &Puppet{
|
return &Puppet{qh: qh}
|
||||||
db: pq.db,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -50,21 +46,16 @@ const (
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
func (pq *PuppetQuery) getDB() *Database {
|
|
||||||
return pq.db
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pq *PuppetQuery) DeleteAllForUser(ctx context.Context, userID int) error {
|
func (pq *PuppetQuery) DeleteAllForUser(ctx context.Context, userID int) error {
|
||||||
_, err := pq.db.Conn(ctx).ExecContext(ctx, deleteAllPuppetsForUserQuery, userID)
|
return pq.Exec(ctx, deleteAllPuppetsForUserQuery, userID)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq *PuppetQuery) Get(ctx context.Context, key Key) (*Puppet, error) {
|
func (pq *PuppetQuery) Get(ctx context.Context, key Key) (*Puppet, error) {
|
||||||
return get[*Puppet](pq, ctx, getPuppetQuery, key.ID, key.Receiver)
|
return pq.QueryOne(ctx, getPuppetQuery, key.ID, key.Receiver)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Puppet struct {
|
type Puppet struct {
|
||||||
db *Database
|
qh *dbutil.QueryHelper[*Puppet]
|
||||||
|
|
||||||
Key
|
Key
|
||||||
Phone string
|
Phone string
|
||||||
|
@ -81,10 +72,11 @@ type Puppet struct {
|
||||||
func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
|
func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
|
||||||
var avatarHash []byte
|
var avatarHash []byte
|
||||||
var avatarUpdateTS int64
|
var avatarUpdateTS int64
|
||||||
err := row.Scan(&puppet.ID, &puppet.Receiver, &puppet.Phone, &puppet.ContactID, &puppet.Name, &puppet.NameSet, &avatarHash, &puppet.AvatarMXC, &puppet.AvatarSet, &avatarUpdateTS, &puppet.ContactInfoSet)
|
err := row.Scan(
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
&puppet.ID, &puppet.Receiver, &puppet.Phone, &puppet.ContactID, &puppet.Name, &puppet.NameSet,
|
||||||
return nil, nil
|
&avatarHash, &puppet.AvatarMXC, &puppet.AvatarSet, &avatarUpdateTS, &puppet.ContactInfoSet,
|
||||||
} else if err != nil {
|
)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(avatarHash) == 32 {
|
if len(avatarHash) == 32 {
|
||||||
|
@ -95,15 +87,16 @@ func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (puppet *Puppet) sqlVariables() []any {
|
func (puppet *Puppet) sqlVariables() []any {
|
||||||
return []any{puppet.ID, puppet.Receiver, puppet.Phone, puppet.ContactID, puppet.Name, puppet.NameSet, puppet.AvatarHash[:], &puppet.AvatarMXC, puppet.AvatarSet, puppet.AvatarUpdateTS.UnixMilli(), puppet.ContactInfoSet}
|
return []any{
|
||||||
|
puppet.ID, puppet.Receiver, puppet.Phone, puppet.ContactID, puppet.Name, puppet.NameSet,
|
||||||
|
puppet.AvatarHash[:], &puppet.AvatarMXC, puppet.AvatarSet, puppet.AvatarUpdateTS.UnixMilli(), puppet.ContactInfoSet,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (puppet *Puppet) Insert(ctx context.Context) error {
|
func (puppet *Puppet) Insert(ctx context.Context) error {
|
||||||
_, err := puppet.db.Conn(ctx).ExecContext(ctx, insertPuppetQuery, puppet.sqlVariables()...)
|
return puppet.qh.Exec(ctx, insertPuppetQuery, puppet.sqlVariables()...)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (puppet *Puppet) Update(ctx context.Context) error {
|
func (puppet *Puppet) Update(ctx context.Context) error {
|
||||||
_, err := puppet.db.Conn(ctx).ExecContext(ctx, updatePuppetQuery, puppet.sqlVariables()...)
|
return puppet.qh.Exec(ctx, updatePuppetQuery, puppet.sqlVariables()...)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
||||||
// Copyright (C) 2023 Tulir Asokan
|
// Copyright (C) 2024 Tulir Asokan
|
||||||
//
|
//
|
||||||
// This program is free software: you can redistribute it and/or modify
|
// This program is free software: you can redistribute it and/or modify
|
||||||
// it under the terms of the GNU Affero General Public License as published by
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
@ -18,8 +18,6 @@ package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -28,17 +26,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type ReactionQuery struct {
|
type ReactionQuery struct {
|
||||||
db *Database
|
*dbutil.QueryHelper[*Reaction]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rq *ReactionQuery) New() *Reaction {
|
func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction {
|
||||||
return &Reaction{
|
return &Reaction{qh: qh}
|
||||||
db: rq.db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rq *ReactionQuery) getDB() *Database {
|
|
||||||
return rq.db
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -68,24 +60,23 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (rq *ReactionQuery) GetByID(ctx context.Context, receiver int, messageID, sender string) (*Reaction, error) {
|
func (rq *ReactionQuery) GetByID(ctx context.Context, receiver int, messageID, sender string) (*Reaction, error) {
|
||||||
return get[*Reaction](rq, ctx, getReactionByIDQuery, receiver, messageID, sender)
|
return rq.QueryOne(ctx, getReactionByIDQuery, receiver, messageID, sender)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
|
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
|
||||||
return get[*Reaction](rq, ctx, getReactionByMXIDQuery, mxid)
|
return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, receiver int, messageID string) ([]*Reaction, error) {
|
func (rq *ReactionQuery) GetAllByMessage(ctx context.Context, receiver int, messageID string) ([]*Reaction, error) {
|
||||||
return getAll[*Reaction](rq, ctx, getReactionsByMessageIDQuery, receiver, messageID)
|
return rq.QueryMany(ctx, getReactionsByMessageIDQuery, receiver, messageID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rq *ReactionQuery) DeleteAllByMessage(ctx context.Context, chat Key, messageID string) error {
|
func (rq *ReactionQuery) DeleteAllByMessage(ctx context.Context, chat Key, messageID string) error {
|
||||||
_, err := rq.db.Conn(ctx).ExecContext(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID)
|
return rq.Exec(ctx, deleteReactionsByMessageIDQuery, chat.ID, chat.Receiver, messageID)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Reaction struct {
|
type Reaction struct {
|
||||||
db *Database
|
qh *dbutil.QueryHelper[*Reaction]
|
||||||
|
|
||||||
Chat Key
|
Chat Key
|
||||||
MessageID string
|
MessageID string
|
||||||
|
@ -95,23 +86,16 @@ type Reaction struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
|
func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
|
||||||
err := row.Scan(&r.Chat.ID, &r.Chat.Receiver, &r.MessageID, &r.Sender, &r.Reaction, &r.MXID)
|
return dbutil.ValueOrErr(r, row.Scan(&r.Chat.ID, &r.Chat.Receiver, &r.MessageID, &r.Sender, &r.Reaction, &r.MXID))
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
|
||||||
return nil, nil
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return r, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Reaction) Insert(ctx context.Context) error {
|
func (r *Reaction) Insert(ctx context.Context) error {
|
||||||
_, err := r.db.Conn(ctx).ExecContext(ctx, insertReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender, r.Reaction, r.MXID)
|
return r.qh.Exec(ctx, insertReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender, r.Reaction, r.MXID)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction) error {
|
func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction) error {
|
||||||
valueStringFormat := "($1, $2, $%d, $%d, $%d, $%d)"
|
valueStringFormat := "($1, $2, $%d, $%d, $%d, $%d)"
|
||||||
if rq.db.Dialect == dbutil.SQLite {
|
if rq.GetDB().Dialect == dbutil.SQLite {
|
||||||
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
|
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
|
||||||
}
|
}
|
||||||
placeholders := make([]string, len(reactions))
|
placeholders := make([]string, len(reactions))
|
||||||
|
@ -127,11 +111,9 @@ func (rq *ReactionQuery) MassInsert(ctx context.Context, reactions []*Reaction)
|
||||||
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4)
|
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4)
|
||||||
}
|
}
|
||||||
query := strings.Replace(insertReactionQuery, "($1, $2, $3, $4, $5, $6)", strings.Join(placeholders, ","), 1)
|
query := strings.Replace(insertReactionQuery, "($1, $2, $3, $4, $5, $6)", strings.Join(placeholders, ","), 1)
|
||||||
_, err := rq.db.Conn(ctx).ExecContext(ctx, query, params...)
|
return rq.Exec(ctx, query, params...)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Reaction) Delete(ctx context.Context) error {
|
func (r *Reaction) Delete(ctx context.Context) error {
|
||||||
_, err := r.db.Conn(ctx).ExecContext(ctx, deleteReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender)
|
return r.qh.Exec(ctx, deleteReactionQuery, r.Chat.ID, r.Chat.Receiver, r.MessageID, r.Sender)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
// mautrix-gmessages - A Matrix-Google Messages puppeting bridge.
|
||||||
// Copyright (C) 2023 Tulir Asokan
|
// Copyright (C) 2024 Tulir Asokan
|
||||||
//
|
//
|
||||||
// This program is free software: you can redistribute it and/or modify
|
// This program is free software: you can redistribute it and/or modify
|
||||||
// it under the terms of the GNU Affero General Public License as published by
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
@ -20,7 +20,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -33,17 +32,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserQuery struct {
|
type UserQuery struct {
|
||||||
db *Database
|
*dbutil.QueryHelper[*User]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uq *UserQuery) New() *User {
|
func newUser(qh *dbutil.QueryHelper[*User]) *User {
|
||||||
return &User{
|
return &User{qh: qh}
|
||||||
db: uq.db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uq *UserQuery) getDB() *Database {
|
|
||||||
return uq.db
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -63,22 +56,23 @@ const (
|
||||||
management_room=$7, space_room=$8, access_token=$9
|
management_room=$7, space_room=$8, access_token=$9
|
||||||
WHERE mxid=$1
|
WHERE mxid=$1
|
||||||
`
|
`
|
||||||
|
updateuserParticipantIDsQuery = `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`
|
||||||
)
|
)
|
||||||
|
|
||||||
func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) {
|
func (uq *UserQuery) GetAllWithSession(ctx context.Context) ([]*User, error) {
|
||||||
return getAll[*User](uq, ctx, getAllUsersWithSessionQuery)
|
return uq.QueryMany(ctx, getAllUsersWithSessionQuery)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) {
|
func (uq *UserQuery) GetAllWithDoublePuppet(ctx context.Context) ([]*User, error) {
|
||||||
return getAll[*User](uq, ctx, getAllUsersWithDoublePuppetQuery)
|
return uq.QueryMany(ctx, getAllUsersWithDoublePuppetQuery)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) {
|
func (uq *UserQuery) GetByRowID(ctx context.Context, rowID int) (*User, error) {
|
||||||
return get[*User](uq, ctx, getUserByRowIDQuery, rowID)
|
return uq.QueryOne(ctx, getUserByRowIDQuery, rowID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
|
func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
|
||||||
return get[*User](uq, ctx, getUserByMXIDQuery, userID)
|
return uq.QueryOne(ctx, getUserByMXIDQuery, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Settings struct {
|
type Settings struct {
|
||||||
|
@ -90,7 +84,7 @@ type Settings struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
db *Database
|
qh *dbutil.QueryHelper[*User]
|
||||||
|
|
||||||
RowID int
|
RowID int
|
||||||
MXID id.UserID
|
MXID id.UserID
|
||||||
|
@ -114,10 +108,11 @@ type User struct {
|
||||||
func (user *User) Scan(row dbutil.Scannable) (*User, error) {
|
func (user *User) Scan(row dbutil.Scannable) (*User, error) {
|
||||||
var phoneID, session, managementRoom, spaceRoom, accessToken sql.NullString
|
var phoneID, session, managementRoom, spaceRoom, accessToken sql.NullString
|
||||||
var selfParticipantIDs, simMetadata, settings string
|
var selfParticipantIDs, simMetadata, settings string
|
||||||
err := row.Scan(&user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &simMetadata, &settings, &managementRoom, &spaceRoom, &accessToken)
|
err := row.Scan(
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
&user.RowID, &user.MXID, &phoneID, &session, &selfParticipantIDs, &simMetadata,
|
||||||
return nil, nil
|
&settings, &managementRoom, &spaceRoom, &accessToken,
|
||||||
} else if err != nil {
|
)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if session.String != "" {
|
if session.String != "" {
|
||||||
|
@ -152,24 +147,12 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) sqlVariables() []any {
|
func (user *User) sqlVariables() []any {
|
||||||
var phoneID, session, managementRoom, spaceRoom, accessToken *string
|
var session *string
|
||||||
if user.PhoneID != "" {
|
|
||||||
phoneID = &user.PhoneID
|
|
||||||
}
|
|
||||||
if user.Session != nil {
|
if user.Session != nil {
|
||||||
data, _ := json.Marshal(user.Session)
|
data, _ := json.Marshal(user.Session)
|
||||||
strData := string(data)
|
strData := string(data)
|
||||||
session = &strData
|
session = &strData
|
||||||
}
|
}
|
||||||
if user.ManagementRoom != "" {
|
|
||||||
managementRoom = (*string)(&user.ManagementRoom)
|
|
||||||
}
|
|
||||||
if user.SpaceRoom != "" {
|
|
||||||
spaceRoom = (*string)(&user.SpaceRoom)
|
|
||||||
}
|
|
||||||
if user.AccessToken != "" {
|
|
||||||
accessToken = &user.AccessToken
|
|
||||||
}
|
|
||||||
user.selfParticipantIDsLock.RLock()
|
user.selfParticipantIDsLock.RLock()
|
||||||
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
|
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
|
||||||
user.selfParticipantIDsLock.RUnlock()
|
user.selfParticipantIDsLock.RUnlock()
|
||||||
|
@ -183,7 +166,10 @@ func (user *User) sqlVariables() []any {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return []any{user.MXID, phoneID, session, string(selfParticipantIDs), string(simMetadata), string(settings), managementRoom, spaceRoom, accessToken}
|
return []any{
|
||||||
|
user.MXID, dbutil.StrPtr(user.PhoneID), session, string(selfParticipantIDs), string(simMetadata),
|
||||||
|
string(settings), dbutil.StrPtr(user.ManagementRoom), dbutil.StrPtr(user.SpaceRoom), dbutil.StrPtr(user.AccessToken),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) IsSelfParticipantID(id string) bool {
|
func (user *User) IsSelfParticipantID(id string) bool {
|
||||||
|
@ -268,20 +254,18 @@ func (user *User) AddSelfParticipantID(ctx context.Context, id string) error {
|
||||||
if !slices.Contains(user.SelfParticipantIDs, id) {
|
if !slices.Contains(user.SelfParticipantIDs, id) {
|
||||||
user.SelfParticipantIDs = append(user.SelfParticipantIDs, id)
|
user.SelfParticipantIDs = append(user.SelfParticipantIDs, id)
|
||||||
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
|
selfParticipantIDs, _ := json.Marshal(user.SelfParticipantIDs)
|
||||||
_, err := user.db.Conn(ctx).ExecContext(ctx, `UPDATE "user" SET self_participant_ids=$2 WHERE mxid=$1`, user.MXID, selfParticipantIDs)
|
return user.qh.Exec(ctx, updateuserParticipantIDsQuery, user.MXID, selfParticipantIDs)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) Insert(ctx context.Context) error {
|
func (user *User) Insert(ctx context.Context) error {
|
||||||
err := user.db.Conn(ctx).
|
err := user.qh.GetDB().
|
||||||
QueryRowContext(ctx, insertUserQuery, user.sqlVariables()...).
|
QueryRow(ctx, insertUserQuery, user.sqlVariables()...).
|
||||||
Scan(&user.RowID)
|
Scan(&user.RowID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) Update(ctx context.Context) error {
|
func (user *User) Update(ctx context.Context) error {
|
||||||
_, err := user.db.Conn(ctx).ExecContext(ctx, updateUserQuery, user.sqlVariables()...)
|
return user.qh.Exec(ctx, updateUserQuery, user.sqlVariables()...)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue