Skip to content

Commit 75904a9

Browse files
committed
Attempt two-pass parsing in SLL prediction and fall back to LL prediction.
Due to grammar ambiguities, we fall back to LL prediction considering contextual ambiguity resolution. Closes spring-projects#3757
1 parent 3dfd0c7 commit 75904a9

File tree

9 files changed

+70
-77
lines changed

9 files changed

+70
-77
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryEnhancer.java

+31-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.function.BiFunction;
2121
import java.util.function.Function;
2222

23+
import org.antlr.v4.runtime.BailErrorStrategy;
2324
import org.antlr.v4.runtime.CharStream;
2425
import org.antlr.v4.runtime.CharStreams;
2526
import org.antlr.v4.runtime.CommonTokenStream;
@@ -28,7 +29,9 @@
2829
import org.antlr.v4.runtime.ParserRuleContext;
2930
import org.antlr.v4.runtime.TokenStream;
3031
import org.antlr.v4.runtime.atn.PredictionMode;
32+
import org.antlr.v4.runtime.misc.ParseCancellationException;
3133
import org.antlr.v4.runtime.tree.ParseTreeVisitor;
34+
3235
import org.springframework.data.domain.Sort;
3336
import org.springframework.data.repository.query.ReturnedType;
3437
import org.springframework.lang.Nullable;
@@ -68,9 +71,35 @@ class JpaQueryEnhancer<Q extends QueryInformation> implements QueryEnhancer {
6871
this.projection = tokens.isEmpty() ? "" : new QueryRenderer.TokenRenderer(tokens).render();
6972
}
7073

74+
/**
75+
* Parse the query and return the parser context (AST). This method attempts parsing the query using
76+
* {@link PredictionMode#SLL} first to attempt a fast-path parse without using the context. If that fails, it retries
77+
* using {@link PredictionMode#LL} which is much slower, however it allows for contextual ambiguity resolution.
78+
*/
7179
static <P extends Parser> ParserRuleContext parse(String query, Function<CharStream, Lexer> lexerFactoryFunction,
7280
Function<TokenStream, P> parserFactoryFunction, Function<P, ParserRuleContext> parseFunction) {
7381

82+
P parser = getParser(query, lexerFactoryFunction, parserFactoryFunction);
83+
84+
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
85+
parser.setErrorHandler(new BailErrorStrategy());
86+
87+
try {
88+
89+
return parseFunction.apply(parser);
90+
} catch (BadJpqlGrammarException | ParseCancellationException e) {
91+
92+
parser = getParser(query, lexerFactoryFunction, parserFactoryFunction);
93+
// fall back to LL(*)-based parsing
94+
parser.getInterpreter().setPredictionMode(PredictionMode.LL);
95+
96+
return parseFunction.apply(parser);
97+
}
98+
}
99+
100+
private static <P extends Parser> P getParser(String query, Function<CharStream, Lexer> lexerFactoryFunction,
101+
Function<TokenStream, P> parserFactoryFunction) {
102+
74103
Lexer lexer = lexerFactoryFunction.apply(CharStreams.fromString(query));
75104
P parser = parserFactoryFunction.apply(new CommonTokenStream(lexer));
76105

@@ -82,11 +111,11 @@ static <P extends Parser> ParserRuleContext parse(String query, Function<CharStr
82111

83112
configureParser(query, grammar.toUpperCase(), lexer, parser);
84113

85-
return parseFunction.apply(parser);
114+
return parser;
86115
}
87116

88117
/**
89-
* Apply common configuration (SLL prediction for performance, our own error listeners).
118+
* Apply common configuration.
90119
*
91120
* @param query the query input to parse.
92121
* @param grammar name of the grammar.
@@ -100,8 +129,6 @@ static void configureParser(String query, String grammar, Lexer lexer, Parser pa
100129
lexer.removeErrorListeners();
101130
lexer.addErrorListener(errorListener);
102131

103-
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
104-
105132
parser.removeErrorListeners();
106133
parser.addErrorListener(errorListener);
107134
}

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/EqlComplianceTests.java

+3-9
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
import static org.assertj.core.api.Assertions.*;
1919

20-
import org.antlr.v4.runtime.CharStreams;
21-
import org.antlr.v4.runtime.CommonTokenStream;
2220
import org.junit.jupiter.api.Test;
21+
2322
import org.springframework.data.jpa.repository.query.QueryRenderer.TokenRenderer;
2423

2524
/**
@@ -40,14 +39,9 @@ class EqlComplianceTests {
4039
*/
4140
private static String parseWithoutChanges(String query) {
4241

43-
EqlLexer lexer = new EqlLexer(CharStreams.fromString(query));
44-
EqlParser parser = new EqlParser(new CommonTokenStream(lexer));
45-
46-
parser.addErrorListener(new BadJpqlGrammarErrorListener(query));
47-
48-
EqlParser.StartContext parsedQuery = parser.start();
42+
JpaQueryEnhancer.EqlQueryParser parser = JpaQueryEnhancer.EqlQueryParser.parseQuery(query);
4943

50-
return TokenRenderer.render(new EqlQueryRenderer().visit(parsedQuery));
44+
return TokenRenderer.render(new EqlQueryRenderer().visit(parser.getContext()));
5145
}
5246

5347
private void assertQuery(String query) {

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/EqlQueryRendererTests.java

+3-9
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919

2020
import java.util.stream.Stream;
2121

22-
import org.antlr.v4.runtime.CharStreams;
23-
import org.antlr.v4.runtime.CommonTokenStream;
2422
import org.junit.jupiter.api.Disabled;
2523
import org.junit.jupiter.api.Test;
2624
import org.junit.jupiter.params.ParameterizedTest;
2725
import org.junit.jupiter.params.provider.Arguments;
2826
import org.junit.jupiter.params.provider.MethodSource;
2927
import org.junit.jupiter.params.provider.ValueSource;
28+
3029
import org.springframework.data.jpa.repository.query.QueryRenderer.TokenRenderer;
3130

3231
/**
@@ -47,14 +46,9 @@ class EqlQueryRendererTests {
4746
*/
4847
private static String parseWithoutChanges(String query) {
4948

50-
EqlLexer lexer = new EqlLexer(CharStreams.fromString(query));
51-
EqlParser parser = new EqlParser(new CommonTokenStream(lexer));
52-
53-
parser.addErrorListener(new BadJpqlGrammarErrorListener(query));
54-
55-
EqlParser.StartContext parsedQuery = parser.start();
49+
JpaQueryEnhancer.EqlQueryParser parser = JpaQueryEnhancer.EqlQueryParser.parseQuery(query);
5650

57-
return TokenRenderer.render(new EqlQueryRenderer().visit(parsedQuery));
51+
return TokenRenderer.render(new EqlQueryRenderer().visit(parser.getContext()));
5852
}
5953

6054
static Stream<Arguments> reservedWords() {

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/EqlSpecificationTests.java

+3-9
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
import static org.assertj.core.api.Assertions.*;
1919

20-
import org.antlr.v4.runtime.CharStreams;
21-
import org.antlr.v4.runtime.CommonTokenStream;
2220
import org.junit.jupiter.api.Disabled;
2321
import org.junit.jupiter.api.Test;
22+
2423
import org.springframework.data.jpa.repository.query.QueryRenderer.TokenRenderer;
2524

2625
/**
@@ -37,14 +36,9 @@ class EqlSpecificationTests {
3736

3837
private static String parseWithoutChanges(String query) {
3938

40-
EqlLexer lexer = new EqlLexer(CharStreams.fromString(query));
41-
EqlParser parser = new EqlParser(new CommonTokenStream(lexer));
42-
43-
parser.addErrorListener(new BadJpqlGrammarErrorListener(query));
44-
45-
EqlParser.StartContext parsedQuery = parser.start();
39+
JpaQueryEnhancer.EqlQueryParser parser = JpaQueryEnhancer.EqlQueryParser.parseQuery(query);
4640

47-
return TokenRenderer.render(new EqlQueryRenderer().visit(parsedQuery));
41+
return TokenRenderer.render(new EqlQueryRenderer().visit(parser.getContext()));
4842
}
4943

5044
private void assertQuery(String query) {

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/HqlQueryRendererTests.java

+18-9
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
import java.util.stream.Stream;
2121

22-
import org.antlr.v4.runtime.CharStreams;
23-
import org.antlr.v4.runtime.CommonTokenStream;
2422
import org.junit.jupiter.api.Disabled;
2523
import org.junit.jupiter.api.Test;
2624
import org.junit.jupiter.params.ParameterizedTest;
@@ -50,14 +48,9 @@ class HqlQueryRendererTests {
5048
*/
5149
private static String parseWithoutChanges(String query) {
5250

53-
HqlLexer lexer = new HqlLexer(CharStreams.fromString(query));
54-
HqlParser parser = new HqlParser(new CommonTokenStream(lexer));
51+
JpaQueryEnhancer.HqlQueryParser parser = JpaQueryEnhancer.HqlQueryParser.parseQuery(query);
5552

56-
parser.addErrorListener(new BadJpqlGrammarErrorListener(query));
57-
58-
HqlParser.StartContext parsedQuery = parser.start();
59-
60-
QueryTokenStream tokens = new HqlQueryRenderer().visit(parsedQuery);
53+
QueryTokenStream tokens = new HqlQueryRenderer().visit(parser.getContext());
6154
return QueryRenderer.from(tokens).render();
6255
}
6356

@@ -1891,6 +1884,22 @@ group by extract(epoch from departureTime)
18911884
""");
18921885
}
18931886

1887+
@Test // GH-3757
1888+
void arithmeticDate() {
1889+
1890+
assertQuery("SELECT a FROM foo a WHERE (cast(a.createdAt as date) - CURRENT_DATE()) BY day - 2 = 0");
1891+
assertQuery("SELECT a FROM foo a WHERE (cast(a.createdAt as date) - CURRENT_DATE()) BY day - 2 = 0");
1892+
assertQuery("SELECT a FROM foo a WHERE (cast(a.createdAt as date)) BY day - 2 = 0");
1893+
1894+
assertQuery("SELECT f.start - 1 minute FROM foo f");
1895+
1896+
assertQuery("SELECT f FROM foo f WHERE (cast(f.start as date) - CURRENT_DATE()) BY day - 2 = 0");
1897+
assertQuery("SELECT 1 week - 1 day FROM foo f");
1898+
assertQuery("SELECT f.birthday - local date day FROM foo f");
1899+
assertQuery("SELECT local datetime - f.birthday FROM foo f");
1900+
assertQuery("SELECT (1 year) by day FROM foo f");
1901+
}
1902+
18941903
@ParameterizedTest // GH-3342
18951904
@ValueSource(
18961905
strings = { "select 1 from User", "select -1 from User", "select +1 from User", "select +1 * -100 from User",

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/HqlSpecificationTests.java

+4-10
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717

1818
import static org.assertj.core.api.Assertions.*;
1919

20-
import org.antlr.v4.runtime.CharStreams;
21-
import org.antlr.v4.runtime.CommonTokenStream;
2220
import org.junit.jupiter.api.Disabled;
2321
import org.junit.jupiter.api.Test;
2422
import org.junit.jupiter.params.ParameterizedTest;
2523
import org.junit.jupiter.params.provider.ValueSource;
24+
2625
import org.springframework.data.jpa.repository.query.QueryRenderer.TokenRenderer;
2726

2827
/**
@@ -43,14 +42,9 @@ class HqlSpecificationTests {
4342

4443
private static String parseWithoutChanges(String query) {
4544

46-
HqlLexer lexer = new HqlLexer(CharStreams.fromString(query));
47-
HqlParser parser = new HqlParser(new CommonTokenStream(lexer));
48-
49-
parser.addErrorListener(new BadJpqlGrammarErrorListener(query));
45+
JpaQueryEnhancer.HqlQueryParser parser = JpaQueryEnhancer.HqlQueryParser.parseQuery(query);
5046

51-
HqlParser.StartContext parsedQuery = parser.start();
52-
53-
return TokenRenderer.render(new HqlQueryRenderer().visit(parsedQuery));
47+
return TokenRenderer.render(new HqlQueryRenderer().visit(parser.getContext()));
5448
}
5549

5650
private void assertQuery(String query) {
@@ -490,7 +484,7 @@ void position() {
490484
"from Call c ");
491485

492486
assertQuery("select POSITION(c.number IN 'foo') + 1 AS pos " + //
493-
"from Call c ");
487+
"from Call c ");
494488
}
495489

496490
@Test // GH-3689

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlComplianceTests.java

+2-9
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
import static org.assertj.core.api.Assertions.*;
1919

20-
import org.antlr.v4.runtime.CharStreams;
21-
import org.antlr.v4.runtime.CommonTokenStream;
2220
import org.junit.jupiter.api.Test;
2321

2422
/**
@@ -32,14 +30,9 @@ class JpqlComplianceTests {
3230

3331
private static String parseWithoutChanges(String query) {
3432

35-
JpqlLexer lexer = new JpqlLexer(CharStreams.fromString(query));
36-
JpqlParser parser = new JpqlParser(new CommonTokenStream(lexer));
33+
JpaQueryEnhancer.JpqlQueryParser parser = JpaQueryEnhancer.JpqlQueryParser.parseQuery(query);
3734

38-
parser.addErrorListener(new BadJpqlGrammarErrorListener(query));
39-
40-
JpqlParser.StartContext parsedQuery = parser.start();
41-
42-
return QueryRenderer.render(new JpqlQueryRenderer().visit(parsedQuery));
35+
return QueryRenderer.render(new JpqlQueryRenderer().visit(parser.getContext()));
4336
}
4437

4538
private void assertQuery(String query) {

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlQueryRendererTests.java

+3-9
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919

2020
import java.util.stream.Stream;
2121

22-
import org.antlr.v4.runtime.CharStreams;
23-
import org.antlr.v4.runtime.CommonTokenStream;
2422
import org.junit.jupiter.api.Disabled;
2523
import org.junit.jupiter.api.Test;
2624
import org.junit.jupiter.params.ParameterizedTest;
2725
import org.junit.jupiter.params.provider.Arguments;
2826
import org.junit.jupiter.params.provider.MethodSource;
2927
import org.junit.jupiter.params.provider.ValueSource;
28+
3029
import org.springframework.data.jpa.repository.query.QueryRenderer.TokenRenderer;
3130

3231
/**
@@ -48,14 +47,9 @@ class JpqlQueryRendererTests {
4847
*/
4948
private static String parseWithoutChanges(String query) {
5049

51-
JpqlLexer lexer = new JpqlLexer(CharStreams.fromString(query));
52-
JpqlParser parser = new JpqlParser(new CommonTokenStream(lexer));
53-
54-
parser.addErrorListener(new BadJpqlGrammarErrorListener(query));
55-
56-
JpqlParser.StartContext parsedQuery = parser.start();
50+
JpaQueryEnhancer.JpqlQueryParser parser = JpaQueryEnhancer.JpqlQueryParser.parseQuery(query);
5751

58-
return TokenRenderer.render(new JpqlQueryRenderer().visit(parsedQuery));
52+
return TokenRenderer.render(new JpqlQueryRenderer().visit(parser.getContext()));
5953
}
6054

6155
static Stream<Arguments> reservedWords() {

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlSpecificationTests.java

+3-9
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
import static org.assertj.core.api.Assertions.*;
1919

20-
import org.antlr.v4.runtime.CharStreams;
21-
import org.antlr.v4.runtime.CommonTokenStream;
2220
import org.junit.jupiter.api.Disabled;
2321
import org.junit.jupiter.api.Test;
22+
2423
import org.springframework.data.jpa.repository.query.QueryRenderer.TokenRenderer;
2524

2625
/**
@@ -41,14 +40,9 @@ class JpqlSpecificationTests {
4140
*/
4241
private static String parseWithoutChanges(String query) {
4342

44-
JpqlLexer lexer = new JpqlLexer(CharStreams.fromString(query));
45-
JpqlParser parser = new JpqlParser(new CommonTokenStream(lexer));
46-
47-
parser.addErrorListener(new BadJpqlGrammarErrorListener(query));
48-
49-
JpqlParser.StartContext parsedQuery = parser.start();
43+
JpaQueryEnhancer.JpqlQueryParser parser = JpaQueryEnhancer.JpqlQueryParser.parseQuery(query);
5044

51-
return TokenRenderer.render(new JpqlQueryRenderer().visit(parsedQuery));
45+
return TokenRenderer.render(new JpqlQueryRenderer().visit(parser.getContext()));
5246
}
5347

5448
private void assertQuery(String query) {

0 commit comments

Comments
 (0)