Skip to content

Commit 90f7bc2

Browse files
committed
feat: add copy in support for r2dbc postgresql driver.
This commit contains the happy flow copy in functionality without much testing. This should be seen as a design proposal for the r2dbc-postgresql team. If r2dbc-postgresql team agree with this proposed solution we can adress all testing requirements and fix all TODO comments.
1 parent b812b9d commit 90f7bc2

File tree

9 files changed

+656
-3
lines changed

9 files changed

+656
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright 2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.r2dbc.postgresql;
18+
19+
import io.r2dbc.postgresql.api.PostgresqlConnection;
20+
import io.r2dbc.postgresql.util.PostgresqlServerExtension;
21+
import org.junit.platform.commons.annotation.Testable;
22+
import org.openjdk.jmh.annotations.Benchmark;
23+
import org.openjdk.jmh.annotations.BenchmarkMode;
24+
import org.openjdk.jmh.annotations.Level;
25+
import org.openjdk.jmh.annotations.Mode;
26+
import org.openjdk.jmh.annotations.OutputTimeUnit;
27+
import org.openjdk.jmh.annotations.Param;
28+
import org.openjdk.jmh.annotations.Scope;
29+
import org.openjdk.jmh.annotations.Setup;
30+
import org.openjdk.jmh.annotations.State;
31+
import org.openjdk.jmh.annotations.TearDown;
32+
import org.openjdk.jmh.infra.Blackhole;
33+
import org.postgresql.copy.CopyManager;
34+
import org.postgresql.jdbc.PgConnection;
35+
import org.springframework.core.io.buffer.DataBuffer;
36+
import org.springframework.core.io.buffer.DataBufferUtils;
37+
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
38+
import reactor.core.publisher.Flux;
39+
40+
import java.io.FileInputStream;
41+
import java.io.FileOutputStream;
42+
import java.io.IOException;
43+
import java.io.InputStream;
44+
import java.io.OutputStream;
45+
import java.nio.ByteBuffer;
46+
import java.nio.charset.StandardCharsets;
47+
import java.nio.file.Files;
48+
import java.nio.file.Path;
49+
import java.nio.file.StandardOpenOption;
50+
import java.sql.SQLException;
51+
import java.sql.Statement;
52+
import java.util.concurrent.TimeUnit;
53+
import java.util.stream.IntStream;
54+
55+
/**
56+
* Benchmarks for Copy operation. Contains the following execution methods:
57+
*/
58+
@BenchmarkMode(Mode.Throughput)
59+
@OutputTimeUnit(TimeUnit.SECONDS)
60+
@Testable
61+
public class CopyInBenchmarks extends BenchmarkSettings {
62+
63+
private static PostgresqlServerExtension extension = new PostgresqlServerExtension();
64+
65+
@State(Scope.Benchmark)
66+
public static class ConnectionHolder {
67+
68+
@Param({"0", "1", "100", "1000000"})
69+
int rows;
70+
71+
final PgConnection jdbc;
72+
73+
final CopyManager copyManager;
74+
75+
final PostgresqlConnection r2dbc;
76+
77+
Path csvFile;
78+
79+
public ConnectionHolder() {
80+
81+
extension.initialize();
82+
try {
83+
jdbc = extension.getDataSource().getConnection()
84+
.unwrap(PgConnection.class);
85+
copyManager = jdbc.getCopyAPI();
86+
Statement statement = jdbc.createStatement();
87+
88+
try {
89+
statement.execute("DROP TABLE IF EXISTS simple_test");
90+
} catch (SQLException e) {
91+
}
92+
93+
statement.execute("CREATE TABLE simple_test (name VARCHAR(255), age int)");
94+
95+
jdbc.setAutoCommit(false);
96+
97+
r2dbc = new PostgresqlConnectionFactory(extension.getConnectionConfiguration()).create().block();
98+
} catch (SQLException e) {
99+
throw new RuntimeException(e);
100+
}
101+
}
102+
103+
@Setup(Level.Trial)
104+
public void doSetup() throws IOException {
105+
csvFile = Files.createTempFile("jmh-input", ".csv");
106+
107+
try (OutputStream outputStream = new FileOutputStream(csvFile.toFile())) {
108+
IntStream.range(0, rows)
109+
.mapToObj(i -> "some-input" + i + ";" + i + "\n")
110+
.forEach(row -> {
111+
try {
112+
outputStream.write(row.getBytes(StandardCharsets.UTF_8));
113+
} catch (Exception e) {
114+
throw new RuntimeException(e);
115+
}
116+
});
117+
}
118+
}
119+
120+
@TearDown(Level.Trial)
121+
public void doTearDown() throws IOException {
122+
Files.delete(csvFile);
123+
}
124+
125+
}
126+
127+
@Benchmark
128+
public void copyInR2dbc(ConnectionHolder connectionHolder, Blackhole voodoo) {
129+
int bufferSize = 65536; // BufferSize is the same as the one from JDBC's CopyManager
130+
Flux<ByteBuffer> input = DataBufferUtils.read(connectionHolder.csvFile, DefaultDataBufferFactory.sharedInstance, bufferSize, StandardOpenOption.READ)
131+
.map(DataBuffer::asByteBuffer);
132+
133+
Long rowsInserted = connectionHolder.r2dbc.copyIn("COPY simple_test (name, age) FROM STDIN DELIMITER ';'", input)
134+
.block();
135+
136+
voodoo.consume(rowsInserted);
137+
}
138+
139+
@Benchmark
140+
public void copyInJdbc(ConnectionHolder connectionHolder, Blackhole voodoo) throws IOException, SQLException {
141+
try (InputStream inputStream = new FileInputStream(connectionHolder.csvFile.toFile())) {
142+
143+
Long rowsInserted = connectionHolder.copyManager.copyIn("COPY simple_test (name, age) FROM STDIN DELIMITER ';'", inputStream);
144+
145+
voodoo.consume(rowsInserted);
146+
}
147+
}
148+
149+
}

src/main/java/io/r2dbc/postgresql/PostgresqlConnection.java

+6
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import reactor.util.Loggers;
5050
import reactor.util.annotation.Nullable;
5151

52+
import java.nio.ByteBuffer;
5253
import java.time.Duration;
5354
import java.util.concurrent.atomic.AtomicReference;
5455
import java.util.function.Function;
@@ -406,6 +407,11 @@ public void onComplete() {
406407
});
407408
}
408409

410+
@Override
411+
public Mono<Long> copyIn(String sql, Publisher<ByteBuffer> stdin) {
412+
return new PostgresqlCopyIn(resources).copy(sql, stdin);
413+
}
414+
409415
private static Function<TransactionStatus, String> getTransactionIsolationLevelQuery(IsolationLevel isolationLevel) {
410416
return transactionStatus -> {
411417
if (transactionStatus == OPEN) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright 2017 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.r2dbc.postgresql;
18+
19+
import io.netty.buffer.Unpooled;
20+
import io.netty.util.ReferenceCountUtil;
21+
import io.netty.util.ReferenceCounted;
22+
import io.r2dbc.postgresql.client.Client;
23+
import io.r2dbc.postgresql.message.backend.BackendMessage;
24+
import io.r2dbc.postgresql.message.backend.CommandComplete;
25+
import io.r2dbc.postgresql.message.backend.CopyInResponse;
26+
import io.r2dbc.postgresql.message.backend.ReadyForQuery;
27+
import io.r2dbc.postgresql.message.frontend.CopyData;
28+
import io.r2dbc.postgresql.message.frontend.CopyDone;
29+
import io.r2dbc.postgresql.message.frontend.CopyFail;
30+
import io.r2dbc.postgresql.message.frontend.FrontendMessage;
31+
import io.r2dbc.postgresql.message.frontend.Query;
32+
import io.r2dbc.postgresql.util.Assert;
33+
import io.r2dbc.postgresql.util.Operators;
34+
import org.reactivestreams.Publisher;
35+
import reactor.core.publisher.Flux;
36+
import reactor.core.publisher.Mono;
37+
38+
import java.nio.ByteBuffer;
39+
40+
import static io.r2dbc.postgresql.PostgresqlResult.toResult;
41+
42+
/**
43+
* An implementation for {@link CopyData} PostgreSQL queries.
44+
*/
45+
final class PostgresqlCopyIn {
46+
47+
private final ConnectionResources context;
48+
49+
PostgresqlCopyIn(ConnectionResources context) {
50+
this.context = Assert.requireNonNull(context, "context must not be null");
51+
}
52+
53+
public Mono<Long> copy(String sql, Publisher<ByteBuffer> stdin) {
54+
Flux<FrontendMessage> insertData = Flux.from(stdin)
55+
.map(buffer -> new CopyData(Unpooled.wrappedBuffer(buffer)));
56+
57+
return copyIn(sql, insertData);
58+
}
59+
60+
private Mono<Long> copyIn(String sql, Flux<FrontendMessage> frontendMessages) {
61+
Client client = context.getClient();
62+
63+
Flux<BackendMessage> backendMessages = frontendMessages
64+
.doOnNext(client::send)
65+
.doOnError(e -> !(e instanceof IllegalArgumentException), (e) -> sendCopyFail(e.getMessage()))
66+
.doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release)
67+
.thenMany(client.exchange(Mono.just(CopyDone.INSTANCE)));
68+
69+
return startCopy(sql)
70+
.concatWith(backendMessages)
71+
.doOnCancel(() -> sendCopyFail("Cancelled"))
72+
.as(Operators::discardOnCancel)
73+
.as(messages -> toResult(context, messages, ExceptionFactory.INSTANCE).getRowsUpdated());
74+
}
75+
76+
private Flux<BackendMessage> startCopy(String sql) {
77+
return context.getClient().exchange(
78+
// ReadyForQuery is returned when an invalid query is provided
79+
backendMessage -> backendMessage instanceof CopyInResponse || backendMessage instanceof ReadyForQuery,
80+
Mono.just(new Query(sql))
81+
)
82+
.doOnNext(message -> {
83+
if (message instanceof CommandComplete) {
84+
throw new IllegalArgumentException("Copy from stdin query expected, sql='" + sql + "', message=" + message);
85+
}
86+
});
87+
}
88+
89+
private void sendCopyFail(String message) {
90+
context.getClient().exchange(Mono.just(new CopyFail("Copy operation failed: " + message)))
91+
.as(Operators::discardOnCancel)
92+
.subscribe();
93+
}
94+
95+
@Override
96+
public String toString() {
97+
return "PostgresqlCopyIn{" +
98+
"context=" + this.context +
99+
'}';
100+
}
101+
102+
}

src/main/java/io/r2dbc/postgresql/api/PostgresqlConnection.java

+11
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
import io.r2dbc.spi.R2dbcNonTransientResourceException;
2323
import io.r2dbc.spi.TransactionDefinition;
2424
import io.r2dbc.spi.ValidationDepth;
25+
import org.reactivestreams.Publisher;
2526
import org.reactivestreams.Subscriber;
2627
import reactor.core.publisher.Flux;
2728
import reactor.core.publisher.Mono;
2829

30+
import java.nio.ByteBuffer;
2931
import java.time.Duration;
3032

3133
/**
@@ -170,4 +172,13 @@ public interface PostgresqlConnection extends Connection {
170172
@Override
171173
Mono<Boolean> validate(ValidationDepth depth);
172174

175+
/**
176+
* Copy bulk data from client into a PostgreSQL table very fast.
177+
*
178+
* @param sql the COPY sql statement
179+
* @param stdin the ByteBuffer publisher
180+
* @return a {@link Mono} with the amount of rows inserted
181+
*/
182+
Mono<Long> copyIn(String sql, Publisher<ByteBuffer> stdin);
183+
173184
}

src/test/java/io/r2dbc/postgresql/PostgresqlConnectionUnitTests.java

+25
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,18 @@
2121
import io.r2dbc.postgresql.client.TestClient;
2222
import io.r2dbc.postgresql.client.Version;
2323
import io.r2dbc.postgresql.codec.MockCodecs;
24+
import io.r2dbc.postgresql.message.Format;
2425
import io.r2dbc.postgresql.message.backend.CommandComplete;
26+
import io.r2dbc.postgresql.message.backend.CopyInResponse;
2527
import io.r2dbc.postgresql.message.backend.ErrorResponse;
28+
import io.r2dbc.postgresql.message.backend.ReadyForQuery;
29+
import io.r2dbc.postgresql.message.frontend.CopyDone;
2630
import io.r2dbc.postgresql.message.frontend.Query;
2731
import io.r2dbc.postgresql.message.frontend.Terminate;
2832
import io.r2dbc.spi.IsolationLevel;
2933
import io.r2dbc.spi.R2dbcNonTransientResourceException;
3034
import org.junit.jupiter.api.Test;
35+
import reactor.core.publisher.Flux;
3136
import reactor.test.StepVerifier;
3237

3338
import java.time.Duration;
@@ -38,6 +43,7 @@
3843
import static io.r2dbc.postgresql.client.TransactionStatus.IDLE;
3944
import static io.r2dbc.postgresql.client.TransactionStatus.OPEN;
4045
import static io.r2dbc.spi.IsolationLevel.READ_COMMITTED;
46+
import static java.util.Collections.emptySet;
4147
import static org.assertj.core.api.Assertions.assertThat;
4248
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
4349
import static org.mockito.Mockito.RETURNS_SMART_NULLS;
@@ -502,6 +508,25 @@ void setTransactionIsolationLevelNonOpen() {
502508
.verifyComplete();
503509
}
504510

511+
@Test
512+
void copyIn() {
513+
Client client = TestClient.builder()
514+
.transactionStatus(IDLE)
515+
.expectRequest(new Query("some-sql")).thenRespond(new CopyInResponse(emptySet(), Format.FORMAT_TEXT))
516+
.expectRequest(CopyDone.INSTANCE).thenRespond(
517+
new CommandComplete("cmd", 1, 0),
518+
new ReadyForQuery(ReadyForQuery.TransactionStatus.IDLE)
519+
)
520+
.build();
521+
522+
PostgresqlConnection connection = createConnection(client, MockCodecs.empty(), this.statementCache);
523+
524+
connection.copyIn("some-sql", Flux.empty())
525+
.as(StepVerifier::create)
526+
.expectNext(0L)
527+
.verifyComplete();
528+
}
529+
505530
@Test
506531
void setStatementTimeout() {
507532
Client client = TestClient.builder()

0 commit comments

Comments
 (0)