Skip to content

Commit 9211c2d

Browse files
committed
Make MessageChannelPartitionHandler extend AbstractPartitionHandler
1 parent 2fdb68d commit 9211c2d

File tree

3 files changed

+25
-41
lines changed

3 files changed

+25
-41
lines changed

spring-batch-core/src/main/java/org/springframework/batch/core/partition/support/AbstractPartitionHandler.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2006-2021 the original author or authors.
2+
* Copyright 2006-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -35,7 +35,7 @@
3535
*/
3636
public abstract class AbstractPartitionHandler implements PartitionHandler {
3737

38-
private int gridSize = 1;
38+
protected int gridSize = 1;
3939

4040
/**
4141
* Executes the specified {@link StepExecution} instances and returns an updated view

spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2009-2021 the original author or authors.
2+
* Copyright 2009-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -15,14 +15,14 @@
1515
*/
1616
package org.springframework.batch.integration.partition;
1717

18-
import java.util.ArrayList;
19-
import java.util.Collection;
18+
import java.util.HashSet;
2019
import java.util.Iterator;
2120
import java.util.List;
2221
import java.util.Set;
2322
import java.util.concurrent.Callable;
2423
import java.util.concurrent.Future;
2524
import java.util.concurrent.TimeUnit;
25+
import java.util.stream.Collectors;
2626

2727
import javax.sql.DataSource;
2828

@@ -35,6 +35,7 @@
3535
import org.springframework.batch.core.explore.support.JobExplorerFactoryBean;
3636
import org.springframework.batch.core.partition.PartitionHandler;
3737
import org.springframework.batch.core.partition.StepExecutionSplitter;
38+
import org.springframework.batch.core.partition.support.AbstractPartitionHandler;
3839
import org.springframework.batch.core.repository.JobRepository;
3940
import org.springframework.batch.poller.DirectPoller;
4041
import org.springframework.batch.poller.Poller;
@@ -85,12 +86,10 @@
8586
*
8687
*/
8788
@MessageEndpoint
88-
public class MessageChannelPartitionHandler implements PartitionHandler, InitializingBean {
89+
public class MessageChannelPartitionHandler extends AbstractPartitionHandler implements InitializingBean {
8990

9091
private static Log logger = LogFactory.getLog(MessageChannelPartitionHandler.class);
9192

92-
private int gridSize = 1;
93-
9493
private MessagingTemplate messagingGateway;
9594

9695
private String stepName;
@@ -187,18 +186,6 @@ public void setMessagingOperations(MessagingTemplate messagingGateway) {
187186
this.messagingGateway = messagingGateway;
188187
}
189188

190-
/**
191-
* Passed to the {@link StepExecutionSplitter} in the
192-
* {@link #handle(StepExecutionSplitter, StepExecution)} method, instructing it how
193-
* many {@link StepExecution} instances are required, ideally. The
194-
* {@link StepExecutionSplitter} is allowed to ignore the grid size in the case of a
195-
* restart, since the input data partitions must be preserved.
196-
* @param gridSize the number of step executions that will be created
197-
*/
198-
public void setGridSize(int gridSize) {
199-
this.gridSize = gridSize;
200-
}
201-
202189
/**
203190
* The name of the {@link Step} that will be used to execute the partitioned
204191
* {@link StepExecution}. This is a regular Spring Batch step, with all the business
@@ -234,19 +221,17 @@ public void setReplyChannel(PollableChannel replyChannel) {
234221
*
235222
* @see PartitionHandler#handle(StepExecutionSplitter, StepExecution)
236223
*/
237-
public Collection<StepExecution> handle(StepExecutionSplitter stepExecutionSplitter,
238-
final StepExecution managerStepExecution) throws Exception {
239-
240-
final Set<StepExecution> split = stepExecutionSplitter.split(managerStepExecution, gridSize);
224+
@Override
225+
protected Set<StepExecution> doHandle(StepExecution managerStepExecution, Set<StepExecution> partitionStepExecutions) throws Exception {
241226

242-
if (CollectionUtils.isEmpty(split)) {
243-
return split;
227+
if (CollectionUtils.isEmpty(partitionStepExecutions)) {
228+
return partitionStepExecutions;
244229
}
245230

246231
int count = 0;
247232

248-
for (StepExecution stepExecution : split) {
249-
Message<StepExecutionRequest> request = createMessage(count++, split.size(),
233+
for (StepExecution stepExecution : partitionStepExecutions) {
234+
Message<StepExecutionRequest> request = createMessage(count++, partitionStepExecutions.size(),
250235
new StepExecutionRequest(stepName, stepExecution.getJobExecutionId(), stepExecution.getId()),
251236
replyChannel);
252237
if (logger.isDebugEnabled()) {
@@ -259,17 +244,17 @@ public Collection<StepExecution> handle(StepExecutionSplitter stepExecutionSplit
259244
return receiveReplies(replyChannel);
260245
}
261246
else {
262-
return pollReplies(managerStepExecution, split);
247+
return pollReplies(managerStepExecution, partitionStepExecutions);
263248
}
264249
}
265250

266-
private Collection<StepExecution> pollReplies(final StepExecution managerStepExecution,
251+
private Set<StepExecution> pollReplies(final StepExecution managerStepExecution,
267252
final Set<StepExecution> split) throws Exception {
268-
final Collection<StepExecution> result = new ArrayList<>(split.size());
253+
final Set<StepExecution> result = new HashSet<>(split.size());
269254

270-
Callable<Collection<StepExecution>> callback = new Callable<Collection<StepExecution>>() {
255+
Callable<Set<StepExecution>> callback = new Callable<Set<StepExecution>>() {
271256
@Override
272-
public Collection<StepExecution> call() throws Exception {
257+
public Set<StepExecution> call() throws Exception {
273258

274259
for (Iterator<StepExecution> stepExecutionIterator = split.iterator(); stepExecutionIterator
275260
.hasNext();) {
@@ -298,8 +283,8 @@ public Collection<StepExecution> call() throws Exception {
298283
}
299284
};
300285

301-
Poller<Collection<StepExecution>> poller = new DirectPoller<>(pollInterval);
302-
Future<Collection<StepExecution>> resultsFuture = poller.poll(callback);
286+
Poller<Set<StepExecution>> poller = new DirectPoller<>(pollInterval);
287+
Future<Set<StepExecution>> resultsFuture = poller.poll(callback);
303288

304289
if (timeout >= 0) {
305290
return resultsFuture.get(timeout, TimeUnit.MILLISECONDS);
@@ -309,9 +294,8 @@ public Collection<StepExecution> call() throws Exception {
309294
}
310295
}
311296

312-
private Collection<StepExecution> receiveReplies(PollableChannel currentReplyChannel) {
313-
@SuppressWarnings("unchecked")
314-
Message<Collection<StepExecution>> message = (Message<Collection<StepExecution>>) messagingGateway
297+
private Set<StepExecution> receiveReplies(PollableChannel currentReplyChannel) {
298+
Message<Set<StepExecution>> message = (Message<Set<StepExecution>>) messagingGateway
315299
.receive(currentReplyChannel);
316300

317301
if (message == null) {
@@ -321,7 +305,7 @@ else if (logger.isDebugEnabled()) {
321305
logger.debug("Received replies: " + message);
322306
}
323307

324-
return message.getPayload();
308+
return new HashSet<>(message.getPayload());
325309
}
326310

327311
private Message<StepExecutionRequest> createMessage(int sequenceNumber, int sequenceSize,

spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void testHandleNoReply() throws Exception {
8484
HashSet<StepExecution> stepExecutions = new HashSet<>();
8585
stepExecutions.add(new StepExecution("step1", new JobExecution(5L)));
8686
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
87-
when(message.getPayload()).thenReturn(Collections.emptyList());
87+
when(message.getPayload()).thenReturn(Collections.emptySet());
8888
when(operations.receive((PollableChannel) any())).thenReturn(message);
8989
// set
9090
messageChannelPartitionHandler.setMessagingOperations(operations);
@@ -112,7 +112,7 @@ void testHandleWithReplyChannel() throws Exception {
112112
HashSet<StepExecution> stepExecutions = new HashSet<>();
113113
stepExecutions.add(new StepExecution("step1", new JobExecution(5L)));
114114
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
115-
when(message.getPayload()).thenReturn(Collections.emptyList());
115+
when(message.getPayload()).thenReturn(Collections.emptySet());
116116
when(operations.receive(replyChannel)).thenReturn(message);
117117
// set
118118
messageChannelPartitionHandler.setMessagingOperations(operations);

0 commit comments

Comments
 (0)