diff --git a/src/jmh/java/io/r2dbc/postgresql/CopyInBenchmarks.java b/src/jmh/java/io/r2dbc/postgresql/CopyInBenchmarks.java new file mode 100644 index 00000000..47891abb --- /dev/null +++ b/src/jmh/java/io/r2dbc/postgresql/CopyInBenchmarks.java @@ -0,0 +1,153 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.r2dbc.postgresql.api.PostgresqlConnection; +import io.r2dbc.postgresql.util.PostgresqlServerExtension; +import org.junit.platform.commons.annotation.Testable; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.infra.Blackhole; +import org.postgresql.copy.CopyManager; +import org.postgresql.jdbc.PgConnection; +import reactor.core.publisher.Mono; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +/** + * Benchmarks for Copy operation. Contains the following execution methods: + */ +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +@Testable +public class CopyInBenchmarks extends BenchmarkSettings { + + private static PostgresqlServerExtension extension = new PostgresqlServerExtension(); + + @State(Scope.Benchmark) + public static class ConnectionHolder { + + @Param({"0", "1", "100", "1000000"}) + int rows; + + final PgConnection jdbc; + + final CopyManager copyManager; + + final PostgresqlConnection r2dbc; + + Path csvFile; + + public ConnectionHolder() { + + extension.initialize(); + try { + jdbc = extension.getDataSource().getConnection() + .unwrap(PgConnection.class); + copyManager = jdbc.getCopyAPI(); + Statement statement = jdbc.createStatement(); + + try { + statement.execute("DROP TABLE IF EXISTS simple_test"); + } catch (SQLException e) { + } + + statement.execute("CREATE TABLE simple_test (name VARCHAR(255), age int)"); + + jdbc.setAutoCommit(false); + + r2dbc = new PostgresqlConnectionFactory(extension.getConnectionConfiguration()).create().block(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + @Setup(Level.Trial) + public void doSetup() throws IOException { + csvFile = Files.createTempFile("jmh-input", ".csv"); + + try (OutputStream outputStream = new FileOutputStream(csvFile.toFile())) { + IntStream.range(0, rows) + .mapToObj(i -> "some-input" + i + ";" + i + "\n") + .forEach(row -> { + try { + outputStream.write(row.getBytes(StandardCharsets.UTF_8)); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + @TearDown(Level.Trial) + public void doTearDown() throws IOException { + Files.delete(csvFile); + } + + } + + @Benchmark + public void copyInR2dbc(ConnectionHolder connectionHolder, Blackhole voodoo) throws IOException { + File file = connectionHolder.csvFile.toFile(); + try (FileInputStream fileInputStream = new FileInputStream(file); + FileChannel fileChannel = fileInputStream.getChannel()) { + MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, file.length()); + ByteBuf byteBuf = Unpooled.wrappedBuffer(mappedByteBuffer); + + Long rowsInserted = connectionHolder.r2dbc.copyIn("COPY simple_test (name, age) FROM STDIN DELIMITER ';'", Mono.just(byteBuf)) + .block(); + + voodoo.consume(rowsInserted); + } + } + + @Benchmark + public void copyInJdbc(ConnectionHolder connectionHolder, Blackhole voodoo) throws IOException, SQLException { + try (InputStream inputStream = new FileInputStream(connectionHolder.csvFile.toFile())) { + + Long rowsInserted = connectionHolder.copyManager.copyIn("COPY simple_test (name, age) FROM STDIN DELIMITER ';'", inputStream); + + voodoo.consume(rowsInserted); + } + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnection.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnection.java index f6cf0834..5ba545f8 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnection.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnection.java @@ -16,6 +16,7 @@ package io.r2dbc.postgresql; +import io.netty.buffer.ByteBuf; import io.r2dbc.postgresql.api.ErrorDetails; import io.r2dbc.postgresql.api.Notification; import io.r2dbc.postgresql.api.PostgresTransactionDefinition; @@ -49,6 +50,7 @@ import reactor.util.Loggers; import reactor.util.annotation.Nullable; +import java.nio.ByteBuffer; import java.time.Duration; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -406,6 +408,11 @@ public void onComplete() { }); } + @Override + public Mono copyIn(String sql, Publisher stdin) { + return new PostgresqlCopyIn(resources).copy(sql, stdin); + } + private static Function getTransactionIsolationLevelQuery(IsolationLevel isolationLevel) { return transactionStatus -> { if (transactionStatus == OPEN) { diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlCopyIn.java b/src/main/java/io/r2dbc/postgresql/PostgresqlCopyIn.java new file mode 100644 index 00000000..55c5b9dc --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlCopyIn.java @@ -0,0 +1,98 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql; + +import io.netty.buffer.ByteBuf; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.message.backend.BackendMessage; +import io.r2dbc.postgresql.message.backend.CommandComplete; +import io.r2dbc.postgresql.message.backend.CopyInResponse; +import io.r2dbc.postgresql.message.backend.ReadyForQuery; +import io.r2dbc.postgresql.message.frontend.CopyData; +import io.r2dbc.postgresql.message.frontend.CopyDone; +import io.r2dbc.postgresql.message.frontend.CopyFail; +import io.r2dbc.postgresql.message.frontend.Query; +import io.r2dbc.postgresql.util.Assert; +import io.r2dbc.postgresql.util.Operators; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import static io.r2dbc.postgresql.PostgresqlResult.toResult; + +/** + * An implementation for {@link CopyData} PostgreSQL queries. + */ +final class PostgresqlCopyIn { + + private final ConnectionResources context; + + PostgresqlCopyIn(ConnectionResources context) { + this.context = Assert.requireNonNull(context, "context must not be null"); + } + + Mono copy(String sql, Publisher stdin) { + return Flux.from(stdin) + .map(CopyData::new) + .as(messages -> copyIn(sql, messages)); + } + + private Mono copyIn(String sql, Flux copyDataMessages) { + Client client = context.getClient(); + + Flux backendMessages = copyDataMessages + .doOnNext(client::send) + .doOnError((e) -> sendCopyFail(e.getMessage())) + .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release) + .thenMany(client.exchange(Mono.just(CopyDone.INSTANCE))); + + return startCopy(sql) + .concatWith(backendMessages) + .doOnCancel(() -> sendCopyFail("Cancelled")) + .as(Operators::discardOnCancel) + .as(messages -> toResult(context, messages, ExceptionFactory.INSTANCE).getRowsUpdated()); + } + + private Flux startCopy(String sql) { + return context.getClient().exchange( + // ReadyForQuery is returned when an invalid query is provided + backendMessage -> backendMessage instanceof CopyInResponse || backendMessage instanceof ReadyForQuery, + Mono.just(new Query(sql)) + ) + .doOnNext(message -> { + if (message instanceof CommandComplete) { + throw new IllegalArgumentException("Copy from stdin query expected, sql='" + sql + "', message=" + message); + } + }); + } + + private void sendCopyFail(String message) { + context.getClient().exchange(Mono.just(new CopyFail("Copy operation failed: " + message))) + .as(Operators::discardOnCancel) + .subscribe(); + } + + @Override + public String toString() { + return "PostgresqlCopyIn{" + + "context=" + this.context + + '}'; + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/api/PostgresqlConnection.java b/src/main/java/io/r2dbc/postgresql/api/PostgresqlConnection.java index 6b820a4e..db6ef57a 100644 --- a/src/main/java/io/r2dbc/postgresql/api/PostgresqlConnection.java +++ b/src/main/java/io/r2dbc/postgresql/api/PostgresqlConnection.java @@ -16,12 +16,14 @@ package io.r2dbc.postgresql.api; +import io.netty.buffer.ByteBuf; import io.r2dbc.postgresql.message.frontend.CancelRequest; import io.r2dbc.spi.Connection; import io.r2dbc.spi.IsolationLevel; import io.r2dbc.spi.R2dbcNonTransientResourceException; import io.r2dbc.spi.TransactionDefinition; import io.r2dbc.spi.ValidationDepth; +import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -170,4 +172,13 @@ public interface PostgresqlConnection extends Connection { @Override Mono validate(ValidationDepth depth); + /** + * Use COPY FROM STDIN for very fast copying into a database table. + * + * @param sql the COPY … FROM STDIN sql statement + * @param stdin the ByteBuf publisher + * @return a {@link Mono} with the amount of rows inserted + */ + Mono copyIn(String sql, Publisher stdin); + } diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionUnitTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionUnitTests.java index 68be9062..41cd1012 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionUnitTests.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionUnitTests.java @@ -21,13 +21,18 @@ import io.r2dbc.postgresql.client.TestClient; import io.r2dbc.postgresql.client.Version; import io.r2dbc.postgresql.codec.MockCodecs; +import io.r2dbc.postgresql.message.Format; import io.r2dbc.postgresql.message.backend.CommandComplete; +import io.r2dbc.postgresql.message.backend.CopyInResponse; import io.r2dbc.postgresql.message.backend.ErrorResponse; +import io.r2dbc.postgresql.message.backend.ReadyForQuery; +import io.r2dbc.postgresql.message.frontend.CopyDone; import io.r2dbc.postgresql.message.frontend.Query; import io.r2dbc.postgresql.message.frontend.Terminate; import io.r2dbc.spi.IsolationLevel; import io.r2dbc.spi.R2dbcNonTransientResourceException; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.test.StepVerifier; import java.time.Duration; @@ -38,6 +43,7 @@ import static io.r2dbc.postgresql.client.TransactionStatus.IDLE; import static io.r2dbc.postgresql.client.TransactionStatus.OPEN; import static io.r2dbc.spi.IsolationLevel.READ_COMMITTED; +import static java.util.Collections.emptySet; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.RETURNS_SMART_NULLS; @@ -502,6 +508,25 @@ void setTransactionIsolationLevelNonOpen() { .verifyComplete(); } + @Test + void copyIn() { + Client client = TestClient.builder() + .transactionStatus(IDLE) + .expectRequest(new Query("some-sql")).thenRespond(new CopyInResponse(emptySet(), Format.FORMAT_TEXT)) + .expectRequest(CopyDone.INSTANCE).thenRespond( + new CommandComplete("cmd", 1, 0), + new ReadyForQuery(ReadyForQuery.TransactionStatus.IDLE) + ) + .build(); + + PostgresqlConnection connection = createConnection(client, MockCodecs.empty(), this.statementCache); + + connection.copyIn("some-sql", Flux.empty()) + .as(StepVerifier::create) + .expectNext(0L) + .verifyComplete(); + } + @Test void setStatementTimeout() { Client client = TestClient.builder() diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlCopyInIntegrationTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlCopyInIntegrationTests.java new file mode 100644 index 00000000..d915df0d --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlCopyInIntegrationTests.java @@ -0,0 +1,180 @@ +/* + * Copyright 2019-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.r2dbc.postgresql.ExceptionFactory.PostgresqlBadGrammarException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.jdbc.core.JdbcOperations; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.util.Collections; +import java.util.List; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link PostgresqlCopyIn}. + */ +class PostgresqlCopyInIntegrationTests extends AbstractIntegrationTests { + + @BeforeEach + void setUp() { + super.setUp(); + getJdbcOperations().execute("DROP TABLE IF EXISTS test"); + getJdbcOperations().execute("CREATE TABLE test (id SERIAL PRIMARY KEY, val VARCHAR(255), timestamp TIMESTAMP)"); + } + + @AfterEach + void tearDown() { + super.tearDown(); + getJdbcOperations().execute("DROP TABLE IF EXISTS test"); + } + + private JdbcOperations getJdbcOperations() { + return SERVER.getJdbcOperations(); + } + + @Override + protected void customize(PostgresqlConnectionConfiguration.Builder builder) { + builder.preparedStatementCacheQueries(2); + } + + @Test + void shouldCopyDataIntoTable() { + String sql = "COPY test (val) FROM STDIN"; + + Flux data = Flux.just( + byteBuf("d\n"), + byteBuf("d\n"), + byteBuf("e\n") + ); + + this.connection.copyIn(sql, data) + .as(StepVerifier::create) + .expectNext(3L) + .verifyComplete(); + + // Verify the connection is no longer in COPY-IN mode and verify data is copied into the table. + verifyItemsInserted(asList("d", "d", "e")); + } + + @Test + void shouldHandleErrorOnFailureInInput() { + String sql = "COPY test (val) FROM STDIN"; + + Flux data = Flux.just( + byteBuf("d\n") + ) + .concatWith(Mono.error(new RuntimeException("Failed during input generation"))); + + this.connection.copyIn(sql, data) + .as(StepVerifier::create) + .expectError(RuntimeException.class) + .verify(); + + verifyItemsInserted(emptyList()); + } + + @Test + void shouldCopyNothingEmptyFlux() { + String sql = "COPY test (val) FROM STDIN"; + + Flux data = Flux.empty(); + + this.connection.copyIn(sql, data) + .as(StepVerifier::create) + .expectNext(0L) + .verifyComplete(); + } + + @Test + void shouldHandleErrorOnValidNonCopyInQuery() { + String sql = "SELECT 1"; + + Flux input = Flux.just(byteBuf("something,something-invalid\n")); + + this.connection.copyIn(sql, input) + .as(StepVerifier::create) + .consumeErrorWith(e -> assertThat(e) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Copy from stdin query expected, sql='SELECT 1', message=CommandComplete{command=SELECT, rowId=null, rows=1}") + ) + .verify(); + } + + @Test + void shouldHandleErrors() { + String sql = "COPY test (val) FROM STDIN"; + + int characterCountVarcharType = 256; + Flux input = Flux.just(String.join("", Collections.nCopies(characterCountVarcharType, "a"))) + .map(this::byteBuf); + + verifyCopyInFailed(sql, input, "value too long for type character varying(255)"); + } + + @Test + void shouldFailOnInvalidStatement() { + String sql = "COPY invalid command"; + + Flux data = Flux.just(byteBuf("something,something-invalid\n")); + + verifyCopyInFailed(sql, data, "syntax error at or near \"command\""); + } + + @Test + void shouldFailOnInvalidDataType() { + String sql = "COPY test (val, timestamp) FROM STDIN WITH DELIMITER ','"; + + Flux data = Flux.just(byteBuf("something,something-invalid\n")); + + verifyCopyInFailed(sql, data, "invalid input syntax for type timestamp: \"something-invalid\""); + } + + private void verifyCopyInFailed(String sql, Flux data, String message) { + this.connection.copyIn(sql, data) + .as(StepVerifier::create) + .consumeErrorWith(e -> assertThat(e) + .isInstanceOf(PostgresqlBadGrammarException.class) + .hasMessage(message) + ) + .verify(); + } + + private void verifyItemsInserted(List t) { + this.connection.createStatement("SELECT val FROM test") + .execute() + .flatMap(res -> res.map(row -> row.get(0))) + .collectSortedList() + .as(StepVerifier::create) + .expectNext(t) + .verifyComplete(); + } + + private ByteBuf byteBuf(String str) { + return Unpooled.wrappedBuffer(str.getBytes()); + } + +} diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlCopyInUnitTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlCopyInUnitTests.java new file mode 100644 index 00000000..40c027d8 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlCopyInUnitTests.java @@ -0,0 +1,171 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.r2dbc.postgresql.ExceptionFactory.PostgresqlNonTransientResourceException; +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.TestClient; +import io.r2dbc.postgresql.client.TransactionStatus; +import io.r2dbc.postgresql.message.Format; +import io.r2dbc.postgresql.message.backend.CommandComplete; +import io.r2dbc.postgresql.message.backend.CopyInResponse; +import io.r2dbc.postgresql.message.backend.ErrorResponse; +import io.r2dbc.postgresql.message.backend.ReadyForQuery; +import io.r2dbc.postgresql.message.frontend.CopyData; +import io.r2dbc.postgresql.message.frontend.CopyDone; +import io.r2dbc.postgresql.message.frontend.CopyFail; +import io.r2dbc.postgresql.message.frontend.Query; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; + +import static io.r2dbc.postgresql.message.backend.ReadyForQuery.TransactionStatus.IDLE; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link PostgresqlCopyIn}. + */ +final class PostgresqlCopyInUnitTests { + + @Test + void copyIn() { + ByteBuf byteBuffer = byteBuf("a\n"); + Client client = TestClient.builder() + .expectRequest(new Query("some-sql")).thenRespond(new CopyInResponse(emptySet(), Format.FORMAT_TEXT)) + .expectRequest(new CopyData(byteBuffer), CopyDone.INSTANCE).thenRespond( + new CommandComplete("cmd", 1, 1), + new ReadyForQuery(IDLE) + ).build(); + + new PostgresqlCopyIn(MockContext.builder().client(client).build()) + .copy("some-sql", Flux.just(byteBuffer)) + .as(StepVerifier::create) + .expectNext(1L) + .verifyComplete(); + } + + @Test + void copyInInvalidQuery() { + ByteBuf byteBuffer = byteBuf("a\n"); + String sql = "invalid-sql"; + Client client = TestClient.builder() + .expectRequest(new Query(sql)).thenRespond(new CommandComplete("command", 0, 9)) + .build(); + + new PostgresqlCopyIn(MockContext.builder().client(client).build()) + .copy(sql, Flux.just(byteBuffer)) + .as(StepVerifier::create) + .consumeErrorWith(e -> assertThat(e) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Copy from stdin query expected, sql='invalid-sql', message=CommandComplete{command=command, rowId=0, rows=9}") + ) + .verify(); + } + + @Test + void copyInErrorResponse() { + ByteBuf byteBuffer = byteBuf("a\n"); + Client client = TestClient.builder() + .expectRequest(new Query("some-sql")).thenRespond(new ErrorResponse(emptyList())) + .build(); + + new PostgresqlCopyIn(MockContext.builder().client(client).build()) + .copy("some-sql", Flux.just(byteBuffer)) + .as(StepVerifier::create) + .expectError(PostgresqlNonTransientResourceException.class) + .verify(); + } + + @Test + void copyInEmpty() { + Client client = TestClient.builder() + .transactionStatus(TransactionStatus.IDLE) + .expectRequest(new Query("some-sql")).thenRespond(new CopyInResponse(emptySet(), Format.FORMAT_TEXT)) + .expectRequest(CopyDone.INSTANCE).thenRespond( + new CommandComplete("cmd", 1, 0), + new ReadyForQuery(ReadyForQuery.TransactionStatus.IDLE) + ) + .build(); + + new PostgresqlCopyIn(MockContext.builder().client(client).build()) + .copy("some-sql", Flux.empty()) + .as(StepVerifier::create) + .expectNext(0L) + .verifyComplete(); + } + + @Test + void copyInError() { + TestPublisher testPublisher = TestPublisher.createCold(); + testPublisher.next(byteBuf("a\n")); + testPublisher.next(byteBuf("b\n")); + testPublisher.error(new RuntimeException("Failed")); + + Client client = TestClient.builder() + .expectRequest(new Query("some-sql")).thenRespond(new CopyInResponse(emptySet(), Format.FORMAT_TEXT)) + .expectRequest( + new CopyData(byteBuf("a\n")), + new CopyData(byteBuf("b\n")), + new CopyFail("Copy operation failed: Failed") + ).thenRespond( + new CommandComplete("cmd", 1, 1), + new ReadyForQuery(IDLE) + ).build(); + + new PostgresqlCopyIn(MockContext.builder().client(client).build()) + .copy("some-sql", testPublisher.flux()) + .as(StepVerifier::create) + .expectError(RuntimeException.class) + .verify(); + } + + @Test + void copyInCancel() { + TestPublisher testPublisher = TestPublisher.create(); + + Client client = TestClient.builder() + .expectRequest(new Query("some-sql")).thenRespond(new CopyInResponse(emptySet(), Format.FORMAT_TEXT)) + .expectRequest( + new CopyData(byteBuf("a")), + new CopyData(byteBuf("b")), + new CopyFail("Copy operation failed: Cancelled") + ).thenRespond( + new CommandComplete("cmd", 1, 1), + new ReadyForQuery(IDLE) + ).build(); + + new PostgresqlCopyIn(MockContext.builder().client(client).build()) + .copy("some-sql", testPublisher.flux()) + .as(StepVerifier::create) + .then(() -> { + testPublisher.next(byteBuf("a")); + testPublisher.next(byteBuf("b")); + }) + .thenCancel() + .verify(); + } + + private ByteBuf byteBuf(String str) { + return Unpooled.wrappedBuffer(str.getBytes()); + } + +} diff --git a/src/test/java/io/r2dbc/postgresql/api/MockPostgresqlConnection.java b/src/test/java/io/r2dbc/postgresql/api/MockPostgresqlConnection.java index 7060f9b0..5dc7dd07 100644 --- a/src/test/java/io/r2dbc/postgresql/api/MockPostgresqlConnection.java +++ b/src/test/java/io/r2dbc/postgresql/api/MockPostgresqlConnection.java @@ -16,9 +16,11 @@ package io.r2dbc.postgresql.api; +import io.netty.buffer.ByteBuf; import io.r2dbc.spi.IsolationLevel; import io.r2dbc.spi.TransactionDefinition; import io.r2dbc.spi.ValidationDepth; +import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -132,4 +134,9 @@ public Mono validate(ValidationDepth depth) { return Mono.empty(); } + @Override + public Mono copyIn(String sql, Publisher stdin) { + return Mono.empty(); + } + } diff --git a/src/test/java/io/r2dbc/postgresql/client/TestClient.java b/src/test/java/io/r2dbc/postgresql/client/TestClient.java index 2fdc8216..78dc563f 100644 --- a/src/test/java/io/r2dbc/postgresql/client/TestClient.java +++ b/src/test/java/io/r2dbc/postgresql/client/TestClient.java @@ -326,10 +326,10 @@ public T done() { return this.chain; } - public Exchange.Builder> expectRequest(FrontendMessage request) { - Assert.requireNonNull(request, "request must not be null"); + public Exchange.Builder> expectRequest(FrontendMessage... requests) { + Assert.requireNonNull(requests, "requests must not be null"); - Exchange.Builder> exchange = new Exchange.Builder<>(this, request); + Exchange.Builder> exchange = new Exchange.Builder<>(this, requests); this.exchanges.add(exchange); return exchange; }