Skip to content

Add COPY FROM support #500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions src/jmh/java/io/r2dbc/postgresql/CopyInBenchmarks.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Copyright 2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.r2dbc.postgresql;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.r2dbc.postgresql.api.PostgresqlConnection;
import io.r2dbc.postgresql.util.PostgresqlServerExtension;
import org.junit.platform.commons.annotation.Testable;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.infra.Blackhole;
import org.postgresql.copy.CopyManager;
import org.postgresql.jdbc.PgConnection;
import reactor.core.publisher.Mono;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

/**
* Benchmarks for Copy operation. Contains the following execution methods:
*/
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
@Testable
public class CopyInBenchmarks extends BenchmarkSettings {

private static PostgresqlServerExtension extension = new PostgresqlServerExtension();

@State(Scope.Benchmark)
public static class ConnectionHolder {

@Param({"0", "1", "100", "1000000"})
int rows;

final PgConnection jdbc;

final CopyManager copyManager;

final PostgresqlConnection r2dbc;

Path csvFile;

public ConnectionHolder() {

extension.initialize();
try {
jdbc = extension.getDataSource().getConnection()
.unwrap(PgConnection.class);
copyManager = jdbc.getCopyAPI();
Statement statement = jdbc.createStatement();

try {
statement.execute("DROP TABLE IF EXISTS simple_test");
} catch (SQLException e) {
}

statement.execute("CREATE TABLE simple_test (name VARCHAR(255), age int)");

jdbc.setAutoCommit(false);

r2dbc = new PostgresqlConnectionFactory(extension.getConnectionConfiguration()).create().block();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

@Setup(Level.Trial)
public void doSetup() throws IOException {
csvFile = Files.createTempFile("jmh-input", ".csv");

try (OutputStream outputStream = new FileOutputStream(csvFile.toFile())) {
IntStream.range(0, rows)
.mapToObj(i -> "some-input" + i + ";" + i + "\n")
.forEach(row -> {
try {
outputStream.write(row.getBytes(StandardCharsets.UTF_8));
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}
}

@TearDown(Level.Trial)
public void doTearDown() throws IOException {
Files.delete(csvFile);
}

}

@Benchmark
public void copyInR2dbc(ConnectionHolder connectionHolder, Blackhole voodoo) throws IOException {
File file = connectionHolder.csvFile.toFile();
try (FileInputStream fileInputStream = new FileInputStream(file);
FileChannel fileChannel = fileInputStream.getChannel()) {
MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, file.length());
ByteBuf byteBuf = Unpooled.wrappedBuffer(mappedByteBuffer);

Long rowsInserted = connectionHolder.r2dbc.copyIn("COPY simple_test (name, age) FROM STDIN DELIMITER ';'", Mono.just(byteBuf))
.block();

voodoo.consume(rowsInserted);
}
}

@Benchmark
public void copyInJdbc(ConnectionHolder connectionHolder, Blackhole voodoo) throws IOException, SQLException {
try (InputStream inputStream = new FileInputStream(connectionHolder.csvFile.toFile())) {

Long rowsInserted = connectionHolder.copyManager.copyIn("COPY simple_test (name, age) FROM STDIN DELIMITER ';'", inputStream);

voodoo.consume(rowsInserted);
}
}

}
7 changes: 7 additions & 0 deletions src/main/java/io/r2dbc/postgresql/PostgresqlConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.r2dbc.postgresql;

import io.netty.buffer.ByteBuf;
import io.r2dbc.postgresql.api.ErrorDetails;
import io.r2dbc.postgresql.api.Notification;
import io.r2dbc.postgresql.api.PostgresTransactionDefinition;
Expand Down Expand Up @@ -49,6 +50,7 @@
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;

import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
Expand Down Expand Up @@ -406,6 +408,11 @@ public void onComplete() {
});
}

@Override
public Mono<Long> copyIn(String sql, Publisher<ByteBuf> stdin) {
return new PostgresqlCopyIn(resources).copy(sql, stdin);
}

private static Function<TransactionStatus, String> getTransactionIsolationLevelQuery(IsolationLevel isolationLevel) {
return transactionStatus -> {
if (transactionStatus == OPEN) {
Expand Down
98 changes: 98 additions & 0 deletions src/main/java/io/r2dbc/postgresql/PostgresqlCopyIn.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.r2dbc.postgresql;

import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.r2dbc.postgresql.client.Client;
import io.r2dbc.postgresql.message.backend.BackendMessage;
import io.r2dbc.postgresql.message.backend.CommandComplete;
import io.r2dbc.postgresql.message.backend.CopyInResponse;
import io.r2dbc.postgresql.message.backend.ReadyForQuery;
import io.r2dbc.postgresql.message.frontend.CopyData;
import io.r2dbc.postgresql.message.frontend.CopyDone;
import io.r2dbc.postgresql.message.frontend.CopyFail;
import io.r2dbc.postgresql.message.frontend.Query;
import io.r2dbc.postgresql.util.Assert;
import io.r2dbc.postgresql.util.Operators;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import static io.r2dbc.postgresql.PostgresqlResult.toResult;

/**
* An implementation for {@link CopyData} PostgreSQL queries.
*/
final class PostgresqlCopyIn {

private final ConnectionResources context;

PostgresqlCopyIn(ConnectionResources context) {
this.context = Assert.requireNonNull(context, "context must not be null");
}

Mono<Long> copy(String sql, Publisher<ByteBuf> stdin) {
return Flux.from(stdin)
.map(CopyData::new)
.as(messages -> copyIn(sql, messages));
}

private Mono<Long> copyIn(String sql, Flux<CopyData> copyDataMessages) {
Client client = context.getClient();

Flux<BackendMessage> backendMessages = copyDataMessages
.doOnNext(client::send)
.doOnError((e) -> sendCopyFail(e.getMessage()))
.doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release)
.thenMany(client.exchange(Mono.just(CopyDone.INSTANCE)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the client exchange, it makes sense to start with a many-sink (Sinks.Many.asFlux()) where we keep a single conversation/publisher with the server open. Otherwise, we call exchange multiple times and that generates a bit of overhead and may interfere with other connection activity causing protocol resynchronization.

I can apply this change during the merge as it bears some complexity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did an attempt but it is a little bit complex since we have to wait for the CopyInResponse before sending the data frames.


return startCopy(sql)
.concatWith(backendMessages)
.doOnCancel(() -> sendCopyFail("Cancelled"))
.as(Operators::discardOnCancel)
.as(messages -> toResult(context, messages, ExceptionFactory.INSTANCE).getRowsUpdated());
}

private Flux<BackendMessage> startCopy(String sql) {
return context.getClient().exchange(
// ReadyForQuery is returned when an invalid query is provided
backendMessage -> backendMessage instanceof CopyInResponse || backendMessage instanceof ReadyForQuery,
Mono.just(new Query(sql))
)
.doOnNext(message -> {
if (message instanceof CommandComplete) {
throw new IllegalArgumentException("Copy from stdin query expected, sql='" + sql + "', message=" + message);
}
});
}

private void sendCopyFail(String message) {
context.getClient().exchange(Mono.just(new CopyFail("Copy operation failed: " + message)))
.as(Operators::discardOnCancel)
.subscribe();
}

@Override
public String toString() {
return "PostgresqlCopyIn{" +
"context=" + this.context +
'}';
}

}
11 changes: 11 additions & 0 deletions src/main/java/io/r2dbc/postgresql/api/PostgresqlConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

package io.r2dbc.postgresql.api;

import io.netty.buffer.ByteBuf;
import io.r2dbc.postgresql.message.frontend.CancelRequest;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.IsolationLevel;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
import io.r2dbc.spi.TransactionDefinition;
import io.r2dbc.spi.ValidationDepth;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -170,4 +172,13 @@ public interface PostgresqlConnection extends Connection {
@Override
Mono<Boolean> validate(ValidationDepth depth);

/**
* Use COPY FROM STDIN for very fast copying into a database table.
*
* @param sql the COPY … FROM STDIN sql statement
* @param stdin the ByteBuf publisher
* @return a {@link Mono} with the amount of rows inserted
*/
Mono<Long> copyIn(String sql, Publisher<ByteBuf> stdin);

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,18 @@
import io.r2dbc.postgresql.client.TestClient;
import io.r2dbc.postgresql.client.Version;
import io.r2dbc.postgresql.codec.MockCodecs;
import io.r2dbc.postgresql.message.Format;
import io.r2dbc.postgresql.message.backend.CommandComplete;
import io.r2dbc.postgresql.message.backend.CopyInResponse;
import io.r2dbc.postgresql.message.backend.ErrorResponse;
import io.r2dbc.postgresql.message.backend.ReadyForQuery;
import io.r2dbc.postgresql.message.frontend.CopyDone;
import io.r2dbc.postgresql.message.frontend.Query;
import io.r2dbc.postgresql.message.frontend.Terminate;
import io.r2dbc.spi.IsolationLevel;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;

import java.time.Duration;
Expand All @@ -38,6 +43,7 @@
import static io.r2dbc.postgresql.client.TransactionStatus.IDLE;
import static io.r2dbc.postgresql.client.TransactionStatus.OPEN;
import static io.r2dbc.spi.IsolationLevel.READ_COMMITTED;
import static java.util.Collections.emptySet;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.RETURNS_SMART_NULLS;
Expand Down Expand Up @@ -502,6 +508,25 @@ void setTransactionIsolationLevelNonOpen() {
.verifyComplete();
}

@Test
void copyIn() {
Client client = TestClient.builder()
.transactionStatus(IDLE)
.expectRequest(new Query("some-sql")).thenRespond(new CopyInResponse(emptySet(), Format.FORMAT_TEXT))
.expectRequest(CopyDone.INSTANCE).thenRespond(
new CommandComplete("cmd", 1, 0),
new ReadyForQuery(ReadyForQuery.TransactionStatus.IDLE)
)
.build();

PostgresqlConnection connection = createConnection(client, MockCodecs.empty(), this.statementCache);

connection.copyIn("some-sql", Flux.empty())
.as(StepVerifier::create)
.expectNext(0L)
.verifyComplete();
}

@Test
void setStatementTimeout() {
Client client = TestClient.builder()
Expand Down
Loading