diff --git a/src/main/java/com/rabbitmq/client/impl/AMQConnection.java b/src/main/java/com/rabbitmq/client/impl/AMQConnection.java index 1ce083af30..3dfdf7be51 100644 --- a/src/main/java/com/rabbitmq/client/impl/AMQConnection.java +++ b/src/main/java/com/rabbitmq/client/impl/AMQConnection.java @@ -58,6 +58,8 @@ public class AMQConnection extends ShutdownNotifierComponent implements Connecti private final ScheduledExecutorService heartbeatExecutor; private final ExecutorService shutdownExecutor; private Thread mainLoopThread; + private final AtomicBoolean ioLoopThreadSet = new AtomicBoolean(false); + private volatile Thread ioLoopThread; private ThreadFactory threadFactory = Executors.defaultThreadFactory(); private String id; @@ -504,6 +506,7 @@ public void startMainLoop() { MainLoop loop = new MainLoop(); final String name = "AMQP Connection " + getHostAddress() + ":" + getPort(); mainLoopThread = Environment.newThread(threadFactory, loop, name); + ioLoopThread(mainLoopThread); mainLoopThread.start(); } @@ -1104,7 +1107,7 @@ public void close(int closeCode, boolean abort) throws IOException { - boolean sync = !(Thread.currentThread() == mainLoopThread); + boolean sync = !(Thread.currentThread() == ioLoopThread); try { AMQP.Connection.Close reason = @@ -1195,6 +1198,12 @@ public void setId(String id) { this.id = id; } + public void ioLoopThread(Thread thread) { + if (this.ioLoopThreadSet.compareAndSet(false, true)) { + this.ioLoopThread = thread; + } + } + public int getChannelRpcTimeout() { return channelRpcTimeout; } diff --git a/src/main/java/com/rabbitmq/client/impl/nio/NioLoop.java b/src/main/java/com/rabbitmq/client/impl/nio/NioLoop.java index ae7fa970e9..b143429ea7 100644 --- a/src/main/java/com/rabbitmq/client/impl/nio/NioLoop.java +++ b/src/main/java/com/rabbitmq/client/impl/nio/NioLoop.java @@ -157,6 +157,7 @@ public void run() { if (frame != null) { try { + state.getConnection().ioLoopThread(Thread.currentThread()); boolean noProblem = state.getConnection().handleReadFrame(frame); if (noProblem && (!state.getConnection().isRunning() || state.getConnection().hasBrokerInitiatedShutdown())) { // looks like the frame was Close-Ok or Close diff --git a/src/test/java/com/rabbitmq/client/test/BlockedConnectionTest.java b/src/test/java/com/rabbitmq/client/test/BlockedConnectionTest.java new file mode 100644 index 0000000000..f8748ed27e --- /dev/null +++ b/src/test/java/com/rabbitmq/client/test/BlockedConnectionTest.java @@ -0,0 +1,62 @@ +// Copyright (c) 2023 VMware, Inc. or its affiliates. All rights reserved. +// +// This software, the RabbitMQ Java client library, is triple-licensed under the +// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 +// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see +// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. + +package com.rabbitmq.client.test; + +import static com.rabbitmq.client.test.TestUtils.LatchConditions.completed; +import static com.rabbitmq.client.test.TestUtils.waitAtMost; +import static org.assertj.core.api.Assertions.assertThat; + +import com.rabbitmq.client.Channel; +import com.rabbitmq.client.Connection; +import com.rabbitmq.client.ConnectionFactory; +import java.util.concurrent.CountDownLatch; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +public class BlockedConnectionTest extends BrokerTestCase { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void errorInBlockListenerShouldCloseConnection(boolean nio) throws Exception { + ConnectionFactory cf = TestUtils.connectionFactory(); + if (nio) { + cf.useNio(); + } else { + cf.useBlockingIo(); + } + Connection c = cf.newConnection(); + CountDownLatch shutdownLatch = new CountDownLatch(1); + c.addShutdownListener(cause -> shutdownLatch.countDown()); + CountDownLatch blockedLatch = new CountDownLatch(1); + c.addBlockedListener( + reason -> { + blockedLatch.countDown(); + throw new RuntimeException("error in blocked listener!"); + }, + () -> {}); + try { + block(); + Channel ch = c.createChannel(); + ch.basicPublish("", "", null, "dummy".getBytes()); + assertThat(blockedLatch).is(completed()); + } finally { + unblock(); + } + assertThat(shutdownLatch).is(completed()); + waitAtMost(() -> !c.isOpen()); + } + +} diff --git a/src/test/java/com/rabbitmq/client/test/ClientTestSuite.java b/src/test/java/com/rabbitmq/client/test/ClientTestSuite.java index 33eb4aa8ed..2162103e0f 100644 --- a/src/test/java/com/rabbitmq/client/test/ClientTestSuite.java +++ b/src/test/java/com/rabbitmq/client/test/ClientTestSuite.java @@ -73,7 +73,8 @@ OAuth2ClientCredentialsGrantCredentialsProviderTest.class, RefreshCredentialsTest.class, AMQConnectionRefreshCredentialsTest.class, - ValueWriterTest.class + ValueWriterTest.class, + BlockedConnectionTest.class }) public class ClientTestSuite {