Skip to content

Commit 451c0c3

Browse files
committed
Add support for arbitrary where clauses in Single Query Loading.
Closes #1601
1 parent 6be9529 commit 451c0c3

File tree

6 files changed

+142
-25
lines changed

6 files changed

+142
-25
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java

+44-1
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,22 @@
2121
import java.util.Iterator;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.Optional;
25+
import java.util.function.BiFunction;
2426

2527
import org.springframework.dao.IncorrectResultSizeDataAccessException;
2628
import org.springframework.data.relational.core.dialect.Dialect;
2729
import org.springframework.data.relational.core.mapping.AggregatePath;
2830
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
31+
import org.springframework.data.relational.core.query.CriteriaDefinition;
32+
import org.springframework.data.relational.core.query.Query;
33+
import org.springframework.data.relational.core.sql.Condition;
34+
import org.springframework.data.relational.core.sql.Table;
2935
import org.springframework.data.relational.core.sqlgeneration.AliasFactory;
3036
import org.springframework.data.relational.core.sqlgeneration.SingleQuerySqlGenerator;
3137
import org.springframework.data.relational.core.sqlgeneration.SqlGenerator;
3238
import org.springframework.data.relational.domain.RowDocument;
39+
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
3340
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
3441
import org.springframework.lang.Nullable;
3542
import org.springframework.util.Assert;
@@ -89,6 +96,35 @@ public Iterable<T> findAllById(Iterable<?> ids) {
8996
return jdbcTemplate.query(sqlGenerator.findAllById(), Map.of("ids", convertedIds), this::extractAll);
9097
}
9198

99+
public Iterable<T> findAllBy(Query query) {
100+
101+
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
102+
BiFunction<Table, RelationalPersistentEntity, Condition> condition = createConditionSource(query, parameterSource);
103+
return jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractAll);
104+
}
105+
106+
public Optional<T> findOneByQuery(Query query) {
107+
108+
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
109+
BiFunction<Table, RelationalPersistentEntity, Condition> condition = createConditionSource(query, parameterSource);
110+
111+
return Optional.ofNullable(
112+
jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractZeroOrOne));
113+
}
114+
115+
private BiFunction<Table, RelationalPersistentEntity, Condition> createConditionSource(Query query, MapSqlParameterSource parameterSource) {
116+
117+
QueryMapper queryMapper = new QueryMapper(converter);
118+
119+
BiFunction<Table, RelationalPersistentEntity, Condition> condition = (table, aggregate) -> {
120+
Optional<CriteriaDefinition> criteria = query.getCriteria();
121+
return criteria
122+
.map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate))
123+
.orElse(null);
124+
};
125+
return condition;
126+
}
127+
92128
/**
93129
* Extracts a list of aggregates from the given {@link ResultSet} by utilizing the
94130
* {@link RowDocumentResultSetExtractor} and the {@link JdbcConverter}. When used as a method reference this conforms
@@ -115,7 +151,8 @@ private List<T> extractAll(ResultSet rs) throws SQLException {
115151
* to the {@link org.springframework.jdbc.core.ResultSetExtractor} contract.
116152
*
117153
* @param @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}.
118-
* @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is empty, null is returned.
154+
* @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is
155+
* empty, null is returned.
119156
* @throws SQLException
120157
* @throws IncorrectResultSizeDataAccessException when the conversion yields more than one instance.
121158
*/
@@ -190,9 +227,15 @@ public String findAllById() {
190227
return findAllById;
191228
}
192229

230+
@Override
231+
public String findAllByCondition(BiFunction<Table, RelationalPersistentEntity, Condition> conditionSource) {
232+
return delegate.findAllByCondition(conditionSource);
233+
}
234+
193235
@Override
194236
public AliasFactory getAliasFactory() {
195237
return delegate.getAliasFactory();
196238
}
239+
197240
}
198241
}

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,13 @@ public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
7777

7878
@Override
7979
public <T> Optional<T> findOne(Query query, Class<T> domainType) {
80-
return Optional.empty();
80+
return getReader(domainType).findOneByQuery(query);
8181
}
8282

8383
@Override
8484
public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
85-
throw new UnsupportedOperationException();
85+
86+
return getReader(domainType).findAllBy(query);
8687
}
8788

8889
@Override

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java

+32-6
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
package org.springframework.data.jdbc.core.convert;
1717

1818
import java.util.Collections;
19+
import java.util.Optional;
1920

2021
import org.springframework.data.mapping.PersistentPropertyPath;
2122
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
23+
import org.springframework.data.relational.core.query.Query;
2224
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
2325
import org.springframework.util.Assert;
2426

@@ -85,13 +87,37 @@ public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
8587
return super.findAllById(ids, domainType);
8688
}
8789

90+
public <T> Optional<T> findOne(Query query, Class<T> domainType) {
91+
92+
if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) {
93+
return singleSelectDelegate.findOne(query, domainType);
94+
}
95+
96+
return super.findOne(query, domainType);
97+
}
98+
99+
@Override
100+
public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
101+
102+
if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) {
103+
return singleSelectDelegate.findAll(query, domainType);
104+
}
105+
106+
return super.findAll(query, domainType);
107+
}
108+
109+
private static boolean isSingleSelectQuerySupported(Query query) {
110+
return !query.isSorted() && !query.isLimited();
111+
}
112+
88113
private boolean isSingleSelectQuerySupported(Class<?> entityType) {
89114

90-
return sqlGeneratorSource.getDialect().supportsSingleQueryLoading()//
91-
&& entityQualifiesForSingleSelectQuery(entityType);
115+
return converter.getMappingContext().isSingleQueryLoadingEnabled()
116+
&& sqlGeneratorSource.getDialect().supportsSingleQueryLoading()//
117+
&& entityQualifiesForSingleQueryLoading(entityType);
92118
}
93119

