Skip to content

Commit 1f4c436

Browse files
ArjanSchoutenmp911de
authored andcommitted
Add COPY FROM support.
The driver now provides an API to support COPY FROM STDIN. [#500][resolves #183]
1 parent d008d64 commit 1f4c436

File tree

9 files changed

+655
-3
lines changed

9 files changed

+655
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright 2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.r2dbc.postgresql;
18+
19+
import io.netty.buffer.ByteBuf;
20+
import io.netty.buffer.Unpooled;
21+
import io.r2dbc.postgresql.api.PostgresqlConnection;
22+
import io.r2dbc.postgresql.util.PostgresqlServerExtension;
23+
import org.junit.platform.commons.annotation.Testable;
24+
import org.openjdk.jmh.annotations.Benchmark;
25+
import org.openjdk.jmh.annotations.BenchmarkMode;
26+
import org.openjdk.jmh.annotations.Level;
27+
import org.openjdk.jmh.annotations.Mode;
28+
import org.openjdk.jmh.annotations.OutputTimeUnit;
29+
import org.openjdk.jmh.annotations.Param;
30+
import org.openjdk.jmh.annotations.Scope;
31+
import org.openjdk.jmh.annotations.Setup;
32+
import org.openjdk.jmh.annotations.State;
33+
import org.openjdk.jmh.annotations.TearDown;
34+
import org.openjdk.jmh.infra.Blackhole;
35+
import org.postgresql.copy.CopyManager;
36+
import org.postgresql.jdbc.PgConnection;
37+
import reactor.core.publisher.Mono;
38+
39+
import java.io.File;
40+
import java.io.FileInputStream;
41+
import java.io.FileNotFoundException;
42+
import java.io.FileOutputStream;
43+
import java.io.IOException;
44+
import java.io.InputStream;
45+
import java.io.OutputStream;
46+
import java.nio.MappedByteBuffer;
47+
import java.nio.channels.FileChannel;
48+
import java.nio.charset.StandardCharsets;
49+
import java.nio.file.Files;
50+
import java.nio.file.Path;
51+
import java.sql.SQLException;
52+
import java.sql.Statement;
53+
import java.util.concurrent.TimeUnit;
54+
import java.util.stream.IntStream;
55+
56+
/**
57+
* Benchmarks for Copy operation. Contains the following execution methods:
58+
*/
59+
@BenchmarkMode(Mode.Throughput)
60+
@OutputTimeUnit(TimeUnit.SECONDS)
61+
@Testable
62+
public class CopyInBenchmarks extends BenchmarkSettings {
63+
64+
private static PostgresqlServerExtension extension = new PostgresqlServerExtension();
65+
66+
@State(Scope.Benchmark)
67+
public static class ConnectionHolder {
68+
69+
@Param({"0", "1", "100", "1000000"})
70+
int rows;
71+
72+
final PgConnection jdbc;
73+
74+
final CopyManager copyManager;
75+
76+
final PostgresqlConnection r2dbc;
77+
78+
Path csvFile;
79+
80+
public ConnectionHolder() {
81+
82+
extension.initialize();
83+
try {
84+
jdbc = extension.getDataSource().getConnection()
85+
.unwrap(PgConnection.class);
86+
copyManager = jdbc.getCopyAPI();
87+
Statement statement = jdbc.createStatement();
88+
89+
try {
90+
statement.execute("DROP TABLE IF EXISTS simple_test");
91+
} catch (SQLException e) {
92+
}
93+
94+
statement.execute("CREATE TABLE simple_test (name VARCHAR(255), age int)");
95+
96+
jdbc.setAutoCommit(false);
97+
98+
r2dbc = new PostgresqlConnectionFactory(extension.getConnectionConfiguration()).create().block();
99+
} catch (SQLException e) {
100+
throw new RuntimeException(e);
101+
}
102+
}
103+
104+
@Setup(Level.Trial)
105+
public void doSetup() throws IOException {
106+
csvFile = Files.createTempFile("jmh-input", ".csv");
107+
108+
try (OutputStream outputStream = new FileOutputStream(csvFile.toFile())) {
109+
IntStream.range(0, rows)
110+
.mapToObj(i -> "some-input" + i + ";" + i + "\n")
111+
.forEach(row -> {
112+
try {
113+
outputStream.write(row.getBytes(StandardCharsets.UTF_8));
114+
} catch (Exception e) {
115+
throw new RuntimeException(e);
116+
}
117+
});
118+
}
119+
}
120+
121+
@TearDown(Level.Trial)
122+
public void doTearDown() throws IOException {
123+
Files.delete(csvFile);
124+
}
125+
126+
}
127+
128+
@Benchmark
129+
public void copyInR2dbc(ConnectionHolder connectionHolder, Blackhole voodoo) throws IOException {
130+
File file = connectionHolder.csvFile.toFile();
131+
try (FileInputStream fileInputStream = new FileInputStream(file);
132+
FileChannel fileChannel = fileInputStream.getChannel()) {
133+
MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, file.length());
134+
ByteBuf byteBuf = Unpooled.wrappedBuffer(mappedByteBuffer);
135+
136+
Long rowsInserted = connectionHolder.r2dbc.copyIn("COPY simple_test (name, age) FROM STDIN DELIMITER ';'", Mono.just(byteBuf))
137+
.block();
138+
139+
voodoo.consume(rowsInserted);
140+
}
141+
}
142+
143+
@Benchmark
144+
public void copyInJdbc(ConnectionHolder connectionHolder, Blackhole voodoo) throws IOException, SQLException {
145+
try (InputStream inputStream = new FileInputStream(connectionHolder.csvFile.toFile())) {
146+
147+
Long rowsInserted = connectionHolder.copyManager.copyIn("COPY simple_test (name, age) FROM STDIN DELIMITER ';'", inputStream);
148+
149+
voodoo.consume(rowsInserted);
150+
}
151+
}
152+
153+
}

src/main/java/io/r2dbc/postgresql/PostgresqlConnection.java

+7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package io.r2dbc.postgresql;
1818

19+
import io.netty.buffer.ByteBuf;
1920
import io.r2dbc.postgresql.api.ErrorDetails;
2021
import io.r2dbc.postgresql.api.Notification;
2122
import io.r2dbc.postgresql.api.PostgresTransactionDefinition;
@@ -49,6 +50,7 @@
4950
import reactor.util.Loggers;
5051
import reactor.util.annotation.Nullable;
5152

53+
import java.nio.ByteBuffer;
5254
import java.time.Duration;
5355
import java.util.concurrent.atomic.AtomicReference;
5456
import java.util.function.Function;
@@ -406,6 +408,11 @@ public void onComplete() {
406408
});
407409
}
408410

411+
@Override
412+
public Mono<Long> copyIn(String sql, Publisher<ByteBuf> stdin) {
413+
return new PostgresqlCopyIn(resources).copy(sql, stdin);
414+
}
415+
409416
private static Function<TransactionStatus, String> getTransactionIsolationLevelQuery(IsolationLevel isolationLevel) {
410417
return transactionStatus -> {
411418
if (transactionStatus == OPEN) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright 2017 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.r2dbc.postgresql;
18+
19+
import io.netty.buffer.ByteBuf;
20+
import io.netty.util.ReferenceCountUtil;
21+
import io.netty.util.ReferenceCounted;
22+
import io.r2dbc.postgresql.client.Client;
23+
import io.r2dbc.postgresql.message.backend.BackendMessage;
24+
import io.r2dbc.postgresql.message.backend.CommandComplete;
25+
import io.r2dbc.postgresql.message.backend.CopyInResponse;
26+
import io.r2dbc.postgresql.message.backend.ReadyForQuery;
27+
import io.r2dbc.postgresql.message.frontend.CopyData;
28+
import io.r2dbc.postgresql.message.frontend.CopyDone;
29+
import io.r2dbc.postgresql.message.frontend.CopyFail;
30+
import io.r2dbc.postgresql.message.frontend.Query;
31+
import io.r2dbc.postgresql.util.Assert;
32+
import io.r2dbc.postgresql.util.Operators;
33+
import org.reactivestreams.Publisher;
34+
import reactor.core.publisher.Flux;
35+
import reactor.core.publisher.Mono;
36+
37+
import static io.r2dbc.postgresql.PostgresqlResult.toResult;
38+
39+
/**
40+
* An implementation for {@link CopyData} PostgreSQL queries.
41+
*/
42+
final class PostgresqlCopyIn {
43+
44+
private final ConnectionResources context;
45+
46+
PostgresqlCopyIn(ConnectionResources context) {
47+
this.context = Assert.requireNonNull(context, "context must not be null");
48+
}
49+
50+
Mono<Long> copy(String sql, Publisher<ByteBuf> stdin) {
51+
return Flux.from(stdin)
52+
.map(CopyData::new)
53+
.as(messages -> copyIn(sql, messages));
54+
}
55+
56+
private Mono<Long> copyIn(String sql, Flux<CopyData> copyDataMessages) {
57+
Client client = context.getClient();
58+
59+
Flux<BackendMessage> backendMessages = copyDataMessages
60+
.doOnNext(client::send)
61+
.doOnError((e) -> sendCopyFail(e.getMessage()))
62+
.doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release)
63+
.thenMany(client.exchange(Mono.just(CopyDone.INSTANCE)));
64+
65+
return startCopy(sql)
66+
.concatWith(backendMessages)
67+
.doOnCancel(() -> sendCopyFail("Cancelled"))
68+
.as(Operators::discardOnCancel)
69+
.as(messages -> toResult(context, messages, ExceptionFactory.INSTANCE).getRowsUpdated());
70+
}
71+
72+
private Flux<BackendMessage> startCopy(String sql) {
73+
return context.getClient().exchange(
74+
// ReadyForQuery is returned when an invalid query is provided
75+
backendMessage -> backendMessage instanceof CopyInResponse || backendMessage instanceof ReadyForQuery,
76+
Mono.just(new Query(sql))
77+
)
78+
.doOnNext(message -> {
79+
if (message instanceof CommandComplete) {
80+
throw new IllegalArgumentException("Copy from stdin query expected, sql='" + sql + "', message=" + message);
81+
}
82+
});
83+
}
84+
85+
private void sendCopyFail(String message) {
86+
context.getClient().exchange(Mono.just(new CopyFail("Copy operation failed: " + message)))
87+
.as(Operators::discardOnCancel)
88+
.subscribe();
89+
}
90+
91+
@Override
92+
public String toString() {
93+
return "PostgresqlCopyIn{" +
94+
"context=" + this.context +
95+
'}';
96+
}
97+
98+
}

src/main/java/io/r2dbc/postgresql/api/PostgresqlConnection.java

+11
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
package io.r2dbc.postgresql.api;
1818

19+
import io.netty.buffer.ByteBuf;
1920
import io.r2dbc.postgresql.message.frontend.CancelRequest;
2021
import io.r2dbc.spi.Connection;
2122
import io.r2dbc.spi.IsolationLevel;
2223
import io.r2dbc.spi.R2dbcNonTransientResourceException;
2324
import io.r2dbc.spi.TransactionDefinition;
2425
import io.r2dbc.spi.ValidationDepth;
26+
import org.reactivestreams.Publisher;
2527
import org.reactivestreams.Subscriber;
2628
import reactor.core.publisher.Flux;
2729
import reactor.core.publisher.Mono;
@@ -170,4 +172,13 @@ public interface PostgresqlConnection extends Connection {
170172
@Override
171173
Mono<Boolean> validate(ValidationDepth depth);
172174

175+
/**
176+
* Use COPY FROM STDIN for very fast copying into a database table.
177+
*
178+
* @param sql the COPY … FROM STDIN sql statement
179+
* @param stdin the ByteBuf publisher
180+
* @return a {@link Mono} with the amount of rows inserted
181+
*/
182+
Mono<Long> copyIn(String sql, Publisher<ByteBuf> stdin);
183+
173184
}

src/test/java/io/r2dbc/postgresql/PostgresqlConnectionUnitTests.java

+25
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,18 @@
2121
import io.r2dbc.postgresql.client.TestClient;
2222
import io.r2dbc.postgresql.client.Version;
2323
import io.r2dbc.postgresql.codec.MockCodecs;
24+
import io.r2dbc.postgresql.message.Format;
2425
import io.r2dbc.postgresql.message.backend.CommandComplete;
26+
import io.r2dbc.postgresql.message.backend.CopyInResponse;
2527
import io.r2dbc.postgresql.message.backend.ErrorResponse;
28+
import io.r2dbc.postgresql.message.backend.ReadyForQuery;
29+
import io.r2dbc.postgresql.message.frontend.CopyDone;
2630
import io.r2dbc.postgresql.message.frontend.Query;
2731
import io.r2dbc.postgresql.message.frontend.Terminate;
2832
import io.r2dbc.spi.IsolationLevel;
2933
import io.r2dbc.spi.R2dbcNonTransientResourceException;
3034
import org.junit.jupiter.api.Test;
35+
import reactor.core.publisher.Flux;
3136
import reactor.test.StepVerifier;
3237

3338
import java.time.Duration;
@@ -38,6 +43,7 @@
3843
import static io.r2dbc.postgresql.client.TransactionStatus.IDLE;
3944
import static io.r2dbc.postgresql.client.TransactionStatus.OPEN;
4045
import static io.r2dbc.spi.IsolationLevel.READ_COMMITTED;
46+
import static java.util.Collections.emptySet;
4147
import static org.assertj.core.api.Assertions.assertThat;
4248
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
4349
import static org.mockito.Mockito.RETURNS_SMART_NULLS;
@@ -502,6 +508,25 @@ void setTransactionIsolationLevelNonOpen() {
502508
.verifyComplete();
503509
}
504510

511+
@Test
512+
void copyIn() {
513+
Client client = TestClient.builder()
514+
.transactionStatus(IDLE)
515+
.expectRequest(new Query("some-sql")).thenRespond(new CopyInResponse(emptySet(), Format.FORMAT_TEXT))
516+
.expectRequest(CopyDone.INSTANCE).thenRespond(
517+
new CommandComplete("cmd", 1, 0),
518+
new ReadyForQuery(ReadyForQuery.TransactionStatus.IDLE)
519+
)
520+
.build();
521+
522+
PostgresqlConnection connection = createConnection(client, MockCodecs.empty(), this.statementCache);
523+
524+
connection.copyIn("some-sql", Flux.empty())
525+
.as(StepVerifier::create)
526+
.expectNext(0L)
527+
.verifyComplete();
528+
}
529+
505530
@Test
506531
void setStatementTimeout() {
507532
Client client = TestClient.builder()

0 commit comments

Comments
 (0)