Skip to content

Commit ea29642

Browse files
serezakorotaevschauder
authored andcommitted
Add Stream support to JdbcAggregateOperations
See #1714 Original pull request #1963 Signed-off-by: Sergey Korotaev <[email protected]>
1 parent 4ef0538 commit ea29642

11 files changed

+475
-2
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java

+44
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.List;
1919
import java.util.Optional;
20+
import java.util.stream.Stream;
2021

2122
import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException;
2223
import org.springframework.data.domain.Example;
@@ -35,6 +36,7 @@
3536
* @author Chirag Tailor
3637
* @author Diego Krupitza
3738
* @author Myeonghyeon Lee
39+
* @author Sergey Korotaev
3840
*/
3941
public interface JdbcAggregateOperations {
4042

@@ -165,6 +167,17 @@ public interface JdbcAggregateOperations {
165167
*/
166168
<T> List<T> findAllById(Iterable<?> ids, Class<T> domainType);
167169

170+
/**
171+
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
172+
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
173+
*
174+
* @param ids the Ids of the entities to load. Must not be {@code null}.
175+
* @param domainType the type of entities to load. Must not be {@code null}.
176+
* @param <T> type of entities to load.
177+
* @return the loaded entities. Guaranteed to be not {@code null}.
178+
*/
179+
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
180+
168181
/**
169182
* Load all aggregates of a given type.
170183
*
@@ -174,6 +187,15 @@ public interface JdbcAggregateOperations {
174187
*/
175188
<T> List<T> findAll(Class<T> domainType);
176189

190+
/**
191+
* Load all aggregates of a given type to a {@link Stream}.
192+
*
193+
* @param domainType the type of the aggregate roots. Must not be {@code null}.
194+
* @param <T> the type of the aggregate roots. Must not be {@code null}.
195+
* @return Guaranteed to be not {@code null}.
196+
*/
197+
<T> Stream<T> streamAll(Class<T> domainType);
198+
177199
/**
178200
* Load all aggregates of a given type, sorted.
179201
*
@@ -185,6 +207,17 @@ public interface JdbcAggregateOperations {
185207
*/
186208
<T> List<T> findAll(Class<T> domainType, Sort sort);
187209

210+
/**
211+
* Loads all entities of the given type to a {@link Stream}, sorted.
212+
*
213+
* @param domainType the type of entities to load. Must not be {@code null}.
214+
* @param <T> the type of entities to load.
215+
* @param sort the sorting information. Must not be {@code null}.
216+
* @return Guaranteed to be not {@code null}.
217+
* @since 2.0
218+
*/
219+
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
220+
188221
/**
189222
* Load a page of (potentially sorted) aggregates of a given type.
190223
*
@@ -218,6 +251,17 @@ public interface JdbcAggregateOperations {
218251
*/
219252
<T> List<T> findAll(Query query, Class<T> domainType);
220253

254+
/**
255+
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
256+
*
257+
* @param query must not be {@literal null}.
258+
* @param domainType the type of entities. Must not be {@code null}.
259+
* @return a non-null list with all the matching results.
260+
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
261+
* @since 3.0
262+
*/
263+
<T> Stream<T> streamAll(Query query, Class<T> domainType);
264+
221265
/**
222266
* Returns a {@link Page} of entities matching the given {@link Query}. In case no match could be found, an empty
223267
* {@link Page} is returned.

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java

+34
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Optional;
2626
import java.util.function.Function;
2727
import java.util.stream.Collectors;
28+
import java.util.stream.Stream;
2829
import java.util.stream.StreamSupport;
2930

3031
import org.springframework.context.ApplicationContext;
@@ -68,6 +69,7 @@
6869
* @author Myeonghyeon Lee
6970
* @author Chirag Tailor
7071
* @author Diego Krupitza
72+
* @author Sergey Korotaev
7173
*/
7274
public class JdbcAggregateTemplate implements JdbcAggregateOperations {
7375

@@ -283,6 +285,16 @@ public <T> List<T> findAll(Class<T> domainType, Sort sort) {
283285
return triggerAfterConvert(all);
284286
}
285287

288+
@Override
289+
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
290+
291+
Assert.notNull(domainType, "Domain type must not be null");
292+
293+
Stream<T> allStreamable = accessStrategy.streamAll(domainType, sort);
294+
295+
return allStreamable.map(this::triggerAfterConvert);
296+
}
297+
286298
@Override
287299
public <T> Page<T> findAll(Class<T> domainType, Pageable pageable) {
288300

@@ -307,6 +319,11 @@ public <T> List<T> findAll(Query query, Class<T> domainType) {
307319
return triggerAfterConvert(all);
308320
}
309321

322+
@Override
323+
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
324+
return accessStrategy.streamAll(query, domainType).map(this::triggerAfterConvert);
325+
}
326+
310327
@Override
311328
public <T> Page<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
312329

@@ -325,6 +342,12 @@ public <T> List<T> findAll(Class<T> domainType) {
325342
return triggerAfterConvert(all);
326343
}
327344

345+
@Override
346+
public <T> Stream<T> streamAll(Class<T> domainType) {
347+
Iterable<T> items = triggerAfterConvert(accessStrategy.findAll(domainType));
348+
return StreamSupport.stream(items.spliterator(), false).map(this::triggerAfterConvert);
349+
}
350+
328351
@Override
329352
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
330353

@@ -335,6 +358,17 @@ public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
335358
return triggerAfterConvert(allById);
336359
}
337360

361+
@Override
362+
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
363+
364+
Assert.notNull(ids, "Ids must not be null");
365+
Assert.notNull(domainType, "Domain type must not be null");
366+
367+
Stream<T> allByIdStreamable = accessStrategy.streamAllByIds(ids, domainType);
368+
369+
return allByIdStreamable.map(this::triggerAfterConvert);
370+
}
371+
338372
@Override
339373
public <S> void delete(S aggregateRoot) {
340374

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java

+22
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Optional;
2323
import java.util.function.Consumer;
2424
import java.util.function.Function;
25+
import java.util.stream.Stream;
2526

2627
import org.springframework.data.domain.Pageable;
2728
import org.springframework.data.domain.Sort;
@@ -42,6 +43,7 @@
4243
* @author Myeonghyeon Lee
4344
* @author Chirag Tailor
4445
* @author Diego Krupitza
46+
* @author Sergey Korotaev
4547
* @since 1.1
4648
*/
4749
public class CascadingDataAccessStrategy implements DataAccessStrategy {
@@ -132,11 +134,21 @@ public <T> Iterable<T> findAll(Class<T> domainType) {
132134
return collect(das -> das.findAll(domainType));
133135
}
134136

137+
@Override
138+
public <T> Stream<T> streamAll(Class<T> domainType) {
139+
return collect(das -> das.streamAll(domainType));
140+
}
141+
135142
@Override
136143
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
137144
return collect(das -> das.findAllById(ids, domainType));
138145
}
139146

147+
@Override
148+
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
149+
return collect(das -> das.streamAllByIds(ids, domainType));
150+
}
151+
140152
@Override
141153
public Iterable<Object> findAllByPath(Identifier identifier,
142154
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@@ -153,6 +165,11 @@ public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
153165
return collect(das -> das.findAll(domainType, sort));
154166
}
155167

