diff --git a/src/main/java/com/rabbitmq/client/IncrementingCorrelationIdGenerator.java b/src/main/java/com/rabbitmq/client/IncrementingCorrelationIdGenerator.java new file mode 100644 index 0000000000..e9f8012627 --- /dev/null +++ b/src/main/java/com/rabbitmq/client/IncrementingCorrelationIdGenerator.java @@ -0,0 +1,22 @@ +package com.rabbitmq.client; + +import java.util.function.Supplier; + +public class IncrementingCorrelationIdGenerator implements Supplier { + + private final String _prefix; + private int _correlationId; + + public IncrementingCorrelationIdGenerator(String _prefix) { + this._prefix = _prefix; + } + + @Override + public String get() { + return _prefix + _correlationId++; + } + + public int getCorrelationId() { + return _correlationId; + } +} diff --git a/src/main/java/com/rabbitmq/client/RpcClient.java b/src/main/java/com/rabbitmq/client/RpcClient.java index 2dfd50c830..326e5a2f44 100644 --- a/src/main/java/com/rabbitmq/client/RpcClient.java +++ b/src/main/java/com/rabbitmq/client/RpcClient.java @@ -28,6 +28,7 @@ import java.util.Map.Entry; import java.util.concurrent.TimeoutException; import java.util.function.Function; +import java.util.function.Supplier; import com.rabbitmq.client.impl.MethodArgumentReader; import com.rabbitmq.client.impl.MethodArgumentWriter; @@ -84,7 +85,7 @@ public class RpcClient { /** Map from request correlation ID to continuation BlockingCell */ private final Map> _continuationMap = new HashMap>(); /** Contains the most recently-used request correlation ID */ - private int _correlationId; + private final Supplier _correlationIdGenerator; /** Consumer attached to our reply queue */ private DefaultConsumer _consumer; @@ -109,7 +110,7 @@ public RpcClient(RpcClientParams params) throws _timeout = params.getTimeout(); _useMandatory = params.shouldUseMandatory(); _replyHandler = params.getReplyHandler(); - _correlationId = 0; + _correlationIdGenerator = params.getCorrelationIdGenerator(); _consumer = setupConsumer(); if (_useMandatory) { @@ -208,8 +209,7 @@ public Response doCall(AMQP.BasicProperties props, byte[] message, int timeout) BlockingCell k = new BlockingCell(); String replyId; synchronized (_continuationMap) { - _correlationId++; - replyId = "" + _correlationId; + replyId = _correlationIdGenerator.get(); props = ((props==null) ? new AMQP.BasicProperties.Builder() : props.builder()) .correlationId(replyId).replyTo(_replyTo).build(); _continuationMap.put(replyId, k); @@ -392,9 +392,14 @@ public Map> getContinuationMap() { /** * Retrieve the correlation id. * @return the most recently used correlation id + * @deprecated Only works for {@link IncrementingCorrelationIdGenerator} */ public int getCorrelationId() { - return _correlationId; + if (_correlationIdGenerator instanceof IncrementingCorrelationIdGenerator) { + return ((IncrementingCorrelationIdGenerator) _correlationIdGenerator).getCorrelationId(); + } else { + throw new UnsupportedOperationException(); + } } /** diff --git a/src/main/java/com/rabbitmq/client/RpcClientParams.java b/src/main/java/com/rabbitmq/client/RpcClientParams.java index ce046a6cb6..db32896cd5 100644 --- a/src/main/java/com/rabbitmq/client/RpcClientParams.java +++ b/src/main/java/com/rabbitmq/client/RpcClientParams.java @@ -16,6 +16,7 @@ package com.rabbitmq.client; import java.util.function.Function; +import java.util.function.Supplier; /** * Holder class to configure a {@link RpcClient}. @@ -54,6 +55,8 @@ public class RpcClientParams { */ private Function replyHandler = RpcClient.DEFAULT_REPLY_HANDLER; + private Supplier correlationIdGenerator = new IncrementingCorrelationIdGenerator(""); + /** * Set the channel to use for communication. * @@ -170,6 +173,15 @@ public boolean shouldUseMandatory() { return useMandatory; } + public RpcClientParams correlationIdGenerator(Supplier correlationIdGenerator) { + this.correlationIdGenerator = correlationIdGenerator; + return this; + } + + public Supplier getCorrelationIdGenerator() { + return correlationIdGenerator; + } + public Function getReplyHandler() { return replyHandler; } diff --git a/src/test/java/com/rabbitmq/client/test/RpcTest.java b/src/test/java/com/rabbitmq/client/test/RpcTest.java index 66251738d3..8837f726c7 100644 --- a/src/test/java/com/rabbitmq/client/test/RpcTest.java +++ b/src/test/java/com/rabbitmq/client/test/RpcTest.java @@ -24,6 +24,7 @@ import com.rabbitmq.client.impl.recovery.RecordedQueue; import com.rabbitmq.client.impl.recovery.TopologyRecoveryFilter; import com.rabbitmq.tools.Host; +import org.hamcrest.CoreMatchers; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -39,9 +40,7 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.awaitility.Awaitility.waitAtMost; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; public class RpcTest { @@ -138,6 +137,25 @@ public void rpcUnroutableWithMandatoryFlagShouldThrowUnroutableException() throw client.close(); } + @Test + public void rpcCustomCorrelatorId() throws Exception { + rpcServer = new TestRpcServer(serverChannel, queue); + new Thread(() -> { + try { + rpcServer.mainloop(); + } catch (Exception e) { + // safe to ignore when loops ends/server is canceled + } + }).start(); + RpcClient client = new RpcClient(new RpcClientParams() + .channel(clientChannel).exchange("").routingKey(queue).timeout(1000) + .correlationIdGenerator(new IncrementingCorrelationIdGenerator("myPrefix-")) + ); + RpcClient.Response response = client.doCall(null, "hello".getBytes()); + assertThat(response.getProperties().getCorrelationId(), CoreMatchers.equalTo("myPrefix-0")); + client.close(); + } + @Test public void rpcCustomReplyHandler() throws Exception { rpcServer = new TestRpcServer(serverChannel, queue); @@ -156,7 +174,6 @@ public void rpcCustomReplyHandler() throws Exception { return RpcClient.DEFAULT_REPLY_HANDLER.apply(reply); }) ); - assertEquals(0, replyHandlerCalls.get()); RpcClient.Response response = client.doCall(null, "hello".getBytes()); assertEquals(1, replyHandlerCalls.get()); assertEquals("*** hello ***", new String(response.getBody()));