Skip to content

Commit a4a68d5

Browse files
committed
Ensure transactions roll back immediately on drop
Closes #635
1 parent 4fd7527 commit a4a68d5

File tree

2 files changed

+83
-33
lines changed

2 files changed

+83
-33
lines changed

postgres/src/test.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,31 @@ fn transaction_drop() {
100100
assert_eq!(rows.len(), 0);
101101
}
102102

103+
#[test]
104+
fn transaction_drop_immediate_rollback() {
105+
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
106+
let mut client2 = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
107+
108+
client
109+
.simple_query("CREATE TABLE IF NOT EXISTS foo (id SERIAL PRIMARY KEY)")
110+
.unwrap();
111+
112+
client
113+
.execute("INSERT INTO foo VALUES (1) ON CONFLICT DO NOTHING", &[])
114+
.unwrap();
115+
116+
let mut transaction = client.transaction().unwrap();
117+
118+
transaction
119+
.execute("SELECT * FROM foo FOR UPDATE", &[])
120+
.unwrap();
121+
122+
drop(transaction);
123+
124+
let rows = client2.query("SELECT * FROM foo FOR UPDATE", &[]).unwrap();
125+
assert_eq!(rows.len(), 1);
126+
}
127+
103128
#[test]
104129
fn nested_transactions() {
105130
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();

postgres/src/transaction.rs

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@ use tokio_postgres::{Error, Row, SimpleQueryMessage};
99
/// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints.
1010
pub struct Transaction<'a> {
1111
connection: ConnectionRef<'a>,
12-
transaction: tokio_postgres::Transaction<'a>,
12+
transaction: Option<tokio_postgres::Transaction<'a>>,
13+
}
14+
15+
impl<'a> Drop for Transaction<'a> {
16+
fn drop(&mut self) {
17+
if let Some(transaction) = self.transaction.take() {
18+
let _ = self.connection.block_on(transaction.rollback());
19+
}
20+
}
1321
}
1422

1523
impl<'a> Transaction<'a> {
@@ -19,31 +27,38 @@ impl<'a> Transaction<'a> {
1927
) -> Transaction<'a> {
2028
Transaction {
2129
connection,
22-
transaction,
30+
transaction: Some(transaction),
2331
}
2432
}
2533

2634
/// Consumes the transaction, committing all changes made within it.
2735
pub fn commit(mut self) -> Result<(), Error> {
28-
self.connection.block_on(self.transaction.commit())
36+
self.connection
37+
.block_on(self.transaction.take().unwrap().commit())
2938
}
3039

3140
/// Rolls the transaction back, discarding all changes made within it.
3241
///
3342
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
3443
pub fn rollback(mut self) -> Result<(), Error> {
35-
self.connection.block_on(self.transaction.rollback())
44+
self.connection
45+
.block_on(self.transaction.take().unwrap().rollback())
3646
}
3747

3848
/// Like `Client::prepare`.
3949
pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
40-
self.connection.block_on(self.transaction.prepare(query))
50+
self.connection
51+
.block_on(self.transaction.as_ref().unwrap().prepare(query))
4152
}
4253

4354
/// Like `Client::prepare_typed`.
4455
pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
45-
self.connection
46-
.block_on(self.transaction.prepare_typed(query, types))
56+
self.connection.block_on(
57+
self.transaction
58+
.as_ref()
59+
.unwrap()
60+
.prepare_typed(query, types),
61+
)
4762
}
4863

4964
/// Like `Client::execute`.
@@ -52,7 +67,7 @@ impl<'a> Transaction<'a> {
5267
T: ?Sized + ToStatement,
5368
{
5469
self.connection
55-
.block_on(self.transaction.execute(query, params))
70+
.block_on(self.transaction.as_ref().unwrap().execute(query, params))
5671
}
5772

5873
/// Like `Client::query`.
@@ -61,7 +76,7 @@ impl<'a> Transaction<'a> {
6176
T: ?Sized + ToStatement,
6277
{
6378
self.connection
64-
.block_on(self.transaction.query(query, params))
79+
.block_on(self.transaction.as_ref().unwrap().query(query, params))
6580
}
6681

6782
/// Like `Client::query_one`.
@@ -70,7 +85,7 @@ impl<'a> Transaction<'a> {
7085
T: ?Sized + ToStatement,
7186
{
7287
self.connection
73-
.block_on(self.transaction.query_one(query, params))
88+
.block_on(self.transaction.as_ref().unwrap().query_one(query, params))
7489
}
7590

7691
/// Like `Client::query_opt`.
@@ -83,7 +98,7 @@ impl<'a> Transaction<'a> {
8398
T: ?Sized + ToStatement,
8499
{
85100
self.connection
86-
.block_on(self.transaction.query_opt(query, params))
101+
.block_on(self.transaction.as_ref().unwrap().query_opt(query, params))
87102
}
88103

89104
/// Like `Client::query_raw`.
@@ -95,7 +110,7 @@ impl<'a> Transaction<'a> {
95110
{
96111
let stream = self
97112
.connection
98-
.block_on(self.transaction.query_raw(query, params))?;
113+
.block_on(self.transaction.as_ref().unwrap().query_raw(query, params))?;
99114
Ok(RowIter::new(self.connection.as_ref(), stream))
100115
}
101116

@@ -114,16 +129,20 @@ impl<'a> Transaction<'a> {
114129
T: ?Sized + ToStatement,
115130
{
116131
self.connection
117-
.block_on(self.transaction.bind(query, params))
132+
.block_on(self.transaction.as_ref().unwrap().bind(query, params))
118133
}
119134

120135
/// Continues execution of a portal, returning the next set of rows.
121136
///
122137
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
123138
/// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
124139
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
125-
self.connection
126-
.block_on(self.transaction.query_portal(portal, max_rows))
140+
self.connection.block_on(
141+
self.transaction
142+
.as_ref()
143+
.unwrap()
144+
.query_portal(portal, max_rows),
145+
)
127146
}
128147

129148
/// The maximally flexible version of `query_portal`.
@@ -132,9 +151,12 @@ impl<'a> Transaction<'a> {
132151
portal: &Portal,
133152
max_rows: i32,
134153
) -> Result<RowIter<'_>, Error> {
135-
let stream = self
136-
.connection
137-
.block_on(self.transaction.query_portal_raw(portal, max_rows))?;
154+
let stream = self.connection.block_on(
155+
self.transaction
156+
.as_ref()
157+
.unwrap()
158+
.query_portal_raw(portal, max_rows),
159+
)?;
138160
Ok(RowIter::new(self.connection.as_ref(), stream))
139161
}
140162

@@ -143,7 +165,9 @@ impl<'a> Transaction<'a> {
143165
where
144166
T: ?Sized + ToStatement,
145167
{
146-
let sink = self.connection.block_on(self.transaction.copy_in(query))?;
168+
let sink = self
169+
.connection
170+
.block_on(self.transaction.as_ref().unwrap().copy_in(query))?;
147171
Ok(CopyInWriter::new(self.connection.as_ref(), sink))
148172
}
149173

@@ -152,44 +176,45 @@ impl<'a> Transaction<'a> {
152176
where
153177
T: ?Sized + ToStatement,
154178
{
155-
let stream = self.connection.block_on(self.transaction.copy_out(query))?;
179+
let stream = self
180+
.connection
181+
.block_on(self.transaction.as_ref().unwrap().copy_out(query))?;
156182
Ok(CopyOutReader::new(self.connection.as_ref(), stream))
157183
}
158184

159185
/// Like `Client::simple_query`.
160186
pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
161187
self.connection
162-
.block_on(self.transaction.simple_query(query))
188+
.block_on(self.transaction.as_ref().unwrap().simple_query(query))
163189
}
164190

165191
/// Like `Client::batch_execute`.
166192
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
167193
self.connection
168-
.block_on(self.transaction.batch_execute(query))
194+
.block_on(self.transaction.as_ref().unwrap().batch_execute(query))
169195
}
170196

171197
/// Like `Client::cancel_token`.
172198
pub fn cancel_token(&self) -> CancelToken {
173-
CancelToken::new(self.transaction.cancel_token())
199+
CancelToken::new(self.transaction.as_ref().unwrap().cancel_token())
174200
}
175201

176202
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
177203
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
178-
let transaction = self.connection.block_on(self.transaction.transaction())?;
179-
Ok(Transaction {
180-
connection: self.connection.as_ref(),
181-
transaction,
182-
})
204+
let transaction = self
205+
.connection
206+
.block_on(self.transaction.as_mut().unwrap().transaction())?;
207+
Ok(Transaction::new(self.connection.as_ref(), transaction))
183208
}
209+
184210
/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
185211
pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
186212
where
187213
I: Into<String>,
188214
{
189-
let transaction = self.connection.block_on(self.transaction.savepoint(name))?;
190-
Ok(Transaction {
191-
connection: self.connection.as_ref(),
192-
transaction,
193-
})
215+
let transaction = self
216+
.connection
217+
.block_on(self.transaction.as_mut().unwrap().savepoint(name))?;
218+
Ok(Transaction::new(self.connection.as_ref(), transaction))
194219
}
195220
}

0 commit comments

Comments
 (0)