Skip to content

Commit 4220df5

Browse files
christophstroblmp911de
authored andcommitted
Accept index names as hint for aggregations.
Closes #4238 Original pull request: #4243
1 parent 95c6d15 commit 4220df5

File tree

5 files changed

+116
-10
lines changed

5 files changed

+116
-10
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

+26-6
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import org.apache.commons.logging.LogFactory;
3131
import org.bson.Document;
3232
import org.bson.conversions.Bson;
33-
3433
import org.springframework.beans.BeansException;
3534
import org.springframework.context.ApplicationContext;
3635
import org.springframework.context.ApplicationContextAware;
@@ -634,15 +633,17 @@ public MongoCollection<Document> createCollection(String collectionName,
634633
}
635634

636635
@Override
637-
public MongoCollection<Document> createView(String name, Class<?> source, AggregationPipeline pipeline, @Nullable ViewOptions options) {
636+
public MongoCollection<Document> createView(String name, Class<?> source, AggregationPipeline pipeline,
637+
@Nullable ViewOptions options) {
638638

639639
return createView(name, getCollectionName(source),
640640
queryOperations.createAggregation(Aggregation.newAggregation(source, pipeline.getOperations()), source),
641641
options);
642642
}
643643

644644
@Override
645-
public MongoCollection<Document> createView(String name, String source, AggregationPipeline pipeline, @Nullable ViewOptions options) {
645+
public MongoCollection<Document> createView(String name, String source, AggregationPipeline pipeline,
646+
@Nullable ViewOptions options) {
646647

647648
return createView(name, source,
648649
queryOperations.createAggregation(Aggregation.newAggregation(pipeline.getOperations()), (Class<?>) null),
@@ -654,7 +655,8 @@ private MongoCollection<Document> createView(String name, String source, Aggrega
654655
return doCreateView(name, source, aggregation.getAggregationPipeline(), options);
655656
}
656657

657-
protected MongoCollection<Document> doCreateView(String name, String source, List<Document> pipeline, @Nullable ViewOptions options) {
658+
protected MongoCollection<Document> doCreateView(String name, String source, List<Document> pipeline,
659+
@Nullable ViewOptions options) {
658660

659661
CreateViewOptions viewOptions = new CreateViewOptions();
660662
if (options != null) {
@@ -2065,7 +2067,16 @@ protected <O> AggregationResults<O> doAggregate(Aggregation aggregation, String
20652067
}
20662068

20672069
options.getComment().ifPresent(aggregateIterable::comment);
2068-
options.getHint().ifPresent(aggregateIterable::hint);
2070+
if (options.getHintObject().isPresent()) {
2071+
Object hintObject = options.getHintObject().get();
2072+
if (hintObject instanceof String hintString) {
2073+
aggregateIterable = aggregateIterable.hintString(hintString);
2074+
} else if (hintObject instanceof Document hintDocument) {
2075+
aggregateIterable = aggregateIterable.hint(hintDocument);
2076+
} else {
2077+
throw new IllegalStateException("Unable to read hint of type %s".formatted(hintObject.getClass()));
2078+
}
2079+
}
20692080

20702081
if (options.hasExecutionTimeLimit()) {
20712082
aggregateIterable = aggregateIterable.maxTime(options.getMaxTime().toMillis(), TimeUnit.MILLISECONDS);
@@ -2124,7 +2135,16 @@ protected <O> Stream<O> aggregateStream(Aggregation aggregation, String collecti
21242135
}
21252136

21262137
options.getComment().ifPresent(cursor::comment);
2127-
options.getHint().ifPresent(cursor::hint);
2138+
if (options.getHintObject().isPresent()) {
2139+
Object hintObject = options.getHintObject().get();
2140+
if (hintObject instanceof String hintString) {
2141+
cursor = cursor.hintString(hintString);
2142+
} else if (hintObject instanceof Document hintDocument) {
2143+
cursor = cursor.hint(hintDocument);
2144+
} else {
2145+
throw new IllegalStateException("Unable to read hint of type %s".formatted(hintObject.getClass()));
2146+
}
2147+
}
21282148

21292149
Class<?> domainType = aggregation instanceof TypedAggregation ? ((TypedAggregation) aggregation).getInputType()
21302150
: null;

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,16 @@ private <O> Flux<O> aggregateAndMap(MongoCollection<Document> collection, List<D
938938
}
939939

940940
options.getComment().ifPresent(cursor::comment);
941-
options.getHint().ifPresent(cursor::hint);
941+
if (options.getHintObject().isPresent()) {
942+
Object hintObject = options.getHintObject().get();
943+
if (hintObject instanceof String hintString) {
944+
cursor = cursor.hintString(hintString);
945+
} else if (hintObject instanceof Document hintDocument) {
946+
cursor = cursor.hint(hintDocument);
947+
} else {
948+
throw new IllegalStateException("Unable to read hint of type %s".formatted(hintObject.getClass()));
949+
}
950+
}
942951

943952
Optionals.firstNonEmpty(options::getCollation, () -> operations.forType(inputType).getCollation()) //
944953
.map(Collation::toMongoCollation) //

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOptions.java

+55-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.bson.Document;
2222
import org.springframework.data.mongodb.core.query.Collation;
23+
import org.springframework.data.mongodb.util.BsonUtils;
2324
import org.springframework.lang.Nullable;
2425
import org.springframework.util.Assert;
2526

@@ -53,7 +54,7 @@ public class AggregationOptions {
5354
private final Optional<Document> cursor;
5455
private final Optional<Collation> collation;
5556
private final Optional<String> comment;
56-
private final Optional<Document> hint;
57+
private final Optional<Object> hint;
5758
private Duration maxTime = Duration.ZERO;
5859
private ResultOptions resultOptions = ResultOptions.READ;
5960
private DomainTypeMapping domainTypeMapping = DomainTypeMapping.RELAXED;
@@ -113,7 +114,7 @@ public AggregationOptions(boolean allowDiskUse, boolean explain, @Nullable Docum
113114
* @since 3.1
114115
*/
115116
private AggregationOptions(boolean allowDiskUse, boolean explain, @Nullable Document cursor,
116-
@Nullable Collation collation, @Nullable String comment, @Nullable Document hint) {
117+
@Nullable Collation collation, @Nullable String comment, @Nullable Object hint) {
117118

118119
this.allowDiskUse = allowDiskUse;
119120
this.explain = explain;
@@ -242,6 +243,44 @@ public Optional<String> getComment() {
242243
* @since 3.1
243244
*/
244245
public Optional<Document> getHint() {
246+
return hint.map(it -> {
247+
if (it instanceof Document doc) {
248+
return doc;
249+
}
250+
if (it instanceof String hintString) {
251+
if (BsonUtils.isJsonDocument(hintString)) {
252+
return BsonUtils.parse(hintString, null);
253+
}
254+
}
255+
throw new IllegalStateException("Unable to read hint of type %s".formatted(it.getClass()));
256+
});
257+
}
258+
259+
/**
260+
* Get the hint (indexName) used to to fulfill the aggregation.
261+
*
262+
* @return never {@literal null}.
263+
* @since 4.1
264+
*/
265+
public Optional<String> getHintAsString() {
266+
return hint.map(it -> {
267+
if (it instanceof String hintString) {
268+
return hintString;
269+
}
270+
if (it instanceof Document doc) {
271+
return BsonUtils.toJson(doc);
272+
}
273+
throw new IllegalStateException("Unable to read hint of type %s".formatted(it.getClass()));
274+
});
275+
}
276+
277+
/**
278+
* Get the hint used to to fulfill the aggregation.
279+
*
280+
* @return never {@literal null}.
281+
* @since 4.1
282+
*/
283+
public Optional<Object> getHintObject() {
245284
return hint;
246285
}
247286

@@ -361,7 +400,7 @@ public static class Builder {
361400
private @Nullable Document cursor;
362401
private @Nullable Collation collation;
363402
private @Nullable String comment;
364-
private @Nullable Document hint;
403+
private @Nullable Object hint;
365404
private @Nullable Duration maxTime;
366405
private @Nullable ResultOptions resultOptions;
367406
private @Nullable DomainTypeMapping domainTypeMapping;
@@ -454,6 +493,19 @@ public Builder hint(@Nullable Document hint) {
454493
return this;
455494
}
456495

496+
/**
497+
* Define a hint that is used by query optimizer to to fulfill the aggregation.
498+
*
499+
* @param indexName can be {@literal null}.
500+
* @return this.
501+
* @since 4.1
502+
*/
503+
public Builder hint(@Nullable String indexName) {
504+
505+
this.hint = indexName;
506+
return this;
507+
}
508+
457509
/**
458510
* Set the time limit for processing.
459511
*

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java

+12
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ void beforeEach() {
204204
when(aggregateIterable.map(any())).thenReturn(aggregateIterable);
205205
when(aggregateIterable.maxTime(anyLong(), any())).thenReturn(aggregateIterable);
206206
when(aggregateIterable.into(any())).thenReturn(Collections.emptyList());
207+
when(aggregateIterable.hint(any())).thenReturn(aggregateIterable);
208+
when(aggregateIterable.hintString(any())).thenReturn(aggregateIterable);
207209
when(distinctIterable.collation(any())).thenReturn(distinctIterable);
208210
when(distinctIterable.map(any())).thenReturn(distinctIterable);
209211
when(distinctIterable.into(any())).thenReturn(Collections.emptyList());
@@ -497,6 +499,16 @@ void aggregateShouldHonorOptionsHint() {
497499
verify(aggregateIterable).hint(hint);
498500
}
499501

502+
@Test // GH-4238
503+
void aggregateShouldHonorOptionsHintString() {
504+
505+
AggregationOptions options = AggregationOptions.builder().hint("index-1").build();
506+
507+
template.aggregate(newAggregation(Aggregation.unwind("foo")).withOptions(options), "collection-1", Wrapper.class);
508+
509+
verify(aggregateIterable).hintString("index-1");
510+
}
511+
500512
@Test // GH-3542
501513
void aggregateShouldUseRelaxedMappingByDefault() {
502514

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateUnitTests.java

+13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import lombok.AllArgsConstructor;
2424
import lombok.Data;
2525
import lombok.NoArgsConstructor;
26+
import org.springframework.data.mongodb.core.MongoTemplateUnitTests.Wrapper;
27+
import org.springframework.data.mongodb.core.aggregation.Aggregation;
2628
import reactor.core.publisher.Flux;
2729
import reactor.core.publisher.Mono;
2830
import reactor.test.StepVerifier;
@@ -666,6 +668,17 @@ void aggregateShouldHonorOptionsHint() {
666668
verify(aggregatePublisher).hint(hint);
667669
}
668670

671+
@Test // GH-4238
672+
void aggregateShouldHonorOptionsHintString() {
673+
674+
AggregationOptions options = AggregationOptions.builder().hint("index-1").build();
675+
676+
template.aggregate(newAggregation(Sith.class, project("id")).withOptions(options), AutogenerateableId.class,
677+
Document.class).subscribe();
678+
679+
verify(aggregatePublisher).hintString("index-1");
680+
}
681+
669682
@Test // DATAMONGO-2390
670683
void aggregateShouldNoApplyZeroOrNegativeMaxTime() {
671684

0 commit comments

Comments
 (0)