Skip to content

Commit c8244be

Browse files
provide some tests to verify behaviour
1 parent ada6c3d commit c8244be

File tree

2 files changed

+229
-4
lines changed

2 files changed

+229
-4
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractStringBasedJpaQuery.java

+14-4
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,11 @@ public AbstractStringBasedJpaQuery(JpaQueryMethod method, EntityManager em, Stri
9393
public Query doCreateQuery(JpaParametersParameterAccessor accessor) {
9494

9595
Sort sort = accessor.getSort();
96-
String sortedQueryString = sort.isSorted() ? QueryEnhancerFactory.forQuery(query) //
97-
.applySorting(sort, query.getAlias()) : query.getQueryString();
96+
String sortedQueryString = applySortingIfNecessary(query, sort);
9897

9998
ResultProcessor processor = getQueryMethod().getResultProcessor().withDynamicProjection(accessor);
10099

101-
Query query = createJpaQuery(sortedQueryString, accessor.getSort(), accessor.getPageable(),
102-
processor.getReturnedType());
100+
Query query = createJpaQuery(sortedQueryString, sort, accessor.getPageable(), processor.getReturnedType());
103101

104102
QueryParameterSetter.QueryMetadata metadata = metadataCache.getMetadata(sortedQueryString, query);
105103

@@ -108,6 +106,18 @@ public Query doCreateQuery(JpaParametersParameterAccessor accessor) {
108106
return parameterBinder.get().bindAndPrepare(query, metadata, accessor);
109107
}
110108

109+
private String applySortingIfNecessary(DeclaredQuery query, Sort sort) {
110+
111+
if (sort.isUnsorted()) {
112+
return query.getQueryString();
113+
}
114+
return applySorting(query, sort);
115+
}
116+
117+
protected String applySorting(DeclaredQuery query, Sort sort) {
118+
return QueryEnhancerFactory.forQuery(query).applySorting(sort, query.getAlias());
119+
}
120+
111121
@Override
112122
protected ParameterBinder createBinder() {
113123

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.jpa.repository.query;
17+
18+
import static org.mockito.Mockito.*;
19+
20+
import jakarta.persistence.EntityManager;
21+
import jakarta.persistence.metamodel.Metamodel;
22+
23+
import java.lang.reflect.Method;
24+
import java.util.ArrayList;
25+
import java.util.List;
26+
import java.util.function.Supplier;
27+
28+
import org.assertj.core.api.Assertions;
29+
import org.assertj.core.util.Arrays;
30+
import org.junit.jupiter.api.Test;
31+
import org.mockito.Mockito;
32+
import org.springframework.core.annotation.AnnotatedElementUtils;
33+
import org.springframework.data.domain.Pageable;
34+
import org.springframework.data.domain.Sort;
35+
import org.springframework.data.jpa.provider.QueryExtractor;
36+
import org.springframework.data.jpa.repository.Query;
37+
import org.springframework.data.jpa.repository.QueryRewriter;
38+
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
39+
import org.springframework.data.repository.Repository;
40+
import org.springframework.data.repository.core.RepositoryMetadata;
41+
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
42+
import org.springframework.data.repository.query.ParametersSource;
43+
import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider;
44+
import org.springframework.data.repository.query.ReturnedType;
45+
import org.springframework.expression.spel.standard.SpelExpressionParser;
46+
import org.springframework.lang.Nullable;
47+
import org.springframework.util.LinkedMultiValueMap;
48+
import org.springframework.util.MultiValueMap;
49+
import org.springframework.util.ReflectionUtils;
50+
51+
/**
52+
* @author Christoph Strobl
53+
*/
54+
public class AbstractStringBasedJpaQueryUnitTests {
55+
56+
@Test // GH-3310
57+
void shouldNotAttemptToAppendSortIfNoSortArgumentPresent() {
58+
59+
InvocationCapturingStringQueryStub stringQuery = forMethod(TestRepo.class, "find");
60+
stringQuery.createQueryWithArguments();
61+
62+
stringQuery.neverCalled("applySorting");
63+
}
64+
65+
@Test // GH-3310
66+
void shouldNotAttemptToAppendSortIfSortIndicatesUnsorted() {
67+
68+
InvocationCapturingStringQueryStub stringQuery = forMethod(TestRepo.class, "find", Sort.class);
69+
stringQuery.createQueryWithArguments(Sort.unsorted());
70+
71+
stringQuery.neverCalled("applySorting");
72+
}
73+
74+
@Test // GH-3310
75+
void shouldAppendSortIfSortPresent() {
76+
77+
InvocationCapturingStringQueryStub stringQuery = forMethod(TestRepo.class, "find", Sort.class);
78+
stringQuery.createQueryWithArguments(Sort.by("name"));
79+
80+
stringQuery.called("applySorting").times(1);
81+
}
82+
83+
interface TestRepo extends Repository<Object, Object> {
84+
85+
@Query("SELECT e FROM Employee e")
86+
Object find();
87+
88+
@Query("SELECT e FROM Employee e")
89+
Object find(Sort sort);
90+
}
91+
92+
static InvocationCapturingStringQueryStub forMethod(Class<?> repository, String method, Class<?>... args) {
93+
94+
Method respositoryMethod = ReflectionUtils.findMethod(repository, method, args);
95+
RepositoryMetadata repositoryMetadata = new DefaultRepositoryMetadata(repository);
96+
SpelAwareProxyProjectionFactory projectionFactory = Mockito.mock(SpelAwareProxyProjectionFactory.class);
97+
QueryExtractor queryExtractor = Mockito.mock(QueryExtractor.class);
98+
JpaQueryMethod queryMethod = new JpaQueryMethod(respositoryMethod, repositoryMetadata, projectionFactory,
99+
queryExtractor);
100+
101+
Query query = AnnotatedElementUtils.getMergedAnnotation(respositoryMethod, Query.class);
102+
103+
return new InvocationCapturingStringQueryStub(respositoryMethod, queryMethod, query.value(), query.countQuery(),
104+
new SpelExpressionParser());
105+
106+
}
107+
108+
static class InvocationCapturingStringQueryStub extends AbstractStringBasedJpaQuery {
109+
110+
private final Method targetMethod;
111+
private final MultiValueMap<String, Arguments> capturedArguments = new LinkedMultiValueMap<>(3);
112+
113+
InvocationCapturingStringQueryStub(Method targetMethod, JpaQueryMethod queryMethod, String queryString,
114+
@Nullable String countQueryString, SpelExpressionParser parser) {
115+
super(queryMethod, new Supplier<EntityManager>() {
116+
117+
@Override
118+
public EntityManager get() {
119+
120+
EntityManager em = Mockito.mock(EntityManager.class);
121+
122+
Metamodel meta = mock(Metamodel.class);
123+
when(em.getMetamodel()).thenReturn(meta);
124+
when(em.getDelegate()).thenReturn(new Object()); // some generic jpa
125+
126+
return em;
127+
}
128+
}.get(), queryString, countQueryString, Mockito.mock(QueryRewriter.class),
129+
Mockito.mock(QueryMethodEvaluationContextProvider.class), parser);
130+
131+
this.targetMethod = targetMethod;
132+
}
133+
134+
@Override
135+
protected String applySorting(DeclaredQuery query, Sort sort) {
136+
137+
captureInvocation("applySorting", query, sort);
138+
139+
return super.applySorting(query, sort);
140+
}
141+
142+
@Override
143+
protected jakarta.persistence.Query createJpaQuery(String queryString, Sort sort, @Nullable Pageable pageable,
144+
ReturnedType returnedType) {
145+
146+
captureInvocation("createJpaQuery", queryString, sort, pageable, returnedType);
147+
148+
jakarta.persistence.Query jpaQuery = super.createJpaQuery(queryString, sort, pageable, returnedType);
149+
return jpaQuery == null ? Mockito.mock(jakarta.persistence.Query.class) : jpaQuery;
150+
}
151+
152+
// --> convenience for tests
153+
154+
JpaParameters getParameters() {
155+
return new JpaParameters(ParametersSource.of(targetMethod));
156+
}
157+
158+
JpaParametersParameterAccessor getParameterAccessor(Object... args) {
159+
return new JpaParametersParameterAccessor(getParameters(), args);
160+
}
161+
162+
jakarta.persistence.Query createQueryWithArguments(Object... args) {
163+
return doCreateQuery(getParameterAccessor(args));
164+
}
165+
166+
// --> capturing methods
167+
168+
private void captureInvocation(String key, Object... args) {
169+
capturedArguments.add(key, new Arguments(args));
170+
}
171+
172+
// --> verification methdos
173+
174+
int getInvocationCount(String method) {
175+
176+
List<Arguments> invocations = capturedArguments.get(method);
177+
return invocations != null ? invocations.size() : 0;
178+
}
179+
180+
public void neverCalled(String method) {
181+
called(method).never();
182+
}
183+
184+
public Times called(String method) {
185+
186+
return (invocationCount -> {
187+
188+
int actualCount = getInvocationCount(method);
189+
Assertions.assertThat(actualCount)
190+
.withFailMessage(
191+
() -> "Expected %d invocations for %s, but recorded %d".formatted(invocationCount, method, actualCount))
192+
.isEqualTo(invocationCount);
193+
});
194+
}
195+
196+
static class Arguments {
197+
198+
List<Object> values = new ArrayList<>(3);
199+
200+
public Arguments(Object... values) {
201+
this.values = Arrays.asList(values);
202+
}
203+
}
204+
205+
interface Times {
206+
207+
void times(int invocationCount);
208+
209+
default void never() {
210+
times(0);
211+
}
212+
}
213+
214+
}
215+
}

0 commit comments

Comments
 (0)