Skip to content

Commit

Permalink
Added support for transactions
Browse files Browse the repository at this point in the history
Resolves #53
  • Loading branch information
saberder committed Feb 18, 2025
1 parent 8aa32c7 commit 9a2efb5
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 174 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.PHONY: mockgen
mockgen:
mockgen --destination mocks/row.go --package=mocks --build_flags=--mod=mod github.com/jackc/pgx/v5 Row
mockgen --source=pgmq.go --destination mocks/pgmq.go --package=mocks
go run go.uber.org/mock/mockgen --destination mocks/row.go --package=mocks --build_flags=--mod=mod github.com/jackc/pgx/v5 Row
go run go.uber.org/mock/mockgen --source=pgmq.go --destination mocks/pgmq.go --package=mocks

.PHONY: test
test: mockgen
Expand Down
32 changes: 27 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ A Go (Golang) client for
[Postgres Message Queue](https://github.com/tembo-io/pgmq) (PGMQ). Based loosely
on the [Rust client](https://github.com/tembo-io/pgmq/tree/main/pgmq-rs).

`pgmq-go` works with [pgx](https://github.com/jackc/pgx). The second argument of most functions only needs to satisfy the [DB](https://pkg.go.dev/github.com/craigpastro/pgmq-go#DB) interface, which means it can take, among others, a `*pgx.Conn`, `*pgxpool.Pool`, or `pgx.Tx`.

## Usage

Start a Postgres instance with the PGMQ extension installed:
Expand All @@ -32,33 +34,53 @@ import (
func main() {
ctx := context.Background()

q, err := pgmq.New(ctx, "postgres://postgres:password@localhost:5432/postgres")
pool, err := pgmq.NewPgxPool(ctx, "postgres://postgres:password@localhost:5432/postgres")
if err != nil {
panic(err)
}

err = pgmq.CreatePGMQExtension(ctx, pool)
if err != nil {
panic(err)
}

err = pgmq.CreateQueue(ctx, pool, "my_queue")
if err != nil {
panic(err)
}

err = q.CreateQueue(ctx, "my_queue")
// We can perform various queue operations using a transaction.
tx, err := pool.Begin(ctx)
if err != nil {
panic(err)
}

id, err := q.Send(ctx, "my_queue", json.RawMessage(`{"foo": "bar"}`))
id, err := pgmq.Send(ctx, tx, "my_queue", json.RawMessage(`{"foo": "bar"}`))
if err != nil {
panic(err)
}

msg, err := q.Read(ctx, "my_queue", 30)
msg, err := pgmq.Read(ctx, tx, "my_queue", 30)
if err != nil {
panic(err)
}

// Archive the message by moving it to the "pgmq.a_<queue_name>" table.
// Alternatively, you can `Delete` the message, or read and delete in one
// call by using `Pop`.
_, err = q.Archive(ctx, "my_queue", id)
_, err = pgmq.Archive(ctx, tx, "my_queue", id)
if err != nil {
panic(err)
}

// Commit the transaction.
err = tx.Commit(ctx)
if err != nil {
panic(err)
}

// Close the connection pool.
pool.Close()
}
```

Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ require (
go.opentelemetry.io/otel/sdk v1.21.0 // indirect
go.opentelemetry.io/otel/trace v1.31.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
golang.org/x/tools v0.22.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
google.golang.org/protobuf v1.33.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
Expand Down Expand Up @@ -176,6 +178,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand Down
26 changes: 0 additions & 26 deletions mocks/pgmq.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 39 additions & 65 deletions pgmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,13 @@ type Message struct {
}

type DB interface {
Ping(ctx context.Context) error
Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
Close()
}

type PGMQ struct {
db DB
}

// New uses the connString to attempt to establish a connection to Postgres.
// Once a connetion is established it will create the PGMQ extension if it
// does not already exist.
func New(ctx context.Context, connString string) (*PGMQ, error) {
// NewPgxPool is a convenience function for creating a new *pgxpool.Pool.
func NewPgxPool(ctx context.Context, connString string) (*pgxpool.Pool, error) {
cfg, err := pgxpool.ParseConfig(connString)
if err != nil {
return nil, fmt.Errorf("error parsing connection string: %w", err)
Expand All @@ -53,41 +45,23 @@ func New(ctx context.Context, connString string) (*PGMQ, error) {
return nil, fmt.Errorf("error creating pool: %w", err)
}

return NewFromDB(ctx, pool)
return pool, nil
}

// NewFromDB is a bring your own DB version of New. Given an implementation
// of DB, it will call Ping to ensure the connection has been established,
// then create the PGMQ extension if it does not already exist.
func NewFromDB(ctx context.Context, db DB) (*PGMQ, error) {
if err := db.Ping(ctx); err != nil {
return nil, err
}

// CreatePGMQExtension will create the PGMQ extension using the provided DB.
func CreatePGMQExtension(ctx context.Context, db DB) error {
_, err := db.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS pgmq CASCADE")
if err != nil {
return nil, fmt.Errorf("error creating pgmq extension: %w", err)
return fmt.Errorf("error creating pgmq extension: %w", err)
}

return &PGMQ{
db: db,
}, nil
}

// Close closes the underlying connection pool.
func (p *PGMQ) Close() {
p.db.Close()
}

// Ping calls the underlying Ping function of the DB interface.
func (p *PGMQ) Ping(ctx context.Context) error {
return p.db.Ping(ctx)
return nil
}

// CreateQueue creates a new queue. This sets up the queue's tables, indexes,
// and metadata.
func (p *PGMQ) CreateQueue(ctx context.Context, queue string) error {
_, err := p.db.Exec(ctx, "SELECT pgmq.create($1)", queue)
func CreateQueue(ctx context.Context, db DB, queue string) error {
_, err := db.Exec(ctx, "SELECT pgmq.create($1)", queue)
if err != nil {
return wrapPostgresError(err)
}
Expand All @@ -98,8 +72,8 @@ func (p *PGMQ) CreateQueue(ctx context.Context, queue string) error {
// CreateUnloggedQueue creates a new unlogged queue, which uses an unlogged
// table under the hood. This sets up the queue's tables, indexes, and
// metadata.
func (p *PGMQ) CreateUnloggedQueue(ctx context.Context, queue string) error {
_, err := p.db.Exec(ctx, "SELECT pgmq.create_unlogged($1)", queue)
func CreateUnloggedQueue(ctx context.Context, db DB, queue string) error {
_, err := db.Exec(ctx, "SELECT pgmq.create_unlogged($1)", queue)
if err != nil {
return wrapPostgresError(err)
}
Expand All @@ -109,8 +83,8 @@ func (p *PGMQ) CreateUnloggedQueue(ctx context.Context, queue string) error {

// DropQueue deletes the given queue. It deletes the queue's tables, indices,
// and metadata. It will return an error if the queue does not exist.
func (p *PGMQ) DropQueue(ctx context.Context, queue string) error {
_, err := p.db.Exec(ctx, "SELECT pgmq.drop_queue($1)", queue)
func DropQueue(ctx context.Context, db DB, queue string) error {
_, err := db.Exec(ctx, "SELECT pgmq.drop_queue($1)", queue)
if err != nil {
return wrapPostgresError(err)
}
Expand All @@ -120,15 +94,15 @@ func (p *PGMQ) DropQueue(ctx context.Context, queue string) error {

// Send sends a single message to a queue. The message id, unique to the
// queue, is returned.
func (p *PGMQ) Send(ctx context.Context, queue string, msg json.RawMessage) (int64, error) {
return p.SendWithDelay(ctx, queue, msg, 0)
func Send(ctx context.Context, db DB, queue string, msg json.RawMessage) (int64, error) {
return SendWithDelay(ctx, db, queue, msg, 0)
}

// SendWithDelay sends a single message to a queue with a delay. The delay
// is specified in seconds. The message id, unique to the queue, is returned.
func (p *PGMQ) SendWithDelay(ctx context.Context, queue string, msg json.RawMessage, delay int) (int64, error) {
func SendWithDelay(ctx context.Context, db DB, queue string, msg json.RawMessage, delay int) (int64, error) {
var msgID int64
err := p.db.
err := db.
QueryRow(ctx, "SELECT * FROM pgmq.send($1, $2, $3::int)", queue, msg, delay).
Scan(&msgID)
if err != nil {
Expand All @@ -141,9 +115,9 @@ func (p *PGMQ) SendWithDelay(ctx context.Context, queue string, msg json.RawMess
// SendWithDelayTimestamp sends a single message to a queue with a delay. The
// delay is specified as a timestamp. The message id, unique to the queue, is
// returned. Only supported in pgmq-pg17 and above.
func (p *PGMQ) SendWithDelayTimestamp(ctx context.Context, queue string, msg json.RawMessage, delay time.Time) (int64, error) {
func SendWithDelayTimestamp(ctx context.Context, db DB, queue string, msg json.RawMessage, delay time.Time) (int64, error) {
var msgID int64
err := p.db.
err := db.
QueryRow(ctx, "SELECT * FROM pgmq.send($1, $2, $3::timestamptz)", queue, msg, delay).
Scan(&msgID)
if err != nil {
Expand All @@ -155,15 +129,15 @@ func (p *PGMQ) SendWithDelayTimestamp(ctx context.Context, queue string, msg jso

// SendBatch sends a batch of messages to a queue. The message ids, unique to
// the queue, are returned.
func (p *PGMQ) SendBatch(ctx context.Context, queue string, msgs []json.RawMessage) ([]int64, error) {
return p.SendBatchWithDelay(ctx, queue, msgs, 0)
func SendBatch(ctx context.Context, db DB, queue string, msgs []json.RawMessage) ([]int64, error) {
return SendBatchWithDelay(ctx, db, queue, msgs, 0)
}

// SendBatchWithDelay sends a batch of messages to a queue with a delay. The
// delay is specified in seconds. The message ids, unique to the queue, are
// returned.
func (p *PGMQ) SendBatchWithDelay(ctx context.Context, queue string, msgs []json.RawMessage, delay int) ([]int64, error) {
rows, err := p.db.Query(ctx, "SELECT * FROM pgmq.send_batch($1, $2::jsonb[], $3::int)", queue, msgs, delay)
func SendBatchWithDelay(ctx context.Context, db DB, queue string, msgs []json.RawMessage, delay int) ([]int64, error) {
rows, err := db.Query(ctx, "SELECT * FROM pgmq.send_batch($1, $2::jsonb[], $3::int)", queue, msgs, delay)
if err != nil {
return nil, wrapPostgresError(err)
}
Expand All @@ -185,8 +159,8 @@ func (p *PGMQ) SendBatchWithDelay(ctx context.Context, queue string, msgs []json
// SendBatchWithDelayTimestamp sends a batch of messages to a queue with a
// delay. The delay is specified as a timestamp. The message ids, unique to
// the queue, are returned.
func (p *PGMQ) SendBatchWithDelayTimestamp(ctx context.Context, queue string, msgs []json.RawMessage, delay time.Time) ([]int64, error) {
rows, err := p.db.Query(ctx, "SELECT * FROM pgmq.send_batch($1, $2::jsonb[], $3::timestamptz)", queue, msgs, delay)
func SendBatchWithDelayTimestamp(ctx context.Context, db DB, queue string, msgs []json.RawMessage, delay time.Time) ([]int64, error) {
rows, err := db.Query(ctx, "SELECT * FROM pgmq.send_batch($1, $2::jsonb[], $3::timestamptz)", queue, msgs, delay)
if err != nil {
return nil, wrapPostgresError(err)
}
Expand All @@ -209,13 +183,13 @@ func (p *PGMQ) SendBatchWithDelayTimestamp(ctx context.Context, queue string, ms
// messages are invisible, an ErrNoRows errors is returned. If a message is
// returned, it is made invisible for the duration of the visibility timeout
// (vt) in seconds.
func (p *PGMQ) Read(ctx context.Context, queue string, vt int64) (*Message, error) {
func Read(ctx context.Context, db DB, queue string, vt int64) (*Message, error) {
if vt == 0 {
vt = vtDefault
}

var msg Message
rows, err := p.db.Query(ctx, "SELECT * FROM pgmq.read($1, $2, $3)", queue, vt, 1)
rows, err := db.Query(ctx, "SELECT * FROM pgmq.read($1, $2, $3)", queue, vt, 1)
if err != nil {
return nil, wrapPostgresError(err)
}
Expand Down Expand Up @@ -243,12 +217,12 @@ func (p *PGMQ) Read(ctx context.Context, queue string, vt int64) (*Message, erro
// messages that are returned are made invisible for the duration of the
// visibility timeout (vt) in seconds. If vt is 0 it will be set to the
// default value, vtDefault.
func (p *PGMQ) ReadBatch(ctx context.Context, queue string, vt int64, numMsgs int64) ([]*Message, error) {
func ReadBatch(ctx context.Context, db DB, queue string, vt int64, numMsgs int64) ([]*Message, error) {
if vt == 0 {
vt = vtDefault
}

rows, err := p.db.Query(ctx, "SELECT * FROM pgmq.read($1, $2, $3)", queue, vt, numMsgs)
rows, err := db.Query(ctx, "SELECT * FROM pgmq.read($1, $2, $3)", queue, vt, numMsgs)
if err != nil {
return nil, wrapPostgresError(err)
}
Expand Down Expand Up @@ -278,9 +252,9 @@ func (p *PGMQ) ReadBatch(ctx context.Context, queue string, vt int64, numMsgs in
// Similar to Read and ReadBatch if no messages are available an ErrNoRows is
// returned. Unlike these methods, the visibility timeout does not apply.
// This is because the message is immediately deleted.
func (p *PGMQ) Pop(ctx context.Context, queue string) (*Message, error) {
func Pop(ctx context.Context, db DB, queue string) (*Message, error) {
var msg Message
rows, err := p.db.Query(ctx, "SELECT * FROM pgmq.pop($1)", queue)
rows, err := db.Query(ctx, "SELECT * FROM pgmq.pop($1)", queue)
if err != nil {
return nil, wrapPostgresError(err)
}
Expand Down Expand Up @@ -308,9 +282,9 @@ func (p *PGMQ) Pop(ctx context.Context, queue string) (*Message, error) {
// id. View messages on the archive table with sql:
//
// SELECT * FROM pgmq.a_<queue_name>;
func (p *PGMQ) Archive(ctx context.Context, queue string, msgID int64) (bool, error) {
func Archive(ctx context.Context, db DB, queue string, msgID int64) (bool, error) {
var archived bool
err := p.db.QueryRow(ctx, "SELECT pgmq.archive($1, $2::bigint)", queue, msgID).Scan(&archived)
err := db.QueryRow(ctx, "SELECT pgmq.archive($1, $2::bigint)", queue, msgID).Scan(&archived)
if err != nil {
return false, wrapPostgresError(err)
}
Expand All @@ -322,8 +296,8 @@ func (p *PGMQ) Archive(ctx context.Context, queue string, msgID int64) (bool, er
// table by their ids. View messages on the archive table with sql:
//
// SELECT * FROM pgmq.a_<queue_name>;
func (p *PGMQ) ArchiveBatch(ctx context.Context, queue string, msgIDs []int64) ([]int64, error) {
rows, err := p.db.Query(ctx, "SELECT pgmq.archive($1, $2::bigint[])", queue, msgIDs)
func ArchiveBatch(ctx context.Context, db DB, queue string, msgIDs []int64) ([]int64, error) {
rows, err := db.Query(ctx, "SELECT pgmq.archive($1, $2::bigint[])", queue, msgIDs)
if err != nil {
return nil, wrapPostgresError(err)
}
Expand All @@ -344,9 +318,9 @@ func (p *PGMQ) ArchiveBatch(ctx context.Context, queue string, msgIDs []int64) (
// Delete deletes a message from the queue by its id. This is a permanent
// delete and cannot be undone. If you want to retain a log of the message,
// use the Archive method.
func (p *PGMQ) Delete(ctx context.Context, queue string, msgID int64) (bool, error) {
func Delete(ctx context.Context, db DB, queue string, msgID int64) (bool, error) {
var deleted bool
err := p.db.QueryRow(ctx, "SELECT pgmq.delete($1, $2::bigint)", queue, msgID).Scan(&deleted)
err := db.QueryRow(ctx, "SELECT pgmq.delete($1, $2::bigint)", queue, msgID).Scan(&deleted)
if err != nil {
return false, wrapPostgresError(err)
}
Expand All @@ -357,8 +331,8 @@ func (p *PGMQ) Delete(ctx context.Context, queue string, msgID int64) (bool, err
// DeleteBatch deletes a batch of messages from the queue by their ids. This
// is a permanent delete and cannot be undone. If you want to retain a log of
// the messages, use the ArchiveBatch method.
func (p *PGMQ) DeleteBatch(ctx context.Context, queue string, msgIDs []int64) ([]int64, error) {
rows, err := p.db.Query(ctx, "SELECT pgmq.delete($1, $2::bigint[])", queue, msgIDs)
func DeleteBatch(ctx context.Context, db DB, queue string, msgIDs []int64) ([]int64, error) {
rows, err := db.Query(ctx, "SELECT pgmq.delete($1, $2::bigint[])", queue, msgIDs)
if err != nil {
return nil, wrapPostgresError(err)
}
Expand Down
Loading

0 comments on commit 9a2efb5

Please sign in to comment.