1
1
/*
2
- * Copyright 2009-2021 the original author or authors.
2
+ * Copyright 2009-2022 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
15
15
*/
16
16
package org .springframework .batch .integration .partition ;
17
17
18
- import java .util .ArrayList ;
19
- import java .util .Collection ;
18
+ import java .util .HashSet ;
20
19
import java .util .Iterator ;
21
20
import java .util .List ;
22
21
import java .util .Set ;
23
22
import java .util .concurrent .Callable ;
24
23
import java .util .concurrent .Future ;
25
24
import java .util .concurrent .TimeUnit ;
25
+ import java .util .stream .Collectors ;
26
26
27
27
import javax .sql .DataSource ;
28
28
35
35
import org .springframework .batch .core .explore .support .JobExplorerFactoryBean ;
36
36
import org .springframework .batch .core .partition .PartitionHandler ;
37
37
import org .springframework .batch .core .partition .StepExecutionSplitter ;
38
+ import org .springframework .batch .core .partition .support .AbstractPartitionHandler ;
38
39
import org .springframework .batch .core .repository .JobRepository ;
39
40
import org .springframework .batch .poller .DirectPoller ;
40
41
import org .springframework .batch .poller .Poller ;
85
86
*
86
87
*/
87
88
@ MessageEndpoint
88
- public class MessageChannelPartitionHandler implements PartitionHandler , InitializingBean {
89
+ public class MessageChannelPartitionHandler extends AbstractPartitionHandler implements InitializingBean {
89
90
90
91
private static Log logger = LogFactory .getLog (MessageChannelPartitionHandler .class );
91
92
92
- private int gridSize = 1 ;
93
-
94
93
private MessagingTemplate messagingGateway ;
95
94
96
95
private String stepName ;
@@ -187,18 +186,6 @@ public void setMessagingOperations(MessagingTemplate messagingGateway) {
187
186
this .messagingGateway = messagingGateway ;
188
187
}
189
188
190
- /**
191
- * Passed to the {@link StepExecutionSplitter} in the
192
- * {@link #handle(StepExecutionSplitter, StepExecution)} method, instructing it how
193
- * many {@link StepExecution} instances are required, ideally. The
194
- * {@link StepExecutionSplitter} is allowed to ignore the grid size in the case of a
195
- * restart, since the input data partitions must be preserved.
196
- * @param gridSize the number of step executions that will be created
197
- */
198
- public void setGridSize (int gridSize ) {
199
- this .gridSize = gridSize ;
200
- }
201
-
202
189
/**
203
190
* The name of the {@link Step} that will be used to execute the partitioned
204
191
* {@link StepExecution}. This is a regular Spring Batch step, with all the business
@@ -234,19 +221,17 @@ public void setReplyChannel(PollableChannel replyChannel) {
234
221
*
235
222
* @see PartitionHandler#handle(StepExecutionSplitter, StepExecution)
236
223
*/
237
- public Collection <StepExecution > handle (StepExecutionSplitter stepExecutionSplitter ,
238
- final StepExecution managerStepExecution ) throws Exception {
239
-
240
- final Set <StepExecution > split = stepExecutionSplitter .split (managerStepExecution , gridSize );
224
+ @ Override
225
+ protected Set <StepExecution > doHandle (StepExecution managerStepExecution , Set <StepExecution > partitionStepExecutions ) throws Exception {
241
226
242
- if (CollectionUtils .isEmpty (split )) {
243
- return split ;
227
+ if (CollectionUtils .isEmpty (partitionStepExecutions )) {
228
+ return partitionStepExecutions ;
244
229
}
245
230
246
231
int count = 0 ;
247
232
248
- for (StepExecution stepExecution : split ) {
249
- Message <StepExecutionRequest > request = createMessage (count ++, split .size (),
233
+ for (StepExecution stepExecution : partitionStepExecutions ) {
234
+ Message <StepExecutionRequest > request = createMessage (count ++, partitionStepExecutions .size (),
250
235
new StepExecutionRequest (stepName , stepExecution .getJobExecutionId (), stepExecution .getId ()),
251
236
replyChannel );
252
237
if (logger .isDebugEnabled ()) {
@@ -259,17 +244,17 @@ public Collection<StepExecution> handle(StepExecutionSplitter stepExecutionSplit
259
244
return receiveReplies (replyChannel );
260
245
}
261
246
else {
262
- return pollReplies (managerStepExecution , split );
247
+ return pollReplies (managerStepExecution , partitionStepExecutions );
263
248
}
264
249
}
265
250
266
- private Collection <StepExecution > pollReplies (final StepExecution managerStepExecution ,
251
+ private Set <StepExecution > pollReplies (final StepExecution managerStepExecution ,
267
252
final Set <StepExecution > split ) throws Exception {
268
- final Collection <StepExecution > result = new ArrayList <>(split .size ());
253
+ final Set <StepExecution > result = new HashSet <>(split .size ());
269
254
270
- Callable <Collection <StepExecution >> callback = new Callable <Collection <StepExecution >>() {
255
+ Callable <Set <StepExecution >> callback = new Callable <Set <StepExecution >>() {
271
256
@ Override
272
- public Collection <StepExecution > call () throws Exception {
257
+ public Set <StepExecution > call () throws Exception {
273
258
274
259
for (Iterator <StepExecution > stepExecutionIterator = split .iterator (); stepExecutionIterator
275
260
.hasNext ();) {
@@ -298,8 +283,8 @@ public Collection<StepExecution> call() throws Exception {
298
283
}
299
284
};
300
285
301
- Poller <Collection <StepExecution >> poller = new DirectPoller <>(pollInterval );
302
- Future <Collection <StepExecution >> resultsFuture = poller .poll (callback );
286
+ Poller <Set <StepExecution >> poller = new DirectPoller <>(pollInterval );
287
+ Future <Set <StepExecution >> resultsFuture = poller .poll (callback );
303
288
304
289
if (timeout >= 0 ) {
305
290
return resultsFuture .get (timeout , TimeUnit .MILLISECONDS );
@@ -309,9 +294,8 @@ public Collection<StepExecution> call() throws Exception {
309
294
}
310
295
}
311
296
312
- private Collection <StepExecution > receiveReplies (PollableChannel currentReplyChannel ) {
313
- @ SuppressWarnings ("unchecked" )
314
- Message <Collection <StepExecution >> message = (Message <Collection <StepExecution >>) messagingGateway
297
+ private Set <StepExecution > receiveReplies (PollableChannel currentReplyChannel ) {
298
+ Message <Set <StepExecution >> message = (Message <Set <StepExecution >>) messagingGateway
315
299
.receive (currentReplyChannel );
316
300
317
301
if (message == null ) {
@@ -321,7 +305,7 @@ else if (logger.isDebugEnabled()) {
321
305
logger .debug ("Received replies: " + message );
322
306
}
323
307
324
- return message .getPayload ();
308
+ return new HashSet <>( message .getPayload () );
325
309
}
326
310
327
311
private Message <StepExecutionRequest > createMessage (int sequenceNumber , int sequenceSize ,
0 commit comments