Skip to content

Commit 924c271

Browse files
committed
Add unit tests for configuring a custom AsyncTaskExecutor for ClusterCommandExceutor.
Closes #2594
1 parent 35f8eb9 commit 924c271

File tree

4 files changed

+156
-66
lines changed

4 files changed

+156
-66
lines changed

src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -233,36 +233,42 @@ private <T> MultiNodeResult<T> collectResults(Map<NodeExecution, Future<NodeResu
233233
while (!done) {
234234

235235
done = true;
236+
236237
for (Map.Entry<NodeExecution, Future<NodeResult<T>>> entry : futures.entrySet()) {
237238

238239
if (!entry.getValue().isDone() && !entry.getValue().isCancelled()) {
239240
done = false;
240241
} else {
241242

242243
NodeExecution execution = entry.getKey();
244+
243245
try {
244246

245247
String futureId = ObjectUtils.getIdentityHexString(entry.getValue());
248+
246249
if (!saveGuard.contains(futureId)) {
247250

248251
if (execution.isPositional()) {
249252
result.add(execution.getPositionalKey(), entry.getValue().get());
250253
} else {
251254
result.add(entry.getValue().get());
252255
}
256+
253257
saveGuard.add(futureId);
254258
}
255-
} catch (ExecutionException e) {
259+
} catch (ExecutionException cause) {
256260

257-
RuntimeException ex = convertToDataAccessException((Exception) e.getCause());
261+
RuntimeException exception = convertToDataAccessException((Exception) cause.getCause());
258262

259-
exceptions.put(execution.getNode(), ex != null ? ex : e.getCause());
260-
} catch (InterruptedException e) {
263+
exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause());
264+
} catch (InterruptedException cause) {
261265

262266
Thread.currentThread().interrupt();
263267

264-
RuntimeException ex = convertToDataAccessException((Exception) e.getCause());
265-
exceptions.put(execution.getNode(), ex != null ? ex : e.getCause());
268+
RuntimeException exception = convertToDataAccessException((Exception) cause.getCause());
269+
270+
exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause());
271+
266272
break;
267273
}
268274
}
@@ -271,7 +277,6 @@ private <T> MultiNodeResult<T> collectResults(Map<NodeExecution, Future<NodeResu
271277
try {
272278
Thread.sleep(10);
273279
} catch (InterruptedException e) {
274-
275280
done = true;
276281
Thread.currentThread().interrupt();
277282
}
@@ -280,18 +285,19 @@ private <T> MultiNodeResult<T> collectResults(Map<NodeExecution, Future<NodeResu
280285
if (!exceptions.isEmpty()) {
281286
throw new ClusterCommandExecutionFailureException(new ArrayList<>(exceptions.values()));
282287
}
288+
283289
return result;
284290
}
285291

286292
/**
287293
* Run {@link MultiKeyClusterCommandCallback} with on a curated set of nodes serving one or more keys.
288294
*
289-
* @param cmd must not be {@literal null}.
295+
* @param commandCallback must not be {@literal null}.
290296
* @return never {@literal null}.
291297
* @throws ClusterCommandExecutionFailureException if a failure occurs while executing the given
292298
* {@link MultiKeyClusterCommandCallback command}.
293299
*/
294-
public <S, T> MultiNodeResult<T> executeMultiKeyCommand(MultiKeyClusterCommandCallback<S, T> cmd,
300+
public <S, T> MultiNodeResult<T> executeMultiKeyCommand(MultiKeyClusterCommandCallback<S, T> commandCallback,
295301
Iterable<byte[]> keys) {
296302

297303
Map<RedisClusterNode, PositionalKeys> nodeKeyMap = new HashMap<>();
@@ -309,19 +315,19 @@ public <S, T> MultiNodeResult<T> executeMultiKeyCommand(MultiKeyClusterCommandCa
309315

310316
if (entry.getKey().isMaster()) {
311317
for (PositionalKey key : entry.getValue()) {
312-
futures.put(new NodeExecution(entry.getKey(), key),
313-
executor.submit(() -> executeMultiKeyCommandOnSingleNode(cmd, entry.getKey(), key.getBytes())));
318+
futures.put(new NodeExecution(entry.getKey(), key), this.executor.submit(() ->
319+
executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes())));
314320
}
315321
}
316322
}
317323

318324
return collectResults(futures);
319325
}
320326

321-
private <S, T> NodeResult<T> executeMultiKeyCommandOnSingleNode(MultiKeyClusterCommandCallback<S, T> cmd,
327+
private <S, T> NodeResult<T> executeMultiKeyCommandOnSingleNode(MultiKeyClusterCommandCallback<S, T> commandCallback,
322328
RedisClusterNode node, byte[] key) {
323329

324-
Assert.notNull(cmd, "MultiKeyCommandCallback must not be null");
330+
Assert.notNull(commandCallback, "MultiKeyCommandCallback must not be null");
325331
Assert.notNull(node, "RedisClusterNode must not be null");
326332
Assert.notNull(key, "Keys for execution must not be null");
327333

@@ -330,7 +336,7 @@ private <S, T> NodeResult<T> executeMultiKeyCommandOnSingleNode(MultiKeyClusterC
330336
Assert.notNull(client, "Could not acquire resource for node; Is your cluster info up to date");
331337

332338
try {
333-
return new NodeResult<>(node, cmd.doInCluster(client, key), key);
339+
return new NodeResult<>(node, commandCallback.doInCluster(client, key), key);
334340
} catch (RuntimeException ex) {
335341

336342
RuntimeException translatedException = convertToDataAccessException(ex);
@@ -345,8 +351,8 @@ private ClusterTopology getClusterTopology() {
345351
}
346352

347353
@Nullable
348-
private DataAccessException convertToDataAccessException(Exception e) {
349-
return exceptionTranslationStrategy.translate(e);
354+
private DataAccessException convertToDataAccessException(Exception cause) {
355+
return exceptionTranslationStrategy.translate(cause);
350356
}
351357

352358
/**
@@ -361,12 +367,12 @@ public void setMaxRedirects(int maxRedirects) {
361367
@Override
362368
public void destroy() throws Exception {
363369

364-
if (executor instanceof DisposableBean) {
365-
((DisposableBean) executor).destroy();
370+
if (this.executor instanceof DisposableBean disposableBean) {
371+
disposableBean.destroy();
366372
}
367373

368-
if (resourceProvider instanceof DisposableBean) {
369-
((DisposableBean) resourceProvider).destroy();
374+
if (this.resourceProvider instanceof DisposableBean disposableBean) {
375+
disposableBean.destroy();
370376
}
371377
}
372378

src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionFactory.java

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
import org.springframework.data.redis.connection.RedisClusterConnection;
6363
import org.springframework.data.redis.connection.RedisConfiguration;
6464
import org.springframework.data.redis.connection.RedisConfiguration.ClusterConfiguration;
65-
import org.springframework.data.redis.connection.RedisConfiguration.DomainSocketConfiguration;
6665
import org.springframework.data.redis.connection.RedisConfiguration.WithDatabaseIndex;
6766
import org.springframework.data.redis.connection.RedisConfiguration.WithPassword;
6867
import org.springframework.data.redis.connection.RedisConnection;
@@ -354,10 +353,15 @@ public LettuceConnectionFactory(RedisStandaloneConfiguration standaloneConfigura
354353
this.configuration = this.standaloneConfig;
355354
}
356355

356+
@Nullable
357+
protected ClusterCommandExecutor getClusterCommandExecutor() {
358+
return this.clusterCommandExecutor;
359+
}
360+
357361
@Override
358362
public void start() {
359363

360-
State current = state.getAndUpdate(state -> isCreatedOrStopped(state) ? State.STARTING : state);
364+
State current = this.state.getAndUpdate(state -> isCreatedOrStopped(state) ? State.STARTING : state);
361365

362366
if (isCreatedOrStopped(current)) {
363367

@@ -370,7 +374,7 @@ public void start() {
370374
this.clusterCommandExecutor = newClusterCommandExecutor();
371375
}
372376

373-
state.set(State.STARTED);
377+
this.state.set(State.STARTED);
374378

375379
if (getEagerInitialization() && getShareNativeConnection()) {
376380
initConnection();
@@ -459,7 +463,7 @@ public void setPhase(int phase) {
459463

460464
@Override
461465
public boolean isRunning() {
462-
return State.STARTED.equals(state.get());
466+
return State.STARTED.equals(this.state.get());
463467
}
464468

465469
@Override
@@ -473,17 +477,20 @@ public void afterPropertiesSet() {
473477
public void destroy() {
474478

475479
stop();
476-
client = null;
480+
this.client = null;
481+
482+
ClusterCommandExecutor clusterCommandExecutor = getClusterCommandExecutor();
477483

478484
if (clusterCommandExecutor != null) {
479485
try {
480486
clusterCommandExecutor.destroy();
481-
} catch (Exception ex) {
482-
log.warn("Cannot properly close cluster command executor", ex);
487+
this.clusterCommandExecutor = null;
488+
} catch (Exception cause) {
489+
log.warn("Cannot properly close cluster command executor", cause);
483490
}
484491
}
485492

486-
state.set(State.DESTROYED);
493+
this.state.set(State.DESTROYED);
487494
}
488495

489496
private void dispose(@Nullable LettuceConnectionProvider connectionProvider) {
@@ -511,7 +518,7 @@ public RedisConnection getConnection() {
511518
LettuceConnection connection = doCreateLettuceConnection(getSharedConnection(), connectionProvider,
512519
getTimeout(), getDatabase());
513520

514-
connection.setConvertPipelineAndTxResults(convertPipelineAndTxResults);
521+
connection.setConvertPipelineAndTxResults(this.convertPipelineAndTxResults);
515522

516523
return connection;
517524
}
@@ -531,8 +538,8 @@ public RedisClusterConnection getClusterConnection() {
531538

532539
LettuceClusterTopologyProvider topologyProvider = new LettuceClusterTopologyProvider(clusterClient);
533540

534-
return doCreateLettuceClusterConnection(sharedConnection, connectionProvider, topologyProvider,
535-
clusterCommandExecutor, clientConfiguration.getCommandTimeout());
541+
return doCreateLettuceClusterConnection(sharedConnection, this.connectionProvider, topologyProvider,
542+
getClusterCommandExecutor(), this.clientConfiguration.getCommandTimeout());
536543
}
537544

538545
/**
@@ -858,7 +865,7 @@ public void setValidateConnection(boolean validateConnection) {
858865
* @return native connection shared.
859866
*/
860867
public boolean getShareNativeConnection() {
861-
return shareNativeConnection;
868+
return this.shareNativeConnection;
862869
}
863870

864871
/**
@@ -881,7 +888,7 @@ public void setShareNativeConnection(boolean shareNativeConnection) {
881888
* @since 2.2
882889
*/
883890
public boolean getEagerInitialization() {
884-
return eagerInitialization;
891+
return this.eagerInitialization;
885892
}
886893

887894
/**
@@ -1203,7 +1210,7 @@ protected StatefulConnection<ByteBuffer, ByteBuffer> getSharedReactiveConnection
12031210
return shareNativeConnection ? getOrCreateSharedReactiveConnection().getConnection() : null;
12041211
}
12051212

1206-
private LettuceConnectionProvider createConnectionProvider(AbstractRedisClient client, RedisCodec<?, ?> codec) {
1213+
LettuceConnectionProvider createConnectionProvider(AbstractRedisClient client, RedisCodec<?, ?> codec) {
12071214

12081215
LettuceConnectionProvider connectionProvider = doCreateConnectionProvider(client, codec);
12091216

src/test/java/org/springframework/data/redis/connection/jedis/JedisConnectionFactoryUnitTests.java

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
*/
1616
package org.springframework.data.redis.connection.jedis;
1717

18-
import static org.assertj.core.api.Assertions.*;
19-
import static org.mockito.Mockito.*;
20-
21-
import redis.clients.jedis.JedisClientConfig;
22-
import redis.clients.jedis.JedisCluster;
23-
import redis.clients.jedis.JedisPoolConfig;
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
20+
import static org.mockito.Mockito.any;
21+
import static org.mockito.Mockito.doReturn;
22+
import static org.mockito.Mockito.eq;
23+
import static org.mockito.Mockito.mock;
24+
import static org.mockito.Mockito.never;
25+
import static org.mockito.Mockito.spy;
26+
import static org.mockito.Mockito.times;
27+
import static org.mockito.Mockito.verify;
2428

2529
import java.io.IOException;
2630
import java.security.NoSuchAlgorithmException;
@@ -33,15 +37,25 @@
3337
import javax.net.ssl.SSLParameters;
3438
import javax.net.ssl.SSLSocketFactory;
3539

36-
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
3740
import org.junit.jupiter.api.Test;
41+
42+
import org.springframework.core.task.AsyncTaskExecutor;
43+
import org.springframework.data.redis.connection.ClusterCommandExecutor;
44+
import org.springframework.data.redis.connection.ClusterTopologyProvider;
3845
import org.springframework.data.redis.connection.RedisClusterConfiguration;
3946
import org.springframework.data.redis.connection.RedisPassword;
4047
import org.springframework.data.redis.connection.RedisSentinelConfiguration;
4148
import org.springframework.data.redis.connection.RedisStandaloneConfiguration;
4249
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory.State;
50+
import org.springframework.lang.Nullable;
4351
import org.springframework.test.util.ReflectionTestUtils;
4452

53+
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
54+
55+
import redis.clients.jedis.JedisClientConfig;
56+
import redis.clients.jedis.JedisCluster;
57+
import redis.clients.jedis.JedisPoolConfig;
58+
4559
/**
4660
* Unit tests for {@link JedisConnectionFactory}.
4761
*
@@ -335,24 +349,59 @@ void afterPropertiesTriggersConnectionInitialization() {
335349
assertThat(connectionFactory.isRunning()).isTrue();
336350
}
337351

338-
private JedisConnectionFactory initSpyedConnectionFactory(RedisSentinelConfiguration sentinelConfig,
339-
JedisPoolConfig poolConfig) {
352+
@Test // GH-2594
353+
void configuresCustomTaskExecutorCorrectly() {
354+
355+
AsyncTaskExecutor mockTaskExecutor = mock(AsyncTaskExecutor.class);
356+
ClusterTopologyProvider mockClusterTopologyProvider = mock(ClusterTopologyProvider.class);
357+
JedisCluster mockJedisCluster = mock(JedisCluster.class);
358+
359+
RedisClusterConfiguration clusterConfiguration = new RedisClusterConfiguration();
360+
361+
clusterConfiguration.setAsyncTaskExecutor(mockTaskExecutor);
362+
363+
JedisConnectionFactory connectionFactory = initSpyedConnectionFactory(clusterConfiguration, null);
364+
365+
doReturn(false).when(connectionFactory).getUsePool();
366+
doReturn(mockJedisCluster).when(connectionFactory).createCluster();
367+
doReturn(mockClusterTopologyProvider).when(connectionFactory).createTopologyProvider(eq(mockJedisCluster));
368+
369+
connectionFactory.start();
370+
371+
assertThat(connectionFactory.isRunning()).isTrue();
372+
373+
ClusterCommandExecutor clusterCommandExecutor = connectionFactory.getClusterCommandExecutor();
374+
375+
assertThat(clusterCommandExecutor).isNotNull();
376+
assertThat(ReflectionTestUtils.getField(clusterCommandExecutor, "executor")).isEqualTo(mockTaskExecutor);
377+
}
378+
379+
private JedisConnectionFactory initSpyedConnectionFactory(RedisSentinelConfiguration sentinelConfiguration,
380+
@Nullable JedisPoolConfig poolConfig) {
340381

341382
// we have to use a spy here as jedis would start connecting to redis sentinels when the pool is created.
342-
JedisConnectionFactory factorySpy = spy(new JedisConnectionFactory(sentinelConfig, poolConfig));
343-
doReturn(null).when(factorySpy).createRedisSentinelPool(any(RedisSentinelConfiguration.class));
344-
doReturn(null).when(factorySpy).createRedisPool();
345-
return factorySpy;
383+
JedisConnectionFactory connectionFactorySpy = spy(new JedisConnectionFactory(sentinelConfiguration, poolConfig));
384+
385+
doReturn(null).when(connectionFactorySpy)
386+
.createRedisSentinelPool(any(RedisSentinelConfiguration.class));
387+
388+
doReturn(null).when(connectionFactorySpy).createRedisPool();
389+
390+
return connectionFactorySpy;
346391
}
347392

348-
private JedisConnectionFactory initSpyedConnectionFactory(RedisClusterConfiguration clusterConfig,
349-
JedisPoolConfig poolConfig) {
393+
private JedisConnectionFactory initSpyedConnectionFactory(RedisClusterConfiguration clusterConfiguration,
394+
@Nullable JedisPoolConfig poolConfig) {
350395

351396
JedisCluster clusterMock = mock(JedisCluster.class);
352-
JedisConnectionFactory factorySpy = spy(new JedisConnectionFactory(clusterConfig));
353-
doReturn(clusterMock).when(factorySpy).createCluster(any(RedisClusterConfiguration.class),
354-
any(GenericObjectPoolConfig.class));
355-
doReturn(null).when(factorySpy).createRedisPool();
356-
return factorySpy;
397+
398+
JedisConnectionFactory connectionFactorySpy = spy(new JedisConnectionFactory(clusterConfiguration, poolConfig));
399+
400+
doReturn(clusterMock).when(connectionFactorySpy)
401+
.createCluster(any(RedisClusterConfiguration.class), any(GenericObjectPoolConfig.class));
402+
403+
doReturn(null).when(connectionFactorySpy).createRedisPool();
404+
405+
return connectionFactorySpy;
357406
}
358407
}

0 commit comments

Comments
 (0)