Skip to content

Commit 64b5a22

Browse files
mp911deschauder
authored andcommitted
Differentiate between JPQL and native queries in count query derivation.
We now consider whether a query is a native one when deriving a count query for pagination. Previously, the generated queries used JPQL syntax that doesn't comply with native SQL syntax rules. Closes #2773 Original pull request #2777
1 parent 6dc008f commit 64b5a22

File tree

7 files changed

+217
-24
lines changed

7 files changed

+217
-24
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/DefaultQueryEnhancer.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public String detectAlias() {
4646

4747
@Override
4848
public String createCountQueryFor(@Nullable String countProjection) {
49-
return QueryUtils.createCountQueryFor(this.query.getQueryString(), countProjection);
49+
return QueryUtils.createCountQueryFor(this.query.getQueryString(), countProjection, this.query.isNativeQuery());
5050
}
5151

5252
@Override

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java

+17-6
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
*/
1616
package org.springframework.data.jpa.repository.query;
1717

18-
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlCount;
19-
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlLower;
20-
import static org.springframework.data.jpa.repository.query.QueryUtils.checkSortExpression;
18+
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.*;
19+
import static org.springframework.data.jpa.repository.query.QueryUtils.*;
2120

2221
import net.sf.jsqlparser.JSQLParserException;
2322
import net.sf.jsqlparser.expression.Alias;
@@ -29,11 +28,23 @@
2928
import net.sf.jsqlparser.statement.delete.Delete;
3029
import net.sf.jsqlparser.statement.insert.Insert;
3130
import net.sf.jsqlparser.statement.merge.Merge;
32-
import net.sf.jsqlparser.statement.select.*;
31+
import net.sf.jsqlparser.statement.select.OrderByElement;
32+
import net.sf.jsqlparser.statement.select.PlainSelect;
33+
import net.sf.jsqlparser.statement.select.Select;
34+
import net.sf.jsqlparser.statement.select.SelectBody;
35+
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
36+
import net.sf.jsqlparser.statement.select.SelectItem;
37+
import net.sf.jsqlparser.statement.select.SetOperationList;
38+
import net.sf.jsqlparser.statement.select.WithItem;
3339
import net.sf.jsqlparser.statement.update.Update;
3440
import net.sf.jsqlparser.statement.values.ValuesStatement;
3541

36-
import java.util.*;
42+
import java.util.ArrayList;
43+
import java.util.Collections;
44+
import java.util.HashSet;
45+
import java.util.List;
46+
import java.util.Objects;
47+
import java.util.Set;
3748
import java.util.stream.Collectors;
3849

3950
import org.springframework.data.domain.Sort;
@@ -400,7 +411,7 @@ public String createCountQueryFor(@Nullable String countProjection) {
400411
return selectBody.toString();
401412
}
402413

403-
String countProp = tableAlias == null ? "*" : tableAlias;
414+
String countProp = query.isNativeQuery() ? (distinct ? "*" : "1") : tableAlias == null ? "*" : tableAlias;
404415

405416
Function jSqlCount = getJSqlCount(Collections.singletonList(countProp), distinct);
406417
selectBody.setSelectItems(Collections.singletonList(new SelectExpressionItem(jSqlCount)));

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java

+37-6
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,23 @@
1818
import static jakarta.persistence.metamodel.Attribute.PersistentAttributeType.*;
1919
import static java.util.regex.Pattern.*;
2020

21-
import jakarta.persistence.*;
22-
import jakarta.persistence.criteria.*;
23-
import jakarta.persistence.metamodel.*;
21+
import jakarta.persistence.EntityManager;
22+
import jakarta.persistence.ManyToOne;
23+
import jakarta.persistence.OneToOne;
24+
import jakarta.persistence.Parameter;
25+
import jakarta.persistence.Query;
26+
import jakarta.persistence.criteria.CriteriaBuilder;
27+
import jakarta.persistence.criteria.Expression;
28+
import jakarta.persistence.criteria.Fetch;
29+
import jakarta.persistence.criteria.From;
30+
import jakarta.persistence.criteria.Join;
31+
import jakarta.persistence.criteria.JoinType;
32+
import jakarta.persistence.metamodel.Attribute;
2433
import jakarta.persistence.metamodel.Attribute.PersistentAttributeType;
34+
import jakarta.persistence.metamodel.Bindable;
35+
import jakarta.persistence.metamodel.ManagedType;
36+
import jakarta.persistence.metamodel.PluralAttribute;
37+
import jakarta.persistence.metamodel.SingularAttribute;
2538

2639
import java.lang.annotation.Annotation;
2740
import java.lang.reflect.AnnotatedElement;
@@ -570,6 +583,19 @@ public static String createCountQueryFor(String originalQuery) {
570583
*/
571584
@Deprecated
572585
public static String createCountQueryFor(String originalQuery, @Nullable String countProjection) {
586+
return createCountQueryFor(originalQuery, countProjection, false);
587+
}
588+
589+
/**
590+
* Creates a count projected query from the given original query.
591+
*
592+
* @param originalQuery must not be {@literal null}.
593+
* @param countProjection may be {@literal null}.
594+
* @param nativeQuery whether the underlying query is a native query.
595+
* @return a query String to be used a count query for pagination. Guaranteed to be not {@literal null}.
596+
* @since 2.7.8
597+
*/
598+
static String createCountQueryFor(String originalQuery, @Nullable String countProjection, boolean nativeQuery) {
573599

574600
Assert.hasText(originalQuery, "OriginalQuery must not be null or empty");
575601

@@ -591,9 +617,14 @@ public static String createCountQueryFor(String originalQuery, @Nullable String
591617

592618
String replacement = useVariable ? SIMPLE_COUNT_VALUE : complexCountValue;
593619

594-
String alias = QueryUtils.detectAlias(originalQuery);
595-
if ("*".equals(variable) && alias != null) {
596-
replacement = alias;
620+
if (nativeQuery && (variable.contains(",") || "*".equals(variable))) {
621+
replacement = "1";
622+
} else {
623+
624+
String alias = QueryUtils.detectAlias(originalQuery);
625+
if (("*".equals(variable) && alias != null)) {
626+
replacement = alias;
627+
}
597628
}
598629

599630
countQuery = matcher.replaceFirst(String.format(COUNT_REPLACEMENT_TEMPLATE, replacement));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright 2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.jpa.repository.query;
17+
18+
/**
19+
* TCK Tests for {@link DefaultQueryEnhancer}.
20+
*
21+
* @author Mark Paluch
22+
*/
23+
public class DefaultQueryEnhancerUnitTests extends QueryEnhancerTckTests {
24+
25+
@Override
26+
QueryEnhancer createQueryEnhancer(DeclaredQuery declaredQuery) {
27+
return new DefaultQueryEnhancer(declaredQuery);
28+
}
29+
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright 2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.jpa.repository.query;
17+
18+
/**
19+
* TCK Tests for {@link JSqlParserQueryEnhancer}.
20+
*
21+
* @author Mark Paluch
22+
*/
23+
public class JSqlParserQueryEnhancerUnitTests extends QueryEnhancerTckTests {
24+
25+
@Override
26+
QueryEnhancer createQueryEnhancer(DeclaredQuery declaredQuery) {
27+
return new JSqlParserQueryEnhancer(declaredQuery);
28+
}
29+
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright 2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.jpa.repository.query;
17+
18+
import static org.assertj.core.api.Assertions.*;
19+
20+
import java.util.stream.Stream;
21+
22+
import org.junit.jupiter.params.ParameterizedTest;
23+
import org.junit.jupiter.params.provider.Arguments;
24+
import org.junit.jupiter.params.provider.MethodSource;
25+
26+
/**
27+
* TCK Tests for {@link QueryEnhancer}.
28+
*
29+
* @author Mark Paluch
30+
*/
31+
abstract class QueryEnhancerTckTests {
32+
33+
@ParameterizedTest
34+
@MethodSource("nativeCountQueries") // GH-2773
35+
void shouldDeriveNativeCountQuery(String query, String expected) {
36+
37+
DeclaredQuery declaredQuery = DeclaredQuery.of(query, true);
38+
QueryEnhancer enhancer = createQueryEnhancer(declaredQuery);
39+
String countQueryFor = enhancer.createCountQueryFor(null);
40+
41+
assertThat(countQueryFor).isEqualToIgnoringCase(expected);
42+
}
43+
44+
static Stream<Arguments> nativeCountQueries() {
45+
46+
return Stream.of(Arguments.of( //
47+
"SELECT * FROM table_name some_alias", //
48+
"select count(1) FROM table_name some_alias"), //
49+
50+
Arguments.of( //
51+
"SELECT name FROM table_name some_alias", //
52+
"select count(name) FROM table_name some_alias"), //
53+
54+
Arguments.of( //
55+
"SELECT DISTINCT name FROM table_name some_alias", //
56+
"select count(DISTINCT name) FROM table_name some_alias"));
57+
}
58+
59+
@ParameterizedTest // GH-2773
60+
@MethodSource("jpqlCountQueries")
61+
void shouldDeriveJpqlCountQuery(String query, String expected) {
62+
63+
DeclaredQuery declaredQuery = DeclaredQuery.of(query, false);
64+
QueryEnhancer enhancer = createQueryEnhancer(declaredQuery);
65+
String countQueryFor = enhancer.createCountQueryFor(null);
66+
67+
assertThat(countQueryFor).isEqualToIgnoringCase(expected);
68+
}
69+
70+
static Stream<Arguments> jpqlCountQueries() {
71+
72+
return Stream.of(Arguments.of( //
73+
"SELECT some_alias FROM table_name some_alias", //
74+
"select count(some_alias) FROM table_name some_alias"), //
75+
76+
Arguments.of( //
77+
"SELECT name FROM table_name some_alias", //
78+
"select count(name) FROM table_name some_alias"), //
79+
80+
Arguments.of( //
81+
"SELECT DISTINCT name FROM table_name some_alias", //
82+
"select count(DISTINCT name) FROM table_name some_alias"));
83+
}
84+
85+
abstract QueryEnhancer createQueryEnhancer(DeclaredQuery declaredQuery);
86+
87+
}

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java

+15-11
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
*/
1616
package org.springframework.data.jpa.repository.query;
1717

18-
import static org.assertj.core.api.Assertions.assertThat;
19-
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
20-
import static org.assertj.core.api.Assertions.assertThatThrownBy;
18+
import static org.assertj.core.api.Assertions.*;
2119

2220
import java.util.Arrays;
2321
import java.util.Collections;
@@ -229,7 +227,12 @@ void createsCountQueryForNestedReferenceCorrectly() {
229227

230228
@Test // DATAJPA-420
231229
void createsCountQueryForScalarSelects() {
232-
assertCountQuery("select p.lastname,p.firstname from Person p", "select count(p) from Person p", true);
230+
assertCountQuery("select p.lastname,p.firstname from Person p", "select count(p) from Person p", false);
231+
}
232+
233+
@Test // DATAJPA-420
234+
void createsCountQueryForNativeScalarSelects() {
235+
assertCountQuery("select p.lastname,p.firstname from Person p", "select count(1) from Person p", true);
233236
}
234237

235238
@Test // DATAJPA-456
@@ -490,7 +493,7 @@ void createCountQuerySupportsWhitespaceCharacters() {
490493
" order by user.name\n ", true);
491494

492495
assertThat(getEnhancer(query).createCountQueryFor())
493-
.isEqualToIgnoringCase("select count(user) from User user where user.age = 18");
496+
.isEqualToIgnoringCase("select count(1) from User user where user.age = 18");
494497
}
495498

496499
@Test
@@ -503,7 +506,7 @@ void createCountQuerySupportsLineBreaksInSelectClause() {
503506
" order\nby\nuser.name\n ", true);
504507

505508
assertThat(getEnhancer(query).createCountQueryFor())
506-
.isEqualToIgnoringCase("select count(user) from User user where user.age = 18");
509+
.isEqualToIgnoringCase("select count(1) from User user where user.age = 18");
507510
}
508511

509512
@Test // DATAJPA-1061
@@ -724,17 +727,17 @@ void countQueryUsesCorrectVariable() {
724727

725728
QueryEnhancer queryEnhancer = getEnhancer(nativeQuery);
726729
String countQueryFor = queryEnhancer.createCountQueryFor();
727-
assertThat(countQueryFor).isEqualTo("SELECT count(*) FROM User WHERE created_at > $1");
730+
assertThat(countQueryFor).isEqualTo("SELECT count(1) FROM User WHERE created_at > $1");
728731

729732
nativeQuery = new StringQuery("SELECT * FROM (select * from test) ", true);
730733
queryEnhancer = getEnhancer(nativeQuery);
731734
countQueryFor = queryEnhancer.createCountQueryFor();
732-
assertThat(countQueryFor).isEqualTo("SELECT count(*) FROM (SELECT * FROM test)");
735+
assertThat(countQueryFor).isEqualTo("SELECT count(1) FROM (SELECT * FROM test)");
733736

734737
nativeQuery = new StringQuery("SELECT * FROM (select * from test) as test", true);
735738
queryEnhancer = getEnhancer(nativeQuery);
736739
countQueryFor = queryEnhancer.createCountQueryFor();
737-
assertThat(countQueryFor).isEqualTo("SELECT count(test) FROM (SELECT * FROM test) AS test");
740+
assertThat(countQueryFor).isEqualTo("SELECT count(1) FROM (SELECT * FROM test) AS test");
738741
}
739742

740743
@Test // GH-2555
@@ -864,7 +867,7 @@ void withStatementsWorksWithJSQLParser() {
864867

865868
assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase(
866869
"with sample_data (day, value) AS (VALUES ((0, 13), (1, 12), (2, 15), (3, 4), (4, 8), (5, 16)))\n"
867-
+ "SELECT count(a) FROM sample_data AS a");
870+
+ "SELECT count(1) FROM sample_data AS a");
868871
assertThat(queryEnhancer.applySorting(Sort.by("day").descending())).endsWith("ORDER BY a.day DESC");
869872
assertThat(queryEnhancer.getJoinAliases()).isEmpty();
870873
assertThat(queryEnhancer.detectAlias()).isEqualToIgnoringCase("a");
@@ -887,7 +890,7 @@ void multipleWithStatementsWorksWithJSQLParser() {
887890

888891
assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase(
889892
"with sample_data (day, value) AS (VALUES ((0, 13), (1, 12), (2, 15), (3, 4), (4, 8), (5, 16))),test2 AS (VALUES (1, 2, 3))\n"
890-
+ "SELECT count(a) FROM sample_data AS a");
893+
+ "SELECT count(1) FROM sample_data AS a");
891894
assertThat(queryEnhancer.applySorting(Sort.by("day").descending())).endsWith("ORDER BY a.day DESC");
892895
assertThat(queryEnhancer.getJoinAliases()).isEmpty();
893896
assertThat(queryEnhancer.detectAlias()).isEqualToIgnoringCase("a");
@@ -985,4 +988,5 @@ private static void assertCountQuery(StringQuery originalQuery, String countQuer
985988
private static QueryEnhancer getEnhancer(DeclaredQuery query) {
986989
return QueryEnhancerFactory.forQuery(query);
987990
}
991+
988992
}

0 commit comments

Comments
 (0)