Skip to content

Commit 64d6e97

Browse files
committed
Re-add savepoint method to Transaction
Revives #184. The rewrite for async/await and Tokio accidentally lost functionality that allowed users to assign specific names to savepoints when using nested transactions. This functionality had originally been added in #184 and had been updated in #374. This commit revives this functionality using a similar scheme to the one that existed before. This should allow CockroachDB users to update to the next patch release of version `0.17`.
1 parent e3d3c6d commit 64d6e97

File tree

3 files changed

+97
-16
lines changed

3 files changed

+97
-16
lines changed

postgres/src/test.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,57 @@ fn nested_transactions() {
151151
assert_eq!(rows[2].get::<_, i32>(0), 4);
152152
}
153153

154+
#[test]
155+
fn savepoints() {
156+
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
157+
158+
client
159+
.batch_execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)")
160+
.unwrap();
161+
162+
let mut transaction = client.transaction().unwrap();
163+
164+
transaction
165+
.execute("INSERT INTO foo (id) VALUES (1)", &[])
166+
.unwrap();
167+
168+
let mut savepoint1 = transaction.savepoint("savepoint1").unwrap();
169+
170+
savepoint1
171+
.execute("INSERT INTO foo (id) VALUES (2)", &[])
172+
.unwrap();
173+
174+
savepoint1.rollback().unwrap();
175+
176+
let rows = transaction
177+
.query("SELECT id FROM foo ORDER BY id", &[])
178+
.unwrap();
179+
assert_eq!(rows.len(), 1);
180+
assert_eq!(rows[0].get::<_, i32>(0), 1);
181+
182+
let mut savepoint2 = transaction.savepoint("savepoint2").unwrap();
183+
184+
savepoint2
185+
.execute("INSERT INTO foo (id) VALUES(3)", &[])
186+
.unwrap();
187+
188+
let mut savepoint3 = savepoint2.savepoint("savepoint3").unwrap();
189+
190+
savepoint3
191+
.execute("INSERT INTO foo (id) VALUES(4)", &[])
192+
.unwrap();
193+
194+
savepoint3.commit().unwrap();
195+
savepoint2.commit().unwrap();
196+
transaction.commit().unwrap();
197+
198+
let rows = client.query("SELECT id FROM foo ORDER BY id", &[]).unwrap();
199+
assert_eq!(rows.len(), 3);
200+
assert_eq!(rows[0].get::<_, i32>(0), 1);
201+
assert_eq!(rows[1].get::<_, i32>(0), 3);
202+
assert_eq!(rows[2].get::<_, i32>(0), 4);
203+
}
204+
154205
#[test]
155206
fn copy_in() {
156207
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();

postgres/src/transaction.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,23 @@ impl<'a> Transaction<'a> {
173173
CancelToken::new(self.transaction.cancel_token())
174174
}
175175

176-
/// Like `Client::transaction`.
176+
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
177177
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
178178
let transaction = self.connection.block_on(self.transaction.transaction())?;
179179
Ok(Transaction {
180180
connection: self.connection.as_ref(),
181181
transaction,
182182
})
183183
}
184+
/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
185+
pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
186+
where
187+
I: Into<String>,
188+
{
189+
let transaction = self.connection.block_on(self.transaction.savepoint(name))?;
190+
Ok(Transaction {
191+
connection: self.connection.as_ref(),
192+
transaction,
193+
})
194+
}
184195
}

tokio-postgres/src/transaction.rs

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,26 @@ use tokio::io::{AsyncRead, AsyncWrite};
2323
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
2424
pub struct Transaction<'a> {
2525
client: &'a mut Client,
26-
depth: u32,
26+
savepoint: Option<Savepoint>,
2727
done: bool,
2828
}
2929

30+
/// A representation of a PostgreSQL database savepoint.
31+
struct Savepoint {
32+
name: String,
33+
depth: u32,
34+
}
35+
3036
impl<'a> Drop for Transaction<'a> {
3137
fn drop(&mut self) {
3238
if self.done {
3339
return;
3440
}
3541

36-
let query = if self.depth == 0 {
37-
"ROLLBACK".to_string()
42+
let query = if let Some(sp) = self.savepoint.as_ref() {
43+
format!("ROLLBACK TO {}", sp.name)
3844
} else {
39-
format!("ROLLBACK TO sp{}", self.depth)
45+
"ROLLBACK".to_string()
4046
};
4147
let buf = self.client.inner().with_buf(|buf| {
4248
frontend::query(&query, buf).unwrap();
@@ -53,18 +59,18 @@ impl<'a> Transaction<'a> {
5359
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
5460
Transaction {
5561
client,
56-
depth: 0,
62+
savepoint: None,
5763
done: false,
5864
}
5965
}
6066

6167
/// Consumes the transaction, committing all changes made within it.
6268
pub async fn commit(mut self) -> Result<(), Error> {
6369
self.done = true;
64-
let query = if self.depth == 0 {
65-
"COMMIT".to_string()
70+
let query = if let Some(sp) = self.savepoint.as_ref() {
71+
format!("RELEASE {}", sp.name)
6672
} else {
67-
format!("RELEASE sp{}", self.depth)
73+
"COMMIT".to_string()
6874
};
6975
self.client.batch_execute(&query).await
7076
}
@@ -74,10 +80,10 @@ impl<'a> Transaction<'a> {
7480
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
7581
pub async fn rollback(mut self) -> Result<(), Error> {
7682
self.done = true;
77-
let query = if self.depth == 0 {
78-
"ROLLBACK".to_string()
83+
let query = if let Some(sp) = self.savepoint.as_ref() {
84+
format!("ROLLBACK TO {}", sp.name)
7985
} else {
80-
format!("ROLLBACK TO sp{}", self.depth)
86+
"ROLLBACK".to_string()
8187
};
8288
self.client.batch_execute(&query).await
8389
}
@@ -272,15 +278,28 @@ impl<'a> Transaction<'a> {
272278
self.client.cancel_query_raw(stream, tls).await
273279
}
274280

275-
/// Like `Client::transaction`.
281+
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
276282
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
277-
let depth = self.depth + 1;
278-
let query = format!("SAVEPOINT sp{}", depth);
283+
self._savepoint(None).await
284+
}
285+
286+
/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
287+
pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
288+
where
289+
I: Into<String>,
290+
{
291+
self._savepoint(Some(name.into())).await
292+
}
293+
294+
async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
295+
let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
296+
let name = name.unwrap_or_else(|| format!("sp_{}", depth));
297+
let query = format!("SAVEPOINT {}", name);
279298
self.batch_execute(&query).await?;
280299

281300
Ok(Transaction {
282301
client: self.client,
283-
depth,
302+
savepoint: Some(Savepoint { name, depth }),
284303
done: false,
285304
})
286305
}

0 commit comments

Comments
 (0)