168+
@Override
169+
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
170+
return collect(das -> das.streamAll(domainType, sort));
171+
}
172+
156173
@Override
157174
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
158175
return collect(das -> das.findAll(domainType, pageable));
@@ -168,6 +185,11 @@ public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
168185
return collect(das -> das.findAll(query, domainType));
169186
}
170187

188+
@Override
189+
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
190+
return collect(das -> das.streamAll(query, domainType));
191+
}
192+
171193
@Override
172194
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
173195
return collect(das -> das.findAll(query, domainType, pageable));

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java

+48
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.Map;
2020
import java.util.Optional;
21+
import java.util.stream.Stream;
2122

2223
import org.springframework.dao.OptimisticLockingFailureException;
2324
import org.springframework.data.domain.Pageable;
@@ -41,6 +42,7 @@
4142
* @author Myeonghyeon Lee
4243
* @author Chirag Tailor
4344
* @author Diego Krupitza
45+
* @author Sergey Korotaev
4446
*/
4547
public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationResolver {
4648

@@ -252,6 +254,16 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
252254
@Override
253255
<T> Iterable<T> findAll(Class<T> domainType);
254256

257+
/**
258+
* Loads all entities of the given type to a {@link Stream}.
259+
*
260+
* @param domainType the type of entities to load. Must not be {@code null}.
261+
* @param <T> the type of entities to load.
262+
* @return Guaranteed to be not {@code null}.
263+
*/
264+
@Override
265+
<T> Stream<T> streamAll(Class<T> domainType);
266+
255267
/**
256268
* Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids
257269
* passed in matches the number of entities returned.
@@ -264,6 +276,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
264276
@Override
265277
<T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType);
266278

279+
/**
280+
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
281+
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
282+
*
283+
* @param ids the Ids of the entities to load. Must not be {@code null}.
284+
* @param domainType the type of entities to load. Must not be {@code null}.
285+
* @param <T> type of entities to load.
286+
* @return the loaded entities. Guaranteed to be not {@code null}.
287+
*/
288+
@Override
289+
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
290+
267291
@Override
268292
Iterable<Object> findAllByPath(Identifier identifier,
269293
PersistentPropertyPath<? extends RelationalPersistentProperty> path);
@@ -280,6 +304,18 @@ Iterable<Object> findAllByPath(Identifier identifier,
280304
@Override
281305
<T> Iterable<T> findAll(Class<T> domainType, Sort sort);
282306

307+
/**
308+
* Loads all entities of the given type to a {@link Stream}, sorted.
309+
*
310+
* @param domainType the type of entities to load. Must not be {@code null}.
311+
* @param <T> the type of entities to load.
312+
* @param sort the sorting information. Must not be {@code null}.
313+
* @return Guaranteed to be not {@code null}.
314+
* @since 2.0
315+
*/
316+
@Override
317+
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
318+
283319
/**
284320
* Loads all entities of the given type, paged and sorted.
285321
*
@@ -316,6 +352,18 @@ Iterable<Object> findAllByPath(Identifier identifier,
316352
@Override
317353
<T> Iterable<T> findAll(Query query, Class<T> domainType);
318354

355+
/**
356+
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
357+
*
358+
* @param query must not be {@literal null}.
359+
* @param domainType the type of entities. Must not be {@code null}.
360+
* @return a non-null list with all the matching results.
361+
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
362+
* @since 3.0
363+
*/
364+
@Override
365+
<T> Stream<T> streamAll(Query query, Class<T> domainType);
366+
319367
/**
320368
* Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable}
321369
* to the result.

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

+34
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Collections;
2323
import java.util.List;
2424
import java.util.Optional;
25+
import java.util.stream.Stream;
2526

2627
import org.springframework.dao.EmptyResultDataAccessException;
2728
import org.springframework.dao.OptimisticLockingFailureException;
@@ -60,6 +61,7 @@
6061
* @author Radim Tlusty
6162
* @author Chirag Tailor
6263
* @author Diego Krupitza
64+
* @author Sergey Korotaev
6365
* @since 1.1
6466
*/
6567
public class DefaultDataAccessStrategy implements DataAccessStrategy {
@@ -276,6 +278,11 @@ public <T> List<T> findAll(Class<T> domainType) {
276278
return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType));
277279
}
278280

281+
@Override
282+
public <T> Stream<T> streamAll(Class<T> domainType) {
283+
return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(), getEntityRowMapper(domainType));
284+
}
285+
279286
@Override
280287
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
281288