94-
private boolean entityQualifiesForSingleSelectQuery(Class<?> entityType) {
120+
private boolean entityQualifiesForSingleQueryLoading(Class<?> entityType) {
95121

96122
boolean referenceFound = false;
97123
for (PersistentPropertyPath<RelationalPersistentProperty> path : converter.getMappingContext()
@@ -113,9 +139,9 @@ private boolean entityQualifiesForSingleSelectQuery(Class<?> entityType) {
113139
}
114140

115141
// AggregateReferences aren't supported yet
116-
if (property.isAssociation()) {
117-
return false;
118-
}
142+
// if (property.isAssociation()) {
143+
// return false;
144+
// }
119145
}
120146
return true;
121147

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java

+47-9
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,7 @@
2424
import static org.springframework.test.context.TestExecutionListeners.MergeMode.*;
2525

2626
import java.time.LocalDateTime;
27-
import java.util.ArrayList;
28-
import java.util.Collections;
29-
import java.util.HashMap;
30-
import java.util.HashSet;
31-
import java.util.Iterator;
32-
import java.util.List;
33-
import java.util.Map;
34-
import java.util.Objects;
35-
import java.util.Set;
27+
import java.util.*;
3628
import java.util.function.Function;
3729
import java.util.stream.IntStream;
3830

@@ -67,6 +59,9 @@
6759
import org.springframework.data.relational.core.mapping.MappedCollection;
6860
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
6961
import org.springframework.data.relational.core.mapping.Table;
62+
import org.springframework.data.relational.core.query.Criteria;
63+
import org.springframework.data.relational.core.query.CriteriaDefinition;
64+
import org.springframework.data.relational.core.query.Query;
7065
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
7166
import org.springframework.test.context.ActiveProfiles;
7267
import org.springframework.test.context.ContextConfiguration;
@@ -233,6 +228,49 @@ void findAllById() {
233228
.containsExactlyInAnyOrder(tuple(entity.id, "entity"), tuple(yetAnother.id, "yetAnother"));
234229
}
235230

231+
@Test // GH-1601
232+
void findAllByQuery() {
233+
234+
ListParent entity = new ListParent();
235+
entity = template.save(entity);
236+
237+
ListParent other = new ListParent();
238+
other = template.save(other);
239+
240+
ListParent yetAnother = new ListParent();
241+
yetAnother = template.save(yetAnother);
242+
243+
CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(other.id));
244+
Query query = Query.query(criteria);
245+
Iterable<ListParent> reloadedById = template.findAll(query,
246+
ListParent.class);
247+
248+
249+
assertThat(reloadedById).extracting(e -> e.id)
250+
.containsExactly(other.id);
251+
}
252+
253+
@Test // GH-1601
254+
void findOneByQuery() {
255+
256+
ListParent entity = new ListParent();
257+
entity = template.save(entity);
258+
259+
ListParent other = new ListParent();
260+
other = template.save(other);
261+
262+
ListParent yetAnother = new ListParent();
263+
yetAnother = template.save(yetAnother);
264+
265+
CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(other.id));
266+
Query query = Query.query(criteria);
267+
Optional<ListParent> reloadedById = template.findOne(query,
268+
ListParent.class);
269+
270+
271+
assertThat(reloadedById).get().extracting(e -> e.id).isEqualTo(other.id);
272+
}
273+
236274
@Test // DATAJDBC-112
237275
@EnabledOnFeature(SUPPORTS_QUOTED_IDS)
238276
void saveAndLoadAnEntityWithReferencedEntityById() {

spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java

+8-7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.Collection;
2020
import java.util.List;
2121
import java.util.Map;
22+
import java.util.function.BiFunction;
2223

2324
import org.jetbrains.annotations.NotNull;
2425
import org.springframework.data.mapping.PersistentProperty;
@@ -81,20 +82,20 @@ public String findAllById() {
8182
return createSelect(condition);
8283
}
8384

85+
@Override
86+
public String findAllByCondition(BiFunction<Table, RelationalPersistentEntity, Condition> conditionSource) {
87+
Condition condition = conditionSource.apply(table, aggregate);
88+
return createSelect(condition);
89+
}
90+
8491
/**
8592
* @return The {@link AggregatePath} to the id property of the aggregate root.
8693
*/
8794
private AggregatePath getRootIdPath() {
8895
return context.getAggregatePath(aggregate).append(aggregate.getRequiredIdProperty());
8996
}
9097

91-
/**
92-
* Creates a SQL suitable of loading all the data required for constructing complete aggregates.
93-
*
94-
* @param condition a constraint for limiting the aggregates to be loaded.
95-
* @return a {@literal String} containing the generated SQL statement
96-
*/
97-
private String createSelect(Condition condition) {
98+
String createSelect(Condition condition) {
9899

99100
AggregatePath rootPath = context.getAggregatePath(aggregate);
100101
QueryMeta queryMeta = createInlineQuery(rootPath, condition);

spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java

+8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
*/
1616
package org.springframework.data.relational.core.sqlgeneration;
1717

18+
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
19+
import org.springframework.data.relational.core.sql.Condition;
20+
import org.springframework.data.relational.core.sql.Table;
21+
22+
import java.util.function.BiFunction;
23+
1824
/**
1925
* Generates SQL statements for loading aggregates.
2026
*
@@ -28,5 +34,7 @@ public interface SqlGenerator {
2834

2935
String findAllById();
3036

37+
String findAllByCondition(BiFunction<Table, RelationalPersistentEntity, Condition> conditionSource);
38+
3139
AliasFactory getAliasFactory();
3240
}

0 commit comments

Comments
 (0)