Skip to content

Commit c9245bc

Browse files
authored
fix: enhance transaction functionality (#1281)
### Motivation Various fixes and refactoring for transaction. ### Modifications * Employ context in the `Commit` and `Abort` methods * Use client operation timeout * Use `atomic.Int32` for the state * Make all state reads atomic * Clean up and improve error messages
1 parent 0612938 commit c9245bc

File tree

5 files changed

+117
-94
lines changed

5 files changed

+117
-94
lines changed

.golangci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# Run `make lint` from the root path of this project to check code with golangci-lint.
1919

2020
run:
21-
deadline: 6m
21+
timeout: 5m
2222

2323
linters:
2424
# Uncomment this line to run only the explicitly enabled linters

pulsar/consumer_partition.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,8 @@ func (pc *partitionConsumer) internalAckWithTxn(req *ackWithTxnRequest) {
532532
req.err = newError(ConsumerClosed, "Failed to ack by closing or closed consumer")
533533
return
534534
}
535-
if req.Transaction.state != TxnOpen {
536-
pc.log.WithField("state", req.Transaction.state).Error("Failed to ack by a non-open transaction.")
535+
if req.Transaction.state.Load() != int32(TxnOpen) {
536+
pc.log.WithField("state", req.Transaction.state.Load()).Error("Failed to ack by a non-open transaction.")
537537
req.err = newError(InvalidStatus, "Failed to ack by a non-open transaction.")
538538
return
539539
}

pulsar/producer_partition.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,8 +1135,8 @@ func (p *partitionProducer) prepareTransaction(sr *sendRequest) error {
11351135
}
11361136

11371137
txn := (sr.msg.Transaction).(*transaction)
1138-
if txn.state != TxnOpen {
1139-
p.log.WithField("state", txn.state).Error("Failed to send message" +
1138+
if txn.state.Load() != int32(TxnOpen) {
1139+
p.log.WithField("state", txn.state.Load()).Error("Failed to send message" +
11401140
" by a non-open transaction.")
11411141
return joinErrors(ErrTransaction,
11421142
fmt.Errorf("failed to send message by a non-open transaction"))

pulsar/transaction_impl.go

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package pulsar
1919

2020
import (
2121
"context"
22+
"errors"
23+
"fmt"
2224
"sync"
2325
"sync/atomic"
2426
"time"
@@ -33,9 +35,9 @@ type subscription struct {
3335
}
3436

3537
type transaction struct {
36-
sync.Mutex
38+
mu sync.Mutex
3739
txnID TxnID
38-
state TxnState
40+
state atomic.Int32
3941
tcClient *transactionCoordinatorClient
4042
registerPartitions map[string]bool
4143
registerAckSubscriptions map[subscription]bool
@@ -54,96 +56,106 @@ type transaction struct {
5456
// 1. When the transaction is committed or aborted, a bool will be read from opsFlow chan.
5557
// 2. When the opsCount increment from 0 to 1, a bool will be read from opsFlow chan.
5658
opsFlow chan bool
57-
opsCount int32
59+
opsCount atomic.Int32
5860
opTimeout time.Duration
5961
log log.Logger
6062
}
6163

6264
func newTransaction(id TxnID, tcClient *transactionCoordinatorClient, timeout time.Duration) *transaction {
6365
transaction := &transaction{
6466
txnID: id,
65-
state: TxnOpen,
6667
registerPartitions: make(map[string]bool),
6768
registerAckSubscriptions: make(map[subscription]bool),
6869
opsFlow: make(chan bool, 1),
69-
opTimeout: 5 * time.Second,
70+
opTimeout: tcClient.client.operationTimeout,
7071
tcClient: tcClient,
7172
}
72-
//This means there are not pending requests with this transaction. The transaction can be committed or aborted.
73+
transaction.state.Store(int32(TxnOpen))
74+
// This means there are not pending requests with this transaction. The transaction can be committed or aborted.
7375
transaction.opsFlow <- true
7476
go func() {
75-
//Set the state of the transaction to timeout after timeout
77+
// Set the state of the transaction to timeout after timeout
7678
<-time.After(timeout)
77-
atomic.CompareAndSwapInt32((*int32)(&transaction.state), int32(TxnOpen), int32(TxnTimeout))
79+
transaction.state.CompareAndSwap(int32(TxnOpen), int32(TxnTimeout))
7880
}()
7981
transaction.log = tcClient.log.SubLogger(log.Fields{})
8082
return transaction
8183
}
8284

8385
func (txn *transaction) GetState() TxnState {
84-
return txn.state
86+
return TxnState(txn.state.Load())
8587
}
8688

87-
func (txn *transaction) Commit(_ context.Context) error {
88-
if !(atomic.CompareAndSwapInt32((*int32)(&txn.state), int32(TxnOpen), int32(TxnCommitting)) ||
89-
txn.state == TxnCommitting) {
90-
return newError(InvalidStatus, "Expect transaction state is TxnOpen but "+txn.state.string())
89+
func (txn *transaction) Commit(ctx context.Context) error {
90+
if !(txn.state.CompareAndSwap(int32(TxnOpen), int32(TxnCommitting))) {
91+
txnState := txn.state.Load()
92+
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
9193
}
9294

93-
//Wait for all operations to complete
95+
// Wait for all operations to complete
9496
select {
9597
case <-txn.opsFlow:
98+
case <-ctx.Done():
99+
txn.state.Store(int32(TxnOpen))
100+
return ctx.Err()
96101
case <-time.After(txn.opTimeout):
102+
txn.state.Store(int32(TxnTimeout))
97103
return newError(TimeoutError, "There are some operations that are not completed after the timeout.")
98104
}
99-
//Send commit transaction command to transaction coordinator
105+
// Send commit transaction command to transaction coordinator
100106
err := txn.tcClient.endTxn(&txn.txnID, pb.TxnAction_COMMIT)
101107
if err == nil {
102-
atomic.StoreInt32((*int32)(&txn.state), int32(TxnCommitted))
108+
txn.state.Store(int32(TxnCommitted))
103109
} else {
104-
if e, ok := err.(*Error); ok && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
105-
atomic.StoreInt32((*int32)(&txn.state), int32(TxnError))
110+
var e *Error
111+
if errors.As(err, &e) && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
112+
txn.state.Store(int32(TxnError))
106113
return err
107114
}
108115
txn.opsFlow <- true
109116
}
110117
return err
111118
}
112119

113-
func (txn *transaction) Abort(_ context.Context) error {
114-
if !(atomic.CompareAndSwapInt32((*int32)(&txn.state), int32(TxnOpen), int32(TxnAborting)) ||
115-
txn.state == TxnAborting) {
116-
return newError(InvalidStatus, "Expect transaction state is TxnOpen but "+txn.state.string())
120+
func (txn *transaction) Abort(ctx context.Context) error {
121+
if !(txn.state.CompareAndSwap(int32(TxnOpen), int32(TxnAborting))) {
122+
txnState := txn.state.Load()
123+
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
117124
}
118125

119-
//Wait for all operations to complete
126+
// Wait for all operations to complete
120127
select {
121128
case <-txn.opsFlow:
129+
case <-ctx.Done():
130+
txn.state.Store(int32(TxnOpen))
131+
return ctx.Err()
122132
case <-time.After(txn.opTimeout):
133+
txn.state.Store(int32(TxnTimeout))
123134
return newError(TimeoutError, "There are some operations that are not completed after the timeout.")
124135
}
125-
//Send abort transaction command to transaction coordinator
136+
// Send abort transaction command to transaction coordinator
126137
err := txn.tcClient.endTxn(&txn.txnID, pb.TxnAction_ABORT)
127138
if err == nil {
128-
atomic.StoreInt32((*int32)(&txn.state), int32(TxnAborted))
139+
txn.state.Store(int32(TxnAborted))
129140
} else {
130-
if e, ok := err.(*Error); ok && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
131-
atomic.StoreInt32((*int32)(&txn.state), int32(TxnError))
132-
} else {
133-
txn.opsFlow <- true
141+
var e *Error
142+
if errors.As(err, &e) && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
143+
txn.state.Store(int32(TxnError))
144+
return err
134145
}
146+
txn.opsFlow <- true
135147
}
136148
return err
137149
}
138150

139151
func (txn *transaction) registerSendOrAckOp() error {
140-
if atomic.AddInt32(&txn.opsCount, 1) == 1 {
141-
//There are new operations that not completed
152+
if txn.opsCount.Add(1) == 1 {
153+
// There are new operations that were not completed
142154
select {
143155
case <-txn.opsFlow:
144156
return nil
145157
case <-time.After(txn.opTimeout):
146-
if _, err := txn.checkIfOpen(); err != nil {
158+
if err := txn.verifyOpen(); err != nil {
147159
return err
148160
}
149161
return newError(TimeoutError, "Failed to get the semaphore to register the send/ack operation")
@@ -154,23 +166,22 @@ func (txn *transaction) registerSendOrAckOp() error {
154166

155167
func (txn *transaction) endSendOrAckOp(err error) {
156168
if err != nil {
157-
atomic.StoreInt32((*int32)(&txn.state), int32(TxnError))
169+
txn.state.Store(int32(TxnError))
158170
}
159-
if atomic.AddInt32(&txn.opsCount, -1) == 0 {
160-
//This means there are not pending send/ack requests
171+
if txn.opsCount.Add(-1) == 0 {
172+
// This means there are no pending send/ack requests
161173
txn.opsFlow <- true
162174
}
163175
}
164176

165177
func (txn *transaction) registerProducerTopic(topic string) error {
166-
isOpen, err := txn.checkIfOpen()
167-
if !isOpen {
178+
if err := txn.verifyOpen(); err != nil {
168179
return err
169180
}
170181
_, ok := txn.registerPartitions[topic]
171182
if !ok {
172-
txn.Lock()
173-
defer txn.Unlock()
183+
txn.mu.Lock()
184+
defer txn.mu.Unlock()
174185
if _, ok = txn.registerPartitions[topic]; !ok {
175186
err := txn.tcClient.addPublishPartitionToTxn(&txn.txnID, []string{topic})
176187
if err != nil {
@@ -183,8 +194,7 @@ func (txn *transaction) registerProducerTopic(topic string) error {
183194
}
184195

185196
func (txn *transaction) registerAckTopic(topic string, subName string) error {
186-
isOpen, err := txn.checkIfOpen()
187-
if !isOpen {
197+
if err := txn.verifyOpen(); err != nil {
188198
return err
189199
}
190200
sub := subscription{
@@ -193,8 +203,8 @@ func (txn *transaction) registerAckTopic(topic string, subName string) error {
193203
}
194204
_, ok := txn.registerAckSubscriptions[sub]
195205
if !ok {
196-
txn.Lock()
197-
defer txn.Unlock()
206+
txn.mu.Lock()
207+
defer txn.mu.Unlock()
198208
if _, ok = txn.registerAckSubscriptions[sub]; !ok {
199209
err := txn.tcClient.addSubscriptionToTxn(&txn.txnID, topic, subName)
200210
if err != nil {
@@ -210,14 +220,15 @@ func (txn *transaction) GetTxnID() TxnID {
210220
return txn.txnID
211221
}
212222

213-
func (txn *transaction) checkIfOpen() (bool, error) {
214-
if txn.state == TxnOpen {
215-
return true, nil
223+
func (txn *transaction) verifyOpen() error {
224+
txnState := txn.state.Load()
225+
if txnState != int32(TxnOpen) {
226+
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
216227
}
217-
return false, newError(InvalidStatus, "Expect transaction state is TxnOpen but "+txn.state.string())
228+
return nil
218229
}
219230

220-
func (state TxnState) string() string {
231+
func (state TxnState) String() string {
221232
switch state {
222233
case TxnOpen:
223234
return "TxnOpen"
@@ -237,3 +248,8 @@ func (state TxnState) string() string {
237248
return "Unknown"
238249
}
239250
}
251+
252+
//nolint:unparam
253+
func txnStateErrorMessage(expected, actual TxnState) string {
254+
return fmt.Sprintf("Expected transaction state: %s, actual: %s", expected, actual)
255+
}

0 commit comments

Comments
 (0)