diff --git a/src/main/java/com/rabbitmq/stream/impl/ConsumersCoordinator.java b/src/main/java/com/rabbitmq/stream/impl/ConsumersCoordinator.java index 99191cee2a..ce8fdf9f74 100644 --- a/src/main/java/com/rabbitmq/stream/impl/ConsumersCoordinator.java +++ b/src/main/java/com/rabbitmq/stream/impl/ConsumersCoordinator.java @@ -14,13 +14,7 @@ // info@rabbitmq.com. package com.rabbitmq.stream.impl; -import static com.rabbitmq.stream.impl.Utils.convertCodeToException; -import static com.rabbitmq.stream.impl.Utils.formatConstant; -import static com.rabbitmq.stream.impl.Utils.isSac; -import static com.rabbitmq.stream.impl.Utils.jsonField; -import static com.rabbitmq.stream.impl.Utils.namedFunction; -import static com.rabbitmq.stream.impl.Utils.namedRunnable; -import static com.rabbitmq.stream.impl.Utils.quote; +import static com.rabbitmq.stream.impl.Utils.*; import static java.lang.String.format; import com.rabbitmq.stream.*; @@ -51,6 +45,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.*; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -80,6 +76,7 @@ class ConsumersCoordinator { new DefaultExecutorServiceFactory( Runtime.getRuntime().availableProcessors(), 10, "rabbitmq-stream-consumer-connection-"); private final boolean forceReplica; + private final Lock coordinatorLock = new ReentrantLock(); ConsumersCoordinator( StreamEnvironment environment, @@ -116,47 +113,51 @@ Runnable subscribe( MessageHandler messageHandler, Map subscriptionProperties, ConsumerFlowStrategy flowStrategy) { - List candidates = findBrokersForStream(stream, forceReplica); - Client.Broker newNode = pickBroker(candidates); - if (newNode == null) { - throw new IllegalStateException("No available node to subscribe to"); - } - - // create stream subscription to track final and changing state of this very subscription - // we keep this instance when we move the subscription from a client to another one - SubscriptionTracker subscriptionTracker = - new SubscriptionTracker( - this.trackerIdSequence.getAndIncrement(), - consumer, - stream, - offsetSpecification, - trackingReference, - subscriptionListener, - trackingClosingCallback, - messageHandler, - subscriptionProperties, - flowStrategy); + return lock( + this.coordinatorLock, + () -> { + List candidates = findBrokersForStream(stream, forceReplica); + Client.Broker newNode = pickBroker(candidates); + if (newNode == null) { + throw new IllegalStateException("No available node to subscribe to"); + } - try { - addToManager(newNode, subscriptionTracker, offsetSpecification, true); - } catch (ConnectionStreamException e) { - // these exceptions are not public - throw new StreamException(e.getMessage()); - } + // create stream subscription to track final and changing state of this very subscription + // we keep this instance when we move the subscription from a client to another one + SubscriptionTracker subscriptionTracker = + new SubscriptionTracker( + this.trackerIdSequence.getAndIncrement(), + consumer, + stream, + offsetSpecification, + trackingReference, + subscriptionListener, + trackingClosingCallback, + messageHandler, + subscriptionProperties, + flowStrategy); + + try { + addToManager(newNode, subscriptionTracker, offsetSpecification, true); + } catch (ConnectionStreamException e) { + // these exceptions are not public + throw new StreamException(e.getMessage()); + } - if (debug) { - this.trackers.add(subscriptionTracker); - return () -> { - try { - this.trackers.remove(subscriptionTracker); - } catch (Exception e) { - LOGGER.debug("Error while removing subscription tracker from list"); - } - subscriptionTracker.cancel(); - }; - } else { - return subscriptionTracker::cancel; - } + if (debug) { + this.trackers.add(subscriptionTracker); + return () -> { + try { + this.trackers.remove(subscriptionTracker); + } catch (Exception e) { + LOGGER.debug("Error while removing subscription tracker from list"); + } + subscriptionTracker.cancel(); + }; + } else { + return subscriptionTracker::cancel; + } + }); } private void addToManager( diff --git a/src/main/java/com/rabbitmq/stream/impl/ProducersCoordinator.java b/src/main/java/com/rabbitmq/stream/impl/ProducersCoordinator.java index 143ef57204..57e82d67fb 100644 --- a/src/main/java/com/rabbitmq/stream/impl/ProducersCoordinator.java +++ b/src/main/java/com/rabbitmq/stream/impl/ProducersCoordinator.java @@ -14,12 +14,7 @@ // info@rabbitmq.com. package com.rabbitmq.stream.impl; -import static com.rabbitmq.stream.impl.Utils.callAndMaybeRetry; -import static com.rabbitmq.stream.impl.Utils.formatConstant; -import static com.rabbitmq.stream.impl.Utils.jsonField; -import static com.rabbitmq.stream.impl.Utils.namedFunction; -import static com.rabbitmq.stream.impl.Utils.namedRunnable; -import static com.rabbitmq.stream.impl.Utils.quote; +import static com.rabbitmq.stream.impl.Utils.*; import static java.util.stream.Collectors.toSet; import com.rabbitmq.stream.BackOffDelayPolicy; @@ -52,6 +47,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -75,6 +72,7 @@ class ProducersCoordinator { private final ExecutorServiceFactory executorServiceFactory = new DefaultExecutorServiceFactory( Runtime.getRuntime().availableProcessors(), 10, "rabbitmq-stream-producer-connection-"); + private final Lock coordinatorLock = new ReentrantLock(); ProducersCoordinator( StreamEnvironment environment, @@ -94,19 +92,26 @@ private static String keyForNode(Client.Broker broker) { } Runnable registerProducer(StreamProducer producer, String reference, String stream) { - ProducerTracker tracker = - new ProducerTracker(trackerIdSequence.getAndIncrement(), reference, stream, producer); - if (debug) { - this.producerTrackers.add(tracker); - } - return registerAgentTracker(tracker, stream); + return lock( + this.coordinatorLock, + () -> { + ProducerTracker tracker = + new ProducerTracker(trackerIdSequence.getAndIncrement(), reference, stream, producer); + if (debug) { + this.producerTrackers.add(tracker); + } + return registerAgentTracker(tracker, stream); + }); } Runnable registerTrackingConsumer(StreamConsumer consumer) { - return registerAgentTracker( - new TrackingConsumerTracker( - trackerIdSequence.getAndIncrement(), consumer.stream(), consumer), - consumer.stream()); + return lock( + this.coordinatorLock, + () -> + registerAgentTracker( + new TrackingConsumerTracker( + trackerIdSequence.getAndIncrement(), consumer.stream(), consumer), + consumer.stream())); } private Runnable registerAgentTracker(AgentTracker tracker, String stream) { diff --git a/src/main/java/com/rabbitmq/stream/impl/Utils.java b/src/main/java/com/rabbitmq/stream/impl/Utils.java index f600f98809..70d4086ad5 100644 --- a/src/main/java/com/rabbitmq/stream/impl/Utils.java +++ b/src/main/java/com/rabbitmq/stream/impl/Utils.java @@ -39,6 +39,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Lock; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.LongConsumer; @@ -662,4 +663,13 @@ boolean get() { return this.value; } } + + static T lock(Lock lock, Supplier action) { + lock.lock(); + try { + return action.get(); + } finally { + lock.unlock(); + } + } } diff --git a/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentTest.java b/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentTest.java index 9a7d23e210..003e0eae79 100644 --- a/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentTest.java +++ b/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentTest.java @@ -14,6 +14,7 @@ // info@rabbitmq.com. package com.rabbitmq.stream.impl; +import static com.rabbitmq.stream.impl.TestUtils.CountDownLatchConditions.completed; import static com.rabbitmq.stream.impl.TestUtils.ExceptionConditions.responseCode; import static com.rabbitmq.stream.impl.TestUtils.latchAssert; import static com.rabbitmq.stream.impl.TestUtils.localhost; @@ -69,9 +70,7 @@ import java.util.Random; import java.util.Set; import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -82,11 +81,7 @@ import javax.net.ssl.SNIHostName; import javax.net.ssl.SSLParameters; import org.assertj.core.api.ThrowableAssert.ThrowingCallable; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.*; import org.junit.jupiter.api.condition.EnabledIfSystemProperty; import org.junit.jupiter.api.condition.EnabledOnOs; import org.junit.jupiter.api.condition.OS; @@ -778,4 +773,44 @@ void nativeEpollWorksOnLinux() { epollEventLoopGroup.shutdownGracefully(0, 0, SECONDS); } } + + @Test + void enforceEntityPerConnectionLimits() { + int entityCount = 10; + int limit = 3; + ExecutorService executor = Executors.newCachedThreadPool(); + try (Environment env = + environmentBuilder + .maxProducersByConnection(limit) + .maxConsumersByConnection(limit) + .maxTrackingConsumersByConnection(limit) + .build()) { + CountDownLatch latch = new CountDownLatch(entityCount * 2); + IntStream.range(0, entityCount) + .forEach( + i -> { + executor.execute( + () -> { + env.producerBuilder().stream(stream).name(String.valueOf(i)).build(); + latch.countDown(); + }); + }); + IntStream.range(0, entityCount) + .forEach( + i -> { + executor.execute( + () -> { + env.consumerBuilder().stream(stream).messageHandler((ctx, msg) -> {}).build(); + latch.countDown(); + }); + }); + assertThat(latch).is(completed()); + EnvironmentInfo envInfo = MonitoringTestUtils.extract(env); + int expectedConnectionCount = entityCount / limit + 1; + assertThat(envInfo.getProducers().clientCount()).isEqualTo(expectedConnectionCount); + assertThat(envInfo.getConsumers().clients()).hasSize(expectedConnectionCount); + } finally { + executor.shutdownNow(); + } + } }