@@ -288,6 +295,19 @@ public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
288295
return operations.query(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
289296
}
290297

298+
@Override
299+
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
300+
301+
if (!ids.iterator().hasNext()) {
302+
return Stream.empty();
303+
}
304+
305+
SqlParameterSource parameterSource = sqlParametersFactory.forQueryByIds(ids, domainType);
306+
String findAllInListSql = sql(domainType).getFindAllInList();
307+
308+
return operations.queryForStream(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
309+
}
310+
291311
@Override
292312
@SuppressWarnings("unchecked")
293313
public List<Object> findAllByPath(Identifier identifier,
@@ -342,6 +362,11 @@ public <T> List<T> findAll(Class<T> domainType, Sort sort) {
342362
return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType));
343363
}
344364

365+
@Override
366+
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
367+
return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(), getEntityRowMapper(domainType));
368+
}
369+
345370
@Override
346371
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
347372
return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType));
@@ -369,6 +394,15 @@ public <T> List<T> findAll(Query query, Class<T> domainType) {
369394
return operations.query(sqlQuery, parameterSource, getEntityRowMapper(domainType));
370395
}
371396

397+
@Override
398+
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
399+
400+
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
401+
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource);
402+
403+
return operations.queryForStream(sqlQuery, parameterSource, getEntityRowMapper(domainType));
404+
}
405+
372406
@Override
373407
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
374408

0 commit comments

Comments
 (0)