diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java index 89d60baa39..b1fd42b91f 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.Optional; +import java.util.stream.Stream; import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException; import org.springframework.data.domain.Example; @@ -35,6 +36,7 @@ * @author Chirag Tailor * @author Diego Krupitza * @author Myeonghyeon Lee + * @author Sergey Korotaev */ public interface JdbcAggregateOperations { @@ -165,6 +167,17 @@ public interface JdbcAggregateOperations { */ List findAllById(Iterable ids, Class domainType); + /** + * Loads all entities that match one of the ids passed as an argument to a {@link Stream}. + * It is not guaranteed that the number of ids passed in matches the number of entities returned. + * + * @param ids the Ids of the entities to load. Must not be {@code null}. + * @param domainType the type of entities to load. Must not be {@code null}. + * @param type of entities to load. + * @return the loaded entities. Guaranteed to be not {@code null}. + */ + Stream streamAllByIds(Iterable ids, Class domainType); + /** * Load all aggregates of a given type. * @@ -174,6 +187,15 @@ public interface JdbcAggregateOperations { */ List findAll(Class domainType); + /** + * Load all aggregates of a given type to a {@link Stream}. + * + * @param domainType the type of the aggregate roots. Must not be {@code null}. + * @param the type of the aggregate roots. Must not be {@code null}. + * @return Guaranteed to be not {@code null}. + */ + Stream streamAll(Class domainType); + /** * Load all aggregates of a given type, sorted. * @@ -185,6 +207,17 @@ public interface JdbcAggregateOperations { */ List findAll(Class domainType, Sort sort); + /** + * Loads all entities of the given type to a {@link Stream}, sorted. + * + * @param domainType the type of entities to load. Must not be {@code null}. + * @param the type of entities to load. + * @param sort the sorting information. Must not be {@code null}. + * @return Guaranteed to be not {@code null}. + * @since 2.0 + */ + Stream streamAll(Class domainType, Sort sort); + /** * Load a page of (potentially sorted) aggregates of a given type. * @@ -218,6 +251,17 @@ public interface JdbcAggregateOperations { */ List findAll(Query query, Class domainType); + /** + * Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}. + * + * @param query must not be {@literal null}. + * @param domainType the type of entities. Must not be {@code null}. + * @return a non-null list with all the matching results. + * @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found. + * @since 3.0 + */ + Stream streamAll(Query query, Class domainType); + /** * Returns a {@link Page} of entities matching the given {@link Query}. In case no match could be found, an empty * {@link Page} is returned. diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java index 520211d439..2f4fb617e0 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java @@ -25,6 +25,7 @@ import java.util.Optional; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import java.util.stream.StreamSupport; import org.springframework.context.ApplicationContext; @@ -68,6 +69,7 @@ * @author Myeonghyeon Lee * @author Chirag Tailor * @author Diego Krupitza + * @author Sergey Korotaev */ public class JdbcAggregateTemplate implements JdbcAggregateOperations { @@ -283,6 +285,16 @@ public List findAll(Class domainType, Sort sort) { return triggerAfterConvert(all); } + @Override + public Stream streamAll(Class domainType, Sort sort) { + + Assert.notNull(domainType, "Domain type must not be null"); + + Stream allStreamable = accessStrategy.streamAll(domainType, sort); + + return allStreamable.map(this::triggerAfterConvert); + } + @Override public Page findAll(Class domainType, Pageable pageable) { @@ -309,6 +321,11 @@ public List findAll(Query query, Class domainType) { return Streamable.of(all).toList(); } + @Override + public Stream streamAll(Query query, Class domainType) { + return accessStrategy.streamAll(query, domainType).map(this::triggerAfterConvert); + } + @Override public Page findAll(Query query, Class domainType, Pageable pageable) { @@ -327,6 +344,12 @@ public List findAll(Class domainType) { return triggerAfterConvert(all); } + @Override + public Stream streamAll(Class domainType) { + Iterable items = triggerAfterConvert(accessStrategy.findAll(domainType)); + return StreamSupport.stream(items.spliterator(), false).map(this::triggerAfterConvert); + } + @Override public List findAllById(Iterable ids, Class domainType) { @@ -337,6 +360,17 @@ public List findAllById(Iterable ids, Class domainType) { return triggerAfterConvert(allById); } + @Override + public Stream streamAllByIds(Iterable ids, Class domainType) { + + Assert.notNull(ids, "Ids must not be null"); + Assert.notNull(domainType, "Domain type must not be null"); + + Stream allByIdStreamable = accessStrategy.streamAllByIds(ids, domainType); + + return allByIdStreamable.map(this::triggerAfterConvert); + } + @Override public void delete(S aggregateRoot) { diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java index 68160467a4..8135f810f8 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java @@ -22,6 +22,7 @@ import java.util.Optional; import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Stream; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; @@ -42,6 +43,7 @@ * @author Myeonghyeon Lee * @author Chirag Tailor * @author Diego Krupitza + * @author Sergey Korotaev * @since 1.1 */ public class CascadingDataAccessStrategy implements DataAccessStrategy { @@ -132,11 +134,21 @@ public Iterable findAll(Class domainType) { return collect(das -> das.findAll(domainType)); } + @Override + public Stream streamAll(Class domainType) { + return collect(das -> das.streamAll(domainType)); + } + @Override public Iterable findAllById(Iterable ids, Class domainType) { return collect(das -> das.findAllById(ids, domainType)); } + @Override + public Stream streamAllByIds(Iterable ids, Class domainType) { + return collect(das -> das.streamAllByIds(ids, domainType)); + } + @Override public Iterable findAllByPath(Identifier identifier, PersistentPropertyPath path) { @@ -153,6 +165,11 @@ public Iterable findAll(Class domainType, Sort sort) { return collect(das -> das.findAll(domainType, sort)); } + @Override + public Stream streamAll(Class domainType, Sort sort) { + return collect(das -> das.streamAll(domainType, sort)); + } + @Override public Iterable findAll(Class domainType, Pageable pageable) { return collect(das -> das.findAll(domainType, pageable)); @@ -168,6 +185,11 @@ public Iterable findAll(Query query, Class domainType) { return collect(das -> das.findAll(query, domainType)); } + @Override + public Stream streamAll(Query query, Class domainType) { + return collect(das -> das.streamAll(query, domainType)); + } + @Override public Iterable findAll(Query query, Class domainType, Pageable pageable) { return collect(das -> das.findAll(query, domainType, pageable)); diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java index 6b7e3c8ec1..406d2e22f1 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.Stream; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.domain.Pageable; @@ -41,6 +42,7 @@ * @author Myeonghyeon Lee * @author Chirag Tailor * @author Diego Krupitza + * @author Sergey Korotaev */ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationResolver { @@ -252,6 +254,16 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR @Override Iterable findAll(Class domainType); + /** + * Loads all entities of the given type to a {@link Stream}. + * + * @param domainType the type of entities to load. Must not be {@code null}. + * @param the type of entities to load. + * @return Guaranteed to be not {@code null}. + */ + @Override + Stream streamAll(Class domainType); + /** * Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids * passed in matches the number of entities returned. @@ -264,6 +276,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR @Override Iterable findAllById(Iterable ids, Class domainType); + /** + * Loads all entities that match one of the ids passed as an argument to a {@link Stream}. + * It is not guaranteed that the number of ids passed in matches the number of entities returned. + * + * @param ids the Ids of the entities to load. Must not be {@code null}. + * @param domainType the type of entities to load. Must not be {@code null}. + * @param type of entities to load. + * @return the loaded entities. Guaranteed to be not {@code null}. + */ + @Override + Stream streamAllByIds(Iterable ids, Class domainType); + @Override Iterable findAllByPath(Identifier identifier, PersistentPropertyPath path); @@ -280,6 +304,18 @@ Iterable findAllByPath(Identifier identifier, @Override Iterable findAll(Class domainType, Sort sort); + /** + * Loads all entities of the given type to a {@link Stream}, sorted. + * + * @param domainType the type of entities to load. Must not be {@code null}. + * @param the type of entities to load. + * @param sort the sorting information. Must not be {@code null}. + * @return Guaranteed to be not {@code null}. + * @since 2.0 + */ + @Override + Stream streamAll(Class domainType, Sort sort); + /** * Loads all entities of the given type, paged and sorted. * @@ -316,6 +352,18 @@ Iterable findAllByPath(Identifier identifier, @Override Iterable findAll(Query query, Class domainType); + /** + * Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}. + * + * @param query must not be {@literal null}. + * @param domainType the type of entities. Must not be {@code null}. + * @return a non-null list with all the matching results. + * @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found. + * @since 3.0 + */ + @Override + Stream streamAll(Query query, Class domainType); + /** * Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable} * to the result. diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java index 4d210d516d..340646066f 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.stream.Stream; import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.dao.OptimisticLockingFailureException; @@ -60,6 +61,7 @@ * @author Radim Tlusty * @author Chirag Tailor * @author Diego Krupitza + * @author Sergey Korotaev * @since 1.1 */ public class DefaultDataAccessStrategy implements DataAccessStrategy { @@ -276,6 +278,11 @@ public List findAll(Class domainType) { return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType)); } + @Override + public Stream streamAll(Class domainType) { + return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(), getEntityRowMapper(domainType)); + } + @Override public List findAllById(Iterable ids, Class domainType) { @@ -288,6 +295,19 @@ public List findAllById(Iterable ids, Class domainType) { return operations.query(findAllInListSql, parameterSource, getEntityRowMapper(domainType)); } + @Override + public Stream streamAllByIds(Iterable ids, Class domainType) { + + if (!ids.iterator().hasNext()) { + return Stream.empty(); + } + + SqlParameterSource parameterSource = sqlParametersFactory.forQueryByIds(ids, domainType); + String findAllInListSql = sql(domainType).getFindAllInList(); + + return operations.queryForStream(findAllInListSql, parameterSource, getEntityRowMapper(domainType)); + } + @Override @SuppressWarnings("unchecked") public List findAllByPath(Identifier identifier, @@ -342,6 +362,11 @@ public List findAll(Class domainType, Sort sort) { return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType)); } + @Override + public Stream streamAll(Class domainType, Sort sort) { + return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(), getEntityRowMapper(domainType)); + } + @Override public List findAll(Class domainType, Pageable pageable) { return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType)); @@ -369,6 +394,15 @@ public List findAll(Query query, Class domainType) { return operations.query(sqlQuery, parameterSource, getEntityRowMapper(domainType)); } + @Override + public Stream streamAll(Query query, Class domainType) { + + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + String sqlQuery = sql(domainType).selectByQuery(query, parameterSource); + + return operations.queryForStream(sqlQuery, parameterSource, getEntityRowMapper(domainType)); + } + @Override public List findAll(Query query, Class domainType, Pageable pageable) { diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java index 8acf774fb7..f28650e50d 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.Optional; +import java.util.stream.Stream; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; @@ -37,6 +38,7 @@ * @author Myeonghyeon Lee * @author Chirag Tailor * @author Diego Krupitza + * @author Sergey Korotaev * @since 1.1 */ public class DelegatingDataAccessStrategy implements DataAccessStrategy { @@ -135,11 +137,21 @@ public Iterable findAll(Class domainType) { return delegate.findAll(domainType); } + @Override + public Stream streamAll(Class domainType) { + return delegate.streamAll(domainType); + } + @Override public Iterable findAllById(Iterable ids, Class domainType) { return delegate.findAllById(ids, domainType); } + @Override + public Stream streamAllByIds(Iterable ids, Class domainType) { + return delegate.streamAllByIds(ids, domainType); + } + @Override public Iterable findAllByPath(Identifier identifier, PersistentPropertyPath path) { @@ -156,6 +168,11 @@ public Iterable findAll(Class domainType, Sort sort) { return delegate.findAll(domainType, sort); } + @Override + public Stream streamAll(Class domainType, Sort sort) { + return delegate.streamAll(domainType, sort); + } + @Override public Iterable findAll(Class domainType, Pageable pageable) { return delegate.findAll(domainType, pageable); @@ -171,6 +188,11 @@ public Iterable findAll(Query query, Class domainType) { return delegate.findAll(query, domainType); } + @Override + public Stream streamAll(Query query, Class domainType) { + return delegate.streamAll(query, domainType); + } + @Override public Iterable findAll(Query query, Class domainType, Pageable pageable) { return delegate.findAll(query, domainType, pageable); diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ReadingDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ReadingDataAccessStrategy.java index a23c0c1748..3b48b9fc5d 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ReadingDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ReadingDataAccessStrategy.java @@ -17,6 +17,7 @@ package org.springframework.data.jdbc.core.convert; import java.util.Optional; +import java.util.stream.Stream; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; @@ -27,6 +28,7 @@ * The finding methods of a {@link DataAccessStrategy}. * * @author Jens Schauder + * @author Sergey Korotaev * @since 3.2 */ interface ReadingDataAccessStrategy { @@ -51,6 +53,15 @@ interface ReadingDataAccessStrategy { */ Iterable findAll(Class domainType); + /** + * Loads all entities of the given type to a {@link Stream}. + * + * @param domainType the type of entities to load. Must not be {@code null}. + * @param the type of entities to load. + * @return Guaranteed to be not {@code null}. + */ + Stream streamAll(Class domainType); + /** * Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids * passed in matches the number of entities returned. @@ -62,6 +73,17 @@ interface ReadingDataAccessStrategy { */ Iterable findAllById(Iterable ids, Class domainType); + /** + * Loads all entities that match one of the ids passed as an argument to a {@link Stream}. + * It is not guaranteed that the number of ids passed in matches the number of entities returned. + * + * @param ids the Ids of the entities to load. Must not be {@code null}. + * @param domainType the type of entities to load. Must not be {@code null}. + * @param type of entities to load. + * @return the loaded entities. Guaranteed to be not {@code null}. + */ + Stream streamAllByIds(Iterable ids, Class domainType); + /** * Loads all entities of the given type, sorted. * @@ -73,6 +95,17 @@ interface ReadingDataAccessStrategy { */ Iterable findAll(Class domainType, Sort sort); + /** + * Loads all entities of the given type to a {@link Stream}, sorted. + * + * @param domainType the type of entities to load. Must not be {@code null}. + * @param the type of entities to load. + * @param sort the sorting information. Must not be {@code null}. + * @return Guaranteed to be not {@code null}. + * @since 2.0 + */ + Stream streamAll(Class domainType, Sort sort); + /** * Loads all entities of the given type, paged and sorted. * @@ -106,6 +139,17 @@ interface ReadingDataAccessStrategy { */ Iterable findAll(Query query, Class domainType); + /** + * Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}. + * + * @param query must not be {@literal null}. + * @param domainType the type of entities. Must not be {@code null}. + * @return a non-null list with all the matching results. + * @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found. + * @since 3.0 + */ + Stream streamAll(Query query, Class domainType); + /** * Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable} * to the result. diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java index 7fdab7d981..1313f082a8 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Optional; +import java.util.stream.Stream; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; @@ -32,6 +33,7 @@ * * @author Jens Schauder * @author Mark Paluch + * @author Sergey Korotaev * @since 3.2 */ class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy { @@ -56,16 +58,31 @@ public List findAll(Class domainType) { return aggregateReader.findAll(getPersistentEntity(domainType)); } + @Override + public Stream streamAll(Class domainType) { + throw new UnsupportedOperationException(); + } + @Override public List findAllById(Iterable ids, Class domainType) { return aggregateReader.findAllById(ids, getPersistentEntity(domainType)); } + @Override + public Stream streamAllByIds(Iterable ids, Class domainType) { + throw new UnsupportedOperationException(); + } + @Override public List findAll(Class domainType, Sort sort) { throw new UnsupportedOperationException(); } + @Override + public Stream streamAll(Class domainType, Sort sort) { + throw new UnsupportedOperationException(); + } + @Override public List findAll(Class domainType, Pageable pageable) { throw new UnsupportedOperationException(); @@ -81,6 +98,11 @@ public List findAll(Query query, Class domainType) { return aggregateReader.findAll(query, getPersistentEntity(domainType)); } + @Override + public Stream streamAll(Query query, Class domainType) { + throw new UnsupportedOperationException(); + } + @Override public List findAll(Query query, Class domainType, Pageable pageable) { throw new UnsupportedOperationException(); diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java index 3b8b8efd34..29f9f6ca98 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java @@ -22,7 +22,10 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.apache.ibatis.cursor.Cursor; import org.apache.ibatis.session.SqlSession; import org.mybatis.spring.SqlSessionTemplate; import org.springframework.dao.EmptyResultDataAccessException; @@ -59,6 +62,7 @@ * @author Chirag Tailor * @author Christopher Klein * @author Mikhail Polivakha + * @author Sergey Korotaev */ public class MyBatisDataAccessStrategy implements DataAccessStrategy { @@ -263,12 +267,28 @@ public List findAll(Class domainType) { return sqlSession().selectList(statement, parameter); } + @Override + public Stream streamAll(Class domainType) { + String statement = namespace(domainType) + ".streamAll"; + MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap()); + Cursor cursor = sqlSession().selectCursor(statement, parameter); + return StreamSupport.stream(cursor.spliterator(), false); + } + @Override public List findAllById(Iterable ids, Class domainType) { return sqlSession().selectList(namespace(domainType) + ".findAllById", new MyBatisContext(ids, null, domainType, Collections.emptyMap())); } + @Override + public Stream streamAllByIds(Iterable ids, Class domainType) { + String statement = namespace(domainType) + ".streamAllByIds"; + MyBatisContext parameter = new MyBatisContext(ids, null, domainType, Collections.emptyMap()); + Cursor cursor = sqlSession().selectCursor(statement, parameter); + return StreamSupport.stream(cursor.spliterator(), false); + } + @Override public List findAllByPath(Identifier identifier, PersistentPropertyPath path) { @@ -296,6 +316,19 @@ public List findAll(Class domainType, Sort sort) { new MyBatisContext(null, null, domainType, additionalContext)); } + @Override + public Stream streamAll(Class domainType, Sort sort) { + + Map additionalContext = new HashMap<>(); + additionalContext.put("sort", sort); + + String statement = namespace(domainType) + ".streamAllSorted"; + MyBatisContext parameter = new MyBatisContext(null, null, domainType, additionalContext); + + Cursor cursor = sqlSession().selectCursor(statement, parameter); + return StreamSupport.stream(cursor.spliterator(), false); + } + @Override public List findAll(Class domainType, Pageable pageable) { @@ -315,6 +348,11 @@ public List findAll(Query query, Class probeType) { throw new UnsupportedOperationException("Not implemented"); } + @Override + public Stream streamAll(Query query, Class probeType) { + throw new UnsupportedOperationException("Not implemented"); + } + @Override public List findAll(Query query, Class probeType, Pageable pageable) { throw new UnsupportedOperationException("Not implemented"); diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java index d1a085bd26..7a10047a1e 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java @@ -27,8 +27,8 @@ import java.util.ArrayList; import java.util.function.Function; import java.util.stream.IntStream; +import java.util.stream.Stream; -import org.assertj.core.api.SoftAssertions; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationEventPublisher; @@ -81,6 +81,7 @@ * @author Mikhail Polivakha * @author Chirag Tailor * @author Vincent Galloy + * @author Sergey Korotaev */ @IntegrationTest abstract class AbstractJdbcAggregateTemplateIntegrationTests { @@ -309,6 +310,18 @@ void saveAndLoadManyEntitiesWithReferencedEntity() { .containsExactly(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content)); } + @Test // GH-1714 + void saveAndLoadManeEntitiesWithReferenceEntityLikeStream() { + + template.save(legoSet); + + Stream streamable = template.streamAll(LegoSet.class); + + assertThat(streamable) + .extracting("id", "manual.id", "manual.content") // + .containsExactly(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content)); + } + @Test // DATAJDBC-101 void saveAndLoadManyEntitiesWithReferencedEntitySorted() { @@ -323,6 +336,20 @@ void saveAndLoadManyEntitiesWithReferencedEntitySorted() { .containsExactly("Frozen", "Lava", "Star"); } + @Test // GH-1714 + void saveAndLoadManyEntitiesWithReferencedEntitySortedLikeStream() { + + template.save(createLegoSet("Lava")); + template.save(createLegoSet("Star")); + template.save(createLegoSet("Frozen")); + + Stream reloadedLegoSets = template.streamAll(LegoSet.class, Sort.by("name")); + + assertThat(reloadedLegoSets) // + .extracting("name") // + .containsExactly("Frozen", "Lava", "Star"); + } + @Test // DATAJDBC-101 void saveAndLoadManyEntitiesWithReferencedEntitySortedAndPaged() { @@ -360,6 +387,12 @@ void findByNonPropertySortFails() { .isInstanceOf(InvalidPersistentPropertyPath.class); } + @Test // GH-1714 + void findByNonPropertySortLikeStreamFails() { + assertThatThrownBy(() -> template.streamAll(LegoSet.class, Sort.by("somethingNotExistant"))) + .isInstanceOf(InvalidPersistentPropertyPath.class); + } + @Test // DATAJDBC-112 void saveAndLoadManyEntitiesByIdWithReferencedEntity() { @@ -371,6 +404,17 @@ void saveAndLoadManyEntitiesByIdWithReferencedEntity() { .contains(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content)); } + @Test // GH-1714 + void saveAndLoadManyEntitiesByIdWithReferencedEntityLikeStream() { + + template.save(legoSet); + + Stream reloadedLegoSets = template.streamAllByIds(singletonList(legoSet.id), LegoSet.class); + + assertThat(reloadedLegoSets).hasSize(1).extracting("id", "manual.id", "manual.content") + .contains(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content)); + } + @Test // DATAJDBC-112 void saveAndLoadAnEntityWithReferencedNullEntity() { diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategyUnitTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategyUnitTests.java index 8dd69cb2cd..19889b7f30 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategyUnitTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategyUnitTests.java @@ -22,7 +22,12 @@ import static org.mockito.Mockito.*; import static org.springframework.data.relational.core.sql.SqlIdentifier.*; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Stream; +import org.apache.ibatis.cursor.Cursor; import org.apache.ibatis.session.SqlSession; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; @@ -43,6 +48,7 @@ * @author Mark Paluch * @author Tyler Van Gorder * @author Chirag Tailor + * @author Sergey Korotaev */ public class MyBatisDataAccessStrategyUnitTests { @@ -241,6 +247,36 @@ public void findAll() { ); } + @Test + public void streamAll() { + + String value = "some answer"; + + Cursor cursor = getCursor(value); + + when(session.selectCursor(anyString(), any())).then(answer -> cursor); + + Stream streamable = accessStrategy.streamAll(String.class); + + verify(session).selectCursor(eq("java.lang.StringMapper.streamAll"), captor.capture()); + + assertThat(streamable).isNotNull().containsExactly(value); + + assertThat(captor.getValue()) // + .isNotNull() // + .extracting( // + MyBatisContext::getInstance, // + MyBatisContext::getId, // + MyBatisContext::getDomainType, // + c -> c.get("key") // + ).containsExactly( // + null, // + null, // + String.class, // + null // + ); + } + @Test // DATAJDBC-123 public void findAllById() { @@ -263,6 +299,33 @@ public void findAllById() { ); } + @Test + public void streamAllByIds() { + + String value = "some answer 2"; + Cursor cursor = getCursor(value); + + when(session.selectCursor(anyString(), any())).then(answer -> cursor); + + accessStrategy.streamAllByIds(asList("id1", "id2"), String.class); + + verify(session).selectCursor(eq("java.lang.StringMapper.streamAllByIds"), captor.capture()); + + assertThat(captor.getValue()) // + .isNotNull() // + .extracting( // + MyBatisContext::getInstance, // + MyBatisContext::getId, // + MyBatisContext::getDomainType, // + c -> c.get("key") // + ).containsExactly( // + null, // + asList("id1", "id2"), // + String.class, // + null // + ); + } + @SuppressWarnings("unchecked") @Test // DATAJDBC-384 public void findAllByPath() { @@ -367,6 +430,33 @@ public void findAllSorted() { ); } + @Test + public void streamAllSorted() { + + String value = "some answer 3"; + Cursor cursor = getCursor(value); + + when(session.selectCursor(anyString(), any())).then(answer -> cursor); + + accessStrategy.streamAll(String.class, Sort.by("length")); + + verify(session).selectCursor(eq("java.lang.StringMapper.streamAllSorted"), captor.capture()); + + assertThat(captor.getValue()) // + .isNotNull() // + .extracting( // + MyBatisContext::getInstance, // + MyBatisContext::getId, // + MyBatisContext::getDomainType, // + c -> c.get("sort") // + ).containsExactly( // + null, // + null, // + String.class, // + Sort.by("length") // + ); + } + @Test // DATAJDBC-101 public void findAllPaged() { @@ -399,5 +489,36 @@ private static class ChildOne { ChildTwo two; } - private static class ChildTwo {} + private static class ChildTwo { + } + + private Cursor getCursor(String value) { + return new Cursor<>() { + @Override + public boolean isOpen() { + return false; + } + + @Override + public boolean isConsumed() { + return false; + } + + @Override + public int getCurrentIndex() { + return 0; + } + + @Override + public void close() { + + } + + @NotNull + @Override + public Iterator iterator() { + return List.of(value).iterator(); + } + }; + } }