Skip to content

Commit 70df7c7

Browse files
committed
Guard WindowPredicate with doOnDiscard(…).
Windowed fluxes now properly discard ref-counted objects avoiding memory leaks upon cancellation. [#492] Signed-off-by: Mark Paluch <[email protected]>
1 parent 46f3ed1 commit 70df7c7

File tree

6 files changed

+81
-41
lines changed

6 files changed

+81
-41
lines changed

src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java

+14-12
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import io.netty.buffer.ByteBuf;
2020
import io.netty.buffer.Unpooled;
21+
import io.netty.util.ReferenceCountUtil;
22+
import io.netty.util.ReferenceCounted;
2123
import io.r2dbc.postgresql.client.Binding;
2224
import io.r2dbc.postgresql.client.ConnectionContext;
2325
import io.r2dbc.postgresql.client.EncodedParameter;
@@ -230,20 +232,20 @@ private Flux<io.r2dbc.postgresql.api.PostgresqlResult> execute(String sql) {
230232
.doOnSubscribe(it -> bindings.emitNext(iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST));
231233

232234
}).cast(io.r2dbc.postgresql.api.PostgresqlResult.class);
233-
} else {
234-
// Simple Query protocol
235-
if (this.fetchSize != NO_LIMIT) {
236-
return ExtendedFlowDelegate.runQuery(this.resources, factory, sql, Binding.EMPTY, Collections.emptyList(), this.fetchSize)
237-
.windowUntil(WINDOW_UNTIL)
238-
.map(messages -> PostgresqlResult.toResult(this.resources, messages, factory))
239-
.as(Operators::discardOnCancel);
240-
}
235+
}
241236

242-
return SimpleQueryMessageFlow.exchange(this.resources.getClient(), sql)
243-
.windowUntil(WINDOW_UNTIL)
244-
.map(messages -> PostgresqlResult.toResult(this.resources, messages, factory))
245-
.as(Operators::discardOnCancel);
237+
Flux<BackendMessage> exchange;
238+
// Simple Query protocol
239+
if (this.fetchSize != NO_LIMIT) {
240+
exchange = ExtendedFlowDelegate.runQuery(this.resources, factory, sql, Binding.EMPTY, Collections.emptyList(), this.fetchSize);
241+
} else {
242+
exchange = SimpleQueryMessageFlow.exchange(this.resources.getClient(), sql);
246243
}
244+
245+
return exchange.windowUntil(WINDOW_UNTIL)
246+
.doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release) // ensure release of rows within WindowPredicate
247+
.map(messages -> PostgresqlResult.toResult(this.resources, messages, factory))
248+
.as(Operators::discardOnCancel);
247249
}
248250

