Skip to content

Commit a5cf01b

Browse files
DiegoKrupitzagregturn
authored andcommitted
Fixes * bug in createCountQueryFor.
In commit 3e64d9a a bug got introduced that uses the next symbol after the table name for the count function. With this commit this should be now resolved. The count query will use `*` when there is no alias present nor a variable. Related tickets #2341, #2177, #2260, #2511
1 parent 7fe72db commit a5cf01b

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -525,14 +525,19 @@ public static String createCountQueryFor(String originalQuery, @Nullable String
525525
boolean useVariable = StringUtils.hasText(variable) //
526526
&& !variable.startsWith(" new") //
527527
&& !variable.startsWith("count(") //
528-
&& !variable.contains(",") //
529-
&& !variable.contains("*");
528+
&& !variable.contains(",");
530529

531530
String complexCountValue = matcher.matches() && StringUtils.hasText(matcher.group(COMPLEX_COUNT_FIRST_INDEX))
532531
? COMPLEX_COUNT_VALUE
533532
: COMPLEX_COUNT_LAST_VALUE;
534533

535534
String replacement = useVariable ? SIMPLE_COUNT_VALUE : complexCountValue;
535+
536+
String alias = QueryUtils.detectAlias(originalQuery);
537+
if("*".equals(variable) && alias != null) {
538+
replacement = alias;
539+
}
540+
536541
countQuery = matcher.replaceFirst(String.format(COUNT_REPLACEMENT_TEMPLATE, replacement));
537542
} else {
538543
countQuery = matcher.replaceFirst(String.format(COUNT_REPLACEMENT_TEMPLATE, countProjection));

src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java

+20
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,26 @@ void correctApplySortOnComplexNestedFunctionQuery() {
712712
assertThat(result).containsIgnoringCase("order by dd.institutesIds");
713713
}
714714

715+
716+
@Test //GH-2511
717+
void countQueryUsesCorrectVariable() {
718+
StringQuery nativeQuery = new StringQuery("SELECT * FROM User WHERE created_at > $1", true);
719+
QueryEnhancer queryEnhancer = getEnhancer(nativeQuery);
720+
String countQueryFor = queryEnhancer.createCountQueryFor();
721+
assertThat(countQueryFor).isEqualTo("SELECT count(*) FROM User WHERE created_at > $1");
722+
723+
nativeQuery = new StringQuery("SELECT * FROM (select * from test) ",true);
724+
queryEnhancer = getEnhancer(nativeQuery);
725+
countQueryFor = queryEnhancer.createCountQueryFor();
726+
assertThat(countQueryFor).isEqualTo("SELECT count(*) FROM (SELECT * FROM test)");
727+
728+
nativeQuery = new StringQuery("SELECT * FROM (select * from test) as test",true);
729+
queryEnhancer = getEnhancer(nativeQuery);
730+
countQueryFor = queryEnhancer.createCountQueryFor();
731+
assertThat(countQueryFor).isEqualTo("SELECT count(test) FROM (SELECT * FROM test) AS test");
732+
}
733+
734+
715735
public static Stream<Arguments> detectsJoinAliasesCorrectlySource() {
716736

717737
return Stream.of( //

src/test/java/org/springframework/data/jpa/repository/query/QueryUtilsUnitTests.java

+17
Original file line numberDiff line numberDiff line change
@@ -638,4 +638,21 @@ void applySortingAccountsForNativeWindowFunction() {
638638
"select * from (select * from user order by 1, 2, 3 desc limit 10) u order by u.active asc, age desc");
639639
}
640640

641+
@Test //GH-2511
642+
void countQueryUsesCorrectVariable() {
643+
String countQueryFor = createCountQueryFor("SELECT * FROM User WHERE created_at > $1");
644+
assertThat(countQueryFor).isEqualTo("select count(*) FROM User WHERE created_at > $1");
645+
646+
countQueryFor = createCountQueryFor("SELECT * FROM mytable WHERE nr = :number AND kon = :kon AND datum >= '2019-01-01'");
647+
assertThat(countQueryFor).isEqualTo("select count(*) FROM mytable WHERE nr = :number AND kon = :kon AND datum >= '2019-01-01'");
648+
649+
countQueryFor = createCountQueryFor("SELECT * FROM context ORDER BY time");
650+
assertThat(countQueryFor).isEqualTo("select count(*) FROM context");
651+
652+
countQueryFor = createCountQueryFor("select * FROM users_statuses WHERE (user_created_at BETWEEN $1 AND $2)");
653+
assertThat(countQueryFor).isEqualTo("select count(*) FROM users_statuses WHERE (user_created_at BETWEEN $1 AND $2)");
654+
655+
countQueryFor = createCountQueryFor("SELECT * FROM users_statuses us WHERE (user_created_at BETWEEN :fromDate AND :toDate)");
656+
assertThat(countQueryFor).isEqualTo("select count(us) FROM users_statuses us WHERE (user_created_at BETWEEN :fromDate AND :toDate)");
657+
}
641658
}

0 commit comments

Comments
 (0)