24
24
import net .sf .jsqlparser .expression .Function ;
25
25
import net .sf .jsqlparser .parser .CCJSqlParserUtil ;
26
26
import net .sf .jsqlparser .schema .Column ;
27
+ import net .sf .jsqlparser .statement .Statement ;
28
+ import net .sf .jsqlparser .statement .delete .Delete ;
27
29
import net .sf .jsqlparser .statement .select .OrderByElement ;
28
30
import net .sf .jsqlparser .statement .select .PlainSelect ;
29
31
import net .sf .jsqlparser .statement .select .Select ;
30
32
import net .sf .jsqlparser .statement .select .SelectExpressionItem ;
31
33
import net .sf .jsqlparser .statement .select .SelectItem ;
34
+ import net .sf .jsqlparser .statement .update .Update ;
32
35
33
36
import java .util .ArrayList ;
34
37
import java .util .Collections ;
54
57
public class JSqlParserQueryEnhancer implements QueryEnhancer {
55
58
56
59
private final DeclaredQuery query ;
60
+ private final ParsedType parsedType ;
57
61
58
62
/**
59
63
* @param query the query we want to enhance. Must not be {@literal null}.
60
64
*/
61
65
public JSqlParserQueryEnhancer (DeclaredQuery query ) {
62
66
this .query = query ;
67
+ this .parsedType = detectParsedType ();
68
+ }
69
+
70
+ /**
71
+ * Detects what type of query is provided.
72
+ *
73
+ * @return the parsed type
74
+ */
75
+ private ParsedType detectParsedType () {
76
+ try {
77
+ Statement statement = CCJSqlParserUtil .parse (this .query .getQueryString ());
78
+
79
+ if (statement instanceof Update ) {
80
+ return ParsedType .UPDATE ;
81
+ } else if (statement instanceof Delete ) {
82
+ return ParsedType .DELETE ;
83
+ } else if (statement instanceof Select ) {
84
+ return ParsedType .SELECT ;
85
+ } else {
86
+ return ParsedType .SELECT ;
87
+ }
88
+
89
+ } catch (JSQLParserException e ) {
90
+ throw new IllegalArgumentException ("The query you provided is not a valid SQL Query!" , e );
91
+ }
63
92
}
64
93
65
94
@ Override
66
95
public String applySorting (Sort sort , @ Nullable String alias ) {
67
-
68
96
String queryString = query .getQueryString ();
69
97
Assert .hasText (queryString , "Query must not be null or empty!" );
70
98
99
+ if (this .parsedType != ParsedType .SELECT ) {
100
+ return queryString ;
101
+ }
102
+
71
103
if (sort .isUnsorted ()) {
72
104
return queryString ;
73
105
}
@@ -120,6 +152,10 @@ private Set<String> getSelectionAliases(PlainSelect selectBody) {
120
152
*/
121
153
Set <String > getSelectionAliases () {
122
154
155
+ if (this .parsedType != ParsedType .SELECT ) {
156
+ return new HashSet <>();
157
+ }
158
+
123
159
Select selectStatement = parseSelectStatement (this .query .getQueryString ());
124
160
PlainSelect selectBody = (PlainSelect ) selectStatement .getSelectBody ();
125
161
return this .getSelectionAliases (selectBody );
@@ -132,6 +168,9 @@ Set<String> getSelectionAliases() {
132
168
* @return a {@literal Set} of aliases used in the query. Guaranteed to be not {@literal null}.
133
169
*/
134
170
private Set <String > getJoinAliases (String query ) {
171
+ if (this .parsedType != ParsedType .SELECT ) {
172
+ return new HashSet <>();
173
+ }
135
174
return getJoinAliases ((PlainSelect ) parseSelectStatement (query ).getSelectBody ());
136
175
}
137
176
@@ -211,6 +250,10 @@ public String detectAlias() {
211
250
@ Nullable
212
251
private String detectAlias (String query ) {
213
252
253
+ if (this .parsedType != ParsedType .SELECT ) {
254
+ return null ;
255
+ }
256
+
214
257
Select selectStatement = parseSelectStatement (query );
215
258
PlainSelect selectBody = (PlainSelect ) selectStatement .getSelectBody ();
216
259
return detectAlias (selectBody );
@@ -233,6 +276,10 @@ private static String detectAlias(PlainSelect selectBody) {
233
276
@ Override
234
277
public String createCountQueryFor (@ Nullable String countProjection ) {
235
278
279
+ if (this .parsedType != ParsedType .SELECT ) {
280
+ return this .query .getQueryString ();
281
+ }
282
+
236
283
Assert .hasText (this .query .getQueryString (), "OriginalQuery must not be null or empty!" );
237
284
238
285
Select selectStatement = parseSelectStatement (this .query .getQueryString ());
@@ -278,6 +325,10 @@ public String createCountQueryFor(@Nullable String countProjection) {
278
325
@ Override
279
326
public String getProjection () {
280
327
328
+ if (this .parsedType != ParsedType .SELECT ) {
329
+ return "" ;
330
+ }
331
+
281
332
Assert .hasText (query .getQueryString (), "Query must not be null or empty!" );
282
333
283
334
Select selectStatement = parseSelectStatement (query .getQueryString ());
@@ -327,3 +378,15 @@ public DeclaredQuery getQuery() {
327
378
return this .query ;
328
379
}
329
380
}
381
+
382
+ /**
383
+ * An enum to represent the top level parsed statement of the provided query.
384
+ * <ul>
385
+ * <li>{@code ParsedType.DELETE}: means the top level statement is {@link Delete}</li>
386
+ * <li>{@code ParsedType.UPDATE}: means the top level statement is {@link Update}</li>
387
+ * <li>{@code ParsedType.SELECT}: means the top level statement is {@link Select}</li>
388
+ * </ul>
389
+ */
390
+ enum ParsedType {
391
+ DELETE , UPDATE , SELECT ;
392
+ }
0 commit comments