22
22
import java .util .concurrent .BlockingQueue ;
23
23
import java .util .concurrent .LinkedBlockingQueue ;
24
24
import java .util .concurrent .TimeUnit ;
25
- import java .util .function .Predicate ;
25
+ import java .util .function .BiPredicate ;
26
26
import org .slf4j .Logger ;
27
27
import org .slf4j .LoggerFactory ;
28
28
@@ -33,11 +33,11 @@ final class DynamicBatch<T> implements AutoCloseable {
33
33
private static final int MAX_BATCH_SIZE = 8192 ;
34
34
35
35
private final BlockingQueue <T > requests = new LinkedBlockingQueue <>();
36
- private final Predicate <List <T >> consumer ;
36
+ private final BiPredicate <List <T >, Boolean > consumer ;
37
37
private final int configuredBatchSize ;
38
38
private final Thread thread ;
39
39
40
- DynamicBatch (Predicate <List <T >> consumer , int batchSize ) {
40
+ DynamicBatch (BiPredicate <List <T >, Boolean > consumer , int batchSize ) {
41
41
this .consumer = consumer ;
42
42
this .configuredBatchSize = min (max (batchSize , MIN_BATCH_SIZE ), MAX_BATCH_SIZE );
43
43
this .thread = ConcurrencyUtils .defaultThreadFactory ().newThread (this ::loop );
@@ -53,8 +53,10 @@ void add(T item) {
53
53
}
54
54
55
55
private void loop () {
56
- int batchSize = this .configuredBatchSize ;
57
- List <T > batch = new ArrayList <>(batchSize );
56
+ State <T > state = new State <>();
57
+ state .batchSize = this .configuredBatchSize ;
58
+ state .items = new ArrayList <>(state .batchSize );
59
+ state .retry = false ;
58
60
Thread currentThread = Thread .currentThread ();
59
61
T item ;
60
62
while (!currentThread .isInterrupted ()) {
@@ -65,44 +67,50 @@ private void loop() {
65
67
return ;
66
68
}
67
69
if (item != null ) {
68
- batch .add (item );
69
- if (batch .size () >= batchSize ) {
70
- if (this .completeBatch (batch )) {
71
- batchSize = min (batchSize * 2 , MAX_BATCH_SIZE );
72
- batch = new ArrayList <>(batchSize );
73
- }
70
+ state .items .add (item );
71
+ if (state .items .size () >= state .batchSize ) {
72
+ this .maybeCompleteBatch (state , true );
74
73
} else {
75
74
item = this .requests .poll ();
76
75
if (item == null ) {
77
- if (this .completeBatch (batch )) {
78
- batchSize = max (batchSize / 2 , MIN_BATCH_SIZE );
79
- batch = new ArrayList <>(batchSize );
80
- }
76
+ this .maybeCompleteBatch (state , false );
81
77
} else {
82
- batch .add (item );
83
- if (batch .size () >= batchSize ) {
84
- if (this .completeBatch (batch )) {
85
- batchSize = min (batchSize * 2 , MAX_BATCH_SIZE );
86
- batch = new ArrayList <>(batchSize );
87
- }
78
+ state .items .add (item );
79
+ if (state .items .size () >= state .batchSize ) {
80
+ this .maybeCompleteBatch (state , true );
88
81
}
89
82
}
90
83
}
91
84
} else {
92
- if (this .completeBatch (batch )) {
93
- batchSize = min (batchSize * 2 , MAX_BATCH_SIZE );
94
- batch = new ArrayList <>(batchSize );
95
- }
85
+ this .maybeCompleteBatch (state , false );
96
86
}
97
87
}
98
88
}
99
89
100
- private boolean completeBatch (List <T > items ) {
90
+ private static final class State <T > {
91
+
92
+ int batchSize ;
93
+ List <T > items ;
94
+ boolean retry ;
95
+ }
96
+
97
+ private void maybeCompleteBatch (State <T > state , boolean increaseIfCompleted ) {
101
98
try {
102
- return this .consumer .test (items );
99
+ boolean completed = this .consumer .test (state .items , state .retry );
100
+ if (completed ) {
101
+ if (increaseIfCompleted ) {
102
+ state .batchSize = min (state .batchSize * 2 , MAX_BATCH_SIZE );
103
+ } else {
104
+ state .batchSize = max (state .batchSize / 2 , MIN_BATCH_SIZE );
105
+ }
106
+ state .items = new ArrayList <>(state .batchSize );
107
+ state .retry = false ;
108
+ } else {
109
+ state .retry = true ;
110
+ }
103
111
} catch (Exception e ) {
104
112
LOGGER .warn ("Error during dynamic batch completion: {}" , e .getMessage ());
105
- return false ;
113
+ state . retry = true ;
106
114
}
107
115
}
108
116
0 commit comments