Skip to content

Commit 7f9b349

Browse files
toverdijkmp911de
authored andcommitted
Extend PostgresqlSqlLexer to handle PG14 SQL-standard function body syntax
- Lexing/parsing is now done in two steps: first only tokenize, then parse into statements - Added support for function bodies ("BEGIN ATOMIC") - Added a test case for newly supported grammar [resolves #512][#513]
1 parent c74bd6e commit 7f9b349

File tree

5 files changed

+108
-73
lines changed

5 files changed

+108
-73
lines changed

src/main/java/io/r2dbc/postgresql/ParsedSql.java

+10-22
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,20 @@ class ParsedSql {
2424

2525
private final String sql;
2626

27-
private final List<TokenizedStatement> statements;
27+
private final List<Statement> statements;
2828

2929
private final int statementCount;
3030

3131
private final int parameterCount;
3232

33-
public ParsedSql(String sql, List<TokenizedStatement> statements) {
33+
public ParsedSql(String sql, List<Statement> statements) {
3434
this.sql = sql;
3535
this.statements = statements;
3636
this.statementCount = statements.size();
3737
this.parameterCount = getParameterCount(statements);
3838
}
3939

40-
List<TokenizedStatement> getStatements() {
40+
List<Statement> getStatements() {
4141
return this.statements;
4242
}
4343

@@ -53,16 +53,16 @@ public String getSql() {
5353
return sql;
5454
}
5555

56-
private static int getParameterCount(List<TokenizedStatement> statements) {
56+
private static int getParameterCount(List<Statement> statements) {
5757
int sum = 0;
58-
for (TokenizedStatement statement : statements){
58+
for (Statement statement : statements){
5959
sum += statement.getParameterCount();
6060
}
6161
return sum;
6262
}
6363

6464
public boolean hasDefaultTokenValue(String... tokenValues) {
65-
for (TokenizedStatement statement : this.statements) {
65+
for (Statement statement : this.statements) {
6666
for (Token token : statement.getTokens()) {
6767
if (token.getType() == TokenType.DEFAULT) {
6868
for (String value : tokenValues) {
@@ -129,24 +129,17 @@ public String toString() {
129129

130130
}
131131

132-
static class TokenizedStatement {
133-
134-
private final String sql;
132+
static class Statement {
135133

136134
private final List<Token> tokens;
137135

138136
private final int parameterCount;
139137

140-
public TokenizedStatement(String sql, List<Token> tokens) {
138+
public Statement(List<Token> tokens) {
141139
this.tokens = tokens;
142-
this.sql = sql;
143140
this.parameterCount = readParameterCount(tokens);
144141
}
145142

146-
public String getSql() {
147-
return this.sql;
148-
}
149-
150143
public List<Token> getTokens() {
151144
return this.tokens;
152145
}
@@ -164,19 +157,14 @@ public boolean equals(Object o) {
164157
return false;
165158
}
166159

167-
TokenizedStatement that = (TokenizedStatement) o;
160+
Statement that = (Statement) o;
168161

169-
if (!this.sql.equals(that.sql)) {
170-
return false;
171-
}
172162
return this.tokens.equals(that.tokens);
173163
}
174164

175165
@Override
176166
public int hashCode() {
177-
int result = this.sql.hashCode();
178-
result = 31 * result + this.tokens.hashCode();
179-
return result;
167+
return this.tokens.hashCode();
180168
}
181169

182170
@Override

src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ final class PostgresqlBatch implements io.r2dbc.postgresql.api.PostgresqlBatch {
4040
public PostgresqlBatch add(String sql) {
4141
Assert.requireNonNull(sql, "sql must not be null");
4242

43-
if (!(PostgresqlSqlParser.tokenize(sql).getParameterCount() == 0)) {
43+
if (!(PostgresqlSqlParser.parse(sql).getParameterCount() == 0)) {
4444
throw new IllegalArgumentException(String.format("Statement '%s' is not supported. This is often due to the presence of parameters.", sql));
4545
}
4646

src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java

+41-16
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,8 @@ class PostgresqlSqlParser {
3838
Arrays.sort(SPECIAL_AND_OPERATOR_CHARS);
3939
}
4040

41-
public static ParsedSql tokenize(String sql) {
41+
private static List<ParsedSql.Token> tokenize(String sql) {
4242
List<ParsedSql.Token> tokens = new ArrayList<>();
43-
List<ParsedSql.TokenizedStatement> statements = new ArrayList<>();
44-
45-
int statementStartIndex = 0;
4643
int i = 0;
4744
while (i < sql.length()) {
4845
char c = sql.charAt(i);
@@ -87,21 +84,48 @@ public static ParsedSql tokenize(String sql) {
8784
}
8885

8986
i += token.getValue().length();
87+
tokens.add(token);
88+
}
89+
return tokens;
90+
}
9091

91-
if (token.getType() == ParsedSql.TokenType.STATEMENT_END) {
92+
public static ParsedSql parse(String sql) {
93+
List<ParsedSql.Token> tokens = tokenize(sql);
94+
List<ParsedSql.Statement> statements = new ArrayList<>();
95+
List<Boolean> functionBodyList = new ArrayList<>();
9296

93-
tokens.add(token);
94-
statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex, i), tokens));
97+
List<ParsedSql.Token> currentStatementTokens = new ArrayList<>();
98+
for (int i = 0; i < tokens.size(); i++) {
99+
ParsedSql.Token current = tokens.get(i);
100+
currentStatementTokens.add(current);
95101

96-
tokens = new ArrayList<>();
97-
statementStartIndex = i + 1;
98-
} else {
99-
tokens.add(token);
102+
if (current.getType() == ParsedSql.TokenType.DEFAULT) {
103+
String currentValue = current.getValue();
104+
105+
if (currentValue.equalsIgnoreCase("BEGIN")) {
106+
if (i + 1 < tokens.size() && tokens.get(i + 1).getValue().equalsIgnoreCase("ATOMIC")) {
107+
functionBodyList.add(true);
108+
} else {
109+
functionBodyList.add(false);
110+
}
111+
} else if (currentValue.equalsIgnoreCase("END") && !functionBodyList.isEmpty()) {
112+
functionBodyList.remove(functionBodyList.size() - 1);
113+
}
114+
} else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) {
115+
boolean inFunctionBody = false;
116+
117+
for (boolean b : functionBodyList) {
118+
inFunctionBody |= b;
119+
}
120+
if (!inFunctionBody) {
121+
statements.add(new ParsedSql.Statement(currentStatementTokens));
122+
currentStatementTokens = new ArrayList<>();
123+
}
100124
}
101125
}
102-
// If tokens is not empty, implicit statement end
103-
if (!tokens.isEmpty()) {
104-
statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex), tokens));
126+
127+
if (!currentStatementTokens.isEmpty()) {
128+
statements.add(new ParsedSql.Statement(currentStatementTokens));
105129
}
106130

107131
return new ParsedSql(sql, statements);
@@ -209,12 +233,13 @@ private static ParsedSql.Token getQuotedIdentifierToken(String sql, int beginInd
209233
}
210234
}
211235

212-
private static boolean isAsciiLetter(char c){
236+
private static boolean isAsciiLetter(char c) {
213237
char lower = Character.toLowerCase(c);
214238
return lower >= 'a' && lower <= 'z';
215239
}
216240

217-
private static boolean isAsciiDigit(char c){
241+
private static boolean isAsciiDigit(char c) {
218242
return c >= '0' && c <= '9';
219243
}
244+
220245
}

src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ final class PostgresqlStatement implements io.r2dbc.postgresql.api.PostgresqlSta
7373

7474
PostgresqlStatement(ConnectionResources resources, String sql) {
7575
this.resources = Assert.requireNonNull(resources, "resources must not be null");
76-
this.parsedSql = PostgresqlSqlParser.tokenize(Assert.requireNonNull(sql, "sql must not be null"));
76+
this.parsedSql = PostgresqlSqlParser.parse(Assert.requireNonNull(sql, "sql must not be null"));
7777
this.connectionContext = resources.getClient().getContext();
7878
this.bindings = new ArrayDeque<>(this.parsedSql.getParameterCount());
7979

src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java

+55-33
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.List;
2424

2525
import static org.junit.jupiter.api.Assertions.assertEquals;
26+
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
2627
import static org.junit.jupiter.api.Assertions.assertThrows;
2728

2829
class PostgresqlSqlParserTest {
@@ -116,48 +117,49 @@ class SingleTokenExceptionTests {
116117

117118
@Test
118119
void unclosedSingleQuotedStringThrowsIllegalArgumentException() {
119-
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("'test"));
120+
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("'test"));
120121
}
121122

122123
@Test
123124
void unclosedDollarQuotedStringThrowsIllegalArgumentException() {
124-
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$$test"));
125+
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$$test"));
125126
}
126127

127128
@Test
128129
void unclosedTaggedDollarQuotedStringThrowsIllegalArgumentException() {
129-
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc$test"));
130+
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc$test"));
130131
}
131132

132133
@Test
133134
void unclosedQuotedIdentifierThrowsIllegalArgumentException() {
134-
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("\"test"));
135+
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("\"test"));
135136
}
136137

137138
@Test
138139
void unclosedBlockCommentThrowsIllegalArgumentException() {
139-
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*test"));
140+
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*test"));
140141
}
141142

142143
@Test
143144
void unclosedNestedBlockCommentThrowsIllegalArgumentException() {
144-
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*/*test*/"));
145+
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*/*test*/"));
145146
}
146147

147148
@Test
148149
void invalidParameterCharacterThrowsIllegalArgumentException() {
149-
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$1test"));
150+
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$1test"));
150151
}
151152

152153
@Test
153154
void invalidTaggedDollarQuoteThrowsIllegalArgumentException() {
154-
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$a b$test$a b$"));
155+
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$a b$test$a b$"));
155156
}
156157

157158
@Test
158159
void unclosedTaggedDollarQuoteThrowsIllegalArgumentException() {
159-
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc"));
160+
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc"));
160161
}
162+
161163
}
162164

163165
@Nested
@@ -242,13 +244,33 @@ void simpleSelectStatementIsTokenized() {
242244
);
243245
}
244246

247+
@Test
248+
void simpleSelectStatementWithFunctionBodyIsTokenized() {
249+
assertSingleStatementEquals("CREATE FUNCTION test() BEGIN ATOMIC SELECT 1; SELECT 2; END",
250+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "CREATE"),
251+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FUNCTION"),
252+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "test"),
253+
new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "("),
254+
new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, ")"),
255+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "BEGIN"),
256+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "ATOMIC"),
257+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"),
258+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"),
259+
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"),
260+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"),
261+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "2"),
262+
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"),
263+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "END")
264+
);
265+
}
266+
245267
}
246268

247269
void assertSingleStatementEquals(String sql, ParsedSql.Token... tokens) {
248-
ParsedSql parsedSql = PostgresqlSqlParser.tokenize(sql);
270+
ParsedSql parsedSql = PostgresqlSqlParser.parse(sql);
249271
assertEquals(1, parsedSql.getStatements().size(), "Parse returned zero or more than 2 statements");
250-
ParsedSql.TokenizedStatement statement = parsedSql.getStatements().get(0);
251-
assertEquals(new ParsedSql.TokenizedStatement(sql, Arrays.asList(tokens)), statement);
272+
ParsedSql.Statement statement = parsedSql.getStatements().get(0);
273+
assertIterableEquals(Arrays.asList(tokens), statement.getTokens());
252274
}
253275

254276
}
@@ -258,30 +280,30 @@ class MultipleStatementTests {
258280

259281
@Test
260282
void simpleMultipleStatementIsTokenized() {
261-
ParsedSql parsedSql = PostgresqlSqlParser.tokenize("DELETE * FROM X; SELECT 1;");
262-
List<ParsedSql.TokenizedStatement> statements = parsedSql.getStatements();
283+
ParsedSql parsedSql = PostgresqlSqlParser.parse("DELETE * FROM X; SELECT 1;");
284+
List<ParsedSql.Statement> statements = parsedSql.getStatements();
263285
assertEquals(2, statements.size());
264-
ParsedSql.TokenizedStatement statementA = statements.get(0);
265-
ParsedSql.TokenizedStatement statementB = statements.get(1);
266-
267-
assertEquals(new ParsedSql.TokenizedStatement("DELETE * FROM X;",
268-
Arrays.asList(
269-
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"),
270-
new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"),
271-
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"),
272-
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"),
273-
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")
274-
)),
275-
statementA
286+
ParsedSql.Statement statementA = statements.get(0);
287+
ParsedSql.Statement statementB = statements.get(1);
288+
289+
assertIterableEquals(
290+
Arrays.asList(
291+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"),
292+
new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"),
293+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"),
294+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"),
295+
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")
296+
),
297+
statementA.getTokens()
276298
);
277299

278-
assertEquals(new ParsedSql.TokenizedStatement("SELECT 1;",
279-
Arrays.asList(
280-
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"),
281-
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"),
282-
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")
283-
)),
284-
statementB
300+
assertIterableEquals(
301+
Arrays.asList(
302+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"),
303+
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"),
304+
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")
305+
),
306+
statementB.getTokens()
285307
);
286308

287309
}

0 commit comments

Comments
 (0)