249251
private static void tryNextBinding(Iterator<Binding> iterator, Sinks.Many<Binding> bindingSink, AtomicBoolean canceled) {

src/main/java/io/r2dbc/postgresql/message/backend/DataRow.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,14 @@ public final class DataRow extends AbstractReferenceCounted implements BackendMe
4747
* @throws IllegalArgumentException if {@code columns} is {@code null}
4848
*/
4949
public DataRow(ByteBuf... columns) {
50-
this.columns = Assert.requireNonNull(columns, "columns must not be null");
50+
51+
if (columns == null) {
52+
this.columns = new ByteBuf[0];
53+
release();
54+
throw new IllegalArgumentException("columns must not be null");
55+
}
56+
57+
this.columns = columns;
5158
}
5259

5360
@Override

src/test/java/io/r2dbc/postgresql/PostgresqlRowUnitTests.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import io.r2dbc.postgresql.codec.MockCodecs;
2121
import io.r2dbc.postgresql.message.backend.DataRow;
2222
import io.r2dbc.postgresql.message.backend.RowDescription;
23+
import io.r2dbc.postgresql.util.ReferenceCountedCleaner;
24+
import org.junit.jupiter.api.AfterEach;
2325
import org.junit.jupiter.api.Test;
2426

2527
import java.util.Arrays;
@@ -41,6 +43,8 @@
4143
*/
4244
final class PostgresqlRowUnitTests {
4345

46+
private final ReferenceCountedCleaner cleaner = new ReferenceCountedCleaner();
47+
4448
private final List<RowDescription.Field> columns = Arrays.asList(
4549
new RowDescription.Field((short) 100, 200, 300, (short) 400, FORMAT_BINARY, "test-name-1", 500),
4650
new RowDescription.Field((short) 300, 400, 300, (short) 400, FORMAT_TEXT, "test-name-2", 500),
@@ -49,6 +53,11 @@ final class PostgresqlRowUnitTests {
4953

5054
private final ByteBuf[] data = new ByteBuf[]{TEST.buffer(4).writeInt(100), TEST.buffer(4).writeInt(300), null};
5155

56+
@AfterEach
57+
void tearDown() {
58+
cleaner.clean();
59+
}
60+
5261
@Test
5362
void constructorNoContext() {
5463
assertThatIllegalArgumentException().isThrownBy(() -> new PostgresqlRow(null, null, Collections.emptyList(), null))
@@ -156,7 +165,7 @@ void toRow() {
156165
.build();
157166

158167
RowDescription description = new RowDescription(Collections.singletonList(new RowDescription.Field((short) 200, 300, (short) 400, (short) 500, FORMAT_TEXT, "test-name-1", 600)));
159-
PostgresqlRow row = PostgresqlRow.toRow(MockContext.builder().codecs(codecs).build(), new DataRow(TEST.buffer(4).writeInt(100)),
168+
PostgresqlRow row = PostgresqlRow.toRow(MockContext.builder().codecs(codecs).build(), cleaner.capture(new DataRow(TEST.buffer(4).writeInt(100))),
160169
codecs, description);
161170

162171
assertThat(row.get(0, Object.class)).isSameAs(value);
@@ -170,7 +179,7 @@ void toRowNoDataRow() {
170179

171180
@Test
172181
void toRowNoRowDescription() {
173-
assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlRow.toRow(MockContext.empty(), new DataRow(TEST.buffer(4).writeInt(100)), MockCodecs.empty(), null))
182+
assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlRow.toRow(MockContext.empty(), cleaner.capture(new DataRow(TEST.buffer(4).writeInt(100))), MockCodecs.empty(), null))
174183
.withMessage("rowDescription must not be null");
175184
}
176185

src/test/java/io/r2dbc/postgresql/message/backend/BackendMessageAssert.java

+6-24
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
package io.r2dbc.postgresql.message.backend;
1818

1919
import io.netty.buffer.ByteBuf;
20-
import io.netty.util.ReferenceCountUtil;
20+
import io.netty.util.ReferenceCounted;
21+
import io.r2dbc.postgresql.util.ReferenceCountedCleaner;
2122
import org.assertj.core.api.AbstractObjectAssert;
2223
import org.assertj.core.api.ObjectAssert;
2324
import org.springframework.util.ReflectionUtils;
2425

2526
import java.lang.reflect.Method;
26-
import java.util.ArrayList;
27-
import java.util.List;
2827
import java.util.Objects;
2928
import java.util.function.Function;
3029

@@ -35,7 +34,7 @@
3534
*/
3635
final class BackendMessageAssert extends AbstractObjectAssert<BackendMessageAssert, Class<? extends BackendMessage>> {
3736

38-
private Cleaner cleaner = new Cleaner();
37+
private ReferenceCountedCleaner cleaner = new ReferenceCountedCleaner();
3938

4039
private BackendMessageAssert(Class<? extends BackendMessage> actual) {
4140
super(actual, BackendMessageAssert.class);
@@ -45,7 +44,7 @@ static BackendMessageAssert assertThat(Class<? extends BackendMessage> actual) {
4544
return new BackendMessageAssert(actual);
4645
}
4746

48-
BackendMessageAssert cleaner(Cleaner cleaner) {
47+
BackendMessageAssert cleaner(ReferenceCountedCleaner cleaner) {
4948
this.cleaner = cleaner;
5049
return this;
5150
}
@@ -61,28 +60,11 @@ <T extends BackendMessage> ObjectAssert<T> decoded(Function<ByteBuf, ByteBuf> de
6160
ReflectionUtils.makeAccessible(method);
6261
T actual = (T) ReflectionUtils.invokeMethod(method, null, decoded.apply(TEST.buffer()));
6362

64-
return new ObjectAssert<>(this.cleaner.capture(actual));
63+
return new ObjectAssert<>((T) (actual instanceof ReferenceCounted ? this.cleaner.capture((ReferenceCounted) actual) : actual));
6564
}
6665

67-
public Cleaner cleaner() {
66+
public ReferenceCountedCleaner cleaner() {
6867
return this.cleaner;
6968
}
7069

71-
static class Cleaner {
72-
73-
private final List<Object> objects = new ArrayList<>();
74-
75-
public void clean() {
76-
this.objects.forEach(ReferenceCountUtil::release);
77-
this.objects.clear();
78-
}
79-
80-
public <T> T capture(T object) {
81-
this.objects.add(object);
82-
83-
return object;
84-
}
85-
86-
}
87-
8870
}

src/test/java/io/r2dbc/postgresql/message/backend/DataRowUnitTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
package io.r2dbc.postgresql.message.backend;
1818

1919
import io.netty.buffer.ByteBuf;
20+
import io.r2dbc.postgresql.util.ReferenceCountedCleaner;
2021
import org.junit.jupiter.api.AfterEach;
2122
import org.junit.jupiter.api.Test;
2223

23-
import static io.r2dbc.postgresql.message.backend.BackendMessageAssert.Cleaner;
2424
import static io.r2dbc.postgresql.message.backend.BackendMessageAssert.assertThat;
2525
import static io.r2dbc.postgresql.util.TestByteBufAllocator.TEST;
2626
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -30,7 +30,7 @@
3030
*/
3131
final class DataRowUnitTests {
3232

33-
private final Cleaner cleaner = new Cleaner();
33+
private final ReferenceCountedCleaner cleaner = new ReferenceCountedCleaner();
3434

3535
@AfterEach
3636
void tearDown() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.r2dbc.postgresql.util;
18+
19+
import io.netty.util.ReferenceCountUtil;
20+
import io.netty.util.ReferenceCounted;
21+
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
25+
public class ReferenceCountedCleaner {
26+
27+
private final List<Object> objects = new ArrayList<>();
28+
29+
public void clean() {
30+
this.objects.forEach(ReferenceCountUtil::release);
31+
this.objects.clear();
32+
}
33+
34+
public <T extends ReferenceCounted> T capture(T object) {
35+
this.objects.add(object);
36+
37+
return object;
38+
}
39+
40+
}

0 commit comments

Comments
 (0)