@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
22
22
23
23
import org .apache .spark .connect .proto
24
24
import org .apache .spark .sql .SparkSession
25
+ import org .apache .spark .sql .catalyst .AliasIdentifier
25
26
import org .apache .spark .sql .catalyst .analysis .{UnresolvedAlias , UnresolvedAttribute , UnresolvedFunction , UnresolvedRelation , UnresolvedStar }
26
27
import org .apache .spark .sql .catalyst .expressions
27
28
import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , AttributeReference , Expression }
@@ -54,8 +55,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
54
55
}
55
56
56
57
rel.getRelTypeCase match {
57
- case proto.Relation .RelTypeCase .READ => transformReadRel(rel.getRead, common )
58
- case proto.Relation .RelTypeCase .PROJECT => transformProject(rel.getProject, common )
58
+ case proto.Relation .RelTypeCase .READ => transformReadRel(rel.getRead)
59
+ case proto.Relation .RelTypeCase .PROJECT => transformProject(rel.getProject)
59
60
case proto.Relation .RelTypeCase .FILTER => transformFilter(rel.getFilter)
60
61
case proto.Relation .RelTypeCase .LIMIT => transformLimit(rel.getLimit)
61
62
case proto.Relation .RelTypeCase .OFFSET => transformOffset(rel.getOffset)
@@ -66,9 +67,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
66
67
case proto.Relation .RelTypeCase .AGGREGATE => transformAggregate(rel.getAggregate)
67
68
case proto.Relation .RelTypeCase .SQL => transformSql(rel.getSql)
68
69
case proto.Relation .RelTypeCase .LOCAL_RELATION =>
69
- transformLocalRelation(rel.getLocalRelation, common )
70
+ transformLocalRelation(rel.getLocalRelation)
70
71
case proto.Relation .RelTypeCase .SAMPLE => transformSample(rel.getSample)
71
72
case proto.Relation .RelTypeCase .RANGE => transformRange(rel.getRange)
73
+ case proto.Relation .RelTypeCase .SUBQUERY_ALIAS =>
74
+ transformSubqueryAlias(rel.getSubqueryAlias)
72
75
case proto.Relation .RelTypeCase .RELTYPE_NOT_SET =>
73
76
throw new IndexOutOfBoundsException (" Expected Relation to be set, but is empty." )
74
77
case _ => throw InvalidPlanInput (s " ${rel.getUnknown} not supported. " )
@@ -79,6 +82,16 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
79
82
session.sessionState.sqlParser.parsePlan(sql.getQuery)
80
83
}
81
84
85
+ private def transformSubqueryAlias (alias : proto.SubqueryAlias ): LogicalPlan = {
86
+ val aliasIdentifier =
87
+ if (alias.getQualifierCount > 0 ) {
88
+ AliasIdentifier .apply(alias.getAlias, alias.getQualifierList.asScala.toSeq)
89
+ } else {
90
+ AliasIdentifier .apply(alias.getAlias)
91
+ }
92
+ SubqueryAlias (aliasIdentifier, transformRelation(alias.getInput))
93
+ }
94
+
82
95
/**
83
96
* All fields of [[proto.Sample ]] are optional. However, given those are proto primitive types,
84
97
* we cannot differentiate if the field is not or set when the field's value equals to the type
@@ -141,35 +154,21 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
141
154
}
142
155
}
143
156
144
- private def transformLocalRelation (
145
- rel : proto.LocalRelation ,
146
- common : Option [proto.RelationCommon ]): LogicalPlan = {
157
+ private def transformLocalRelation (rel : proto.LocalRelation ): LogicalPlan = {
147
158
val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq
148
- val relation = new org.apache.spark.sql.catalyst.plans.logical.LocalRelation (attributes)
149
- if (common.nonEmpty && common.get.getAlias.nonEmpty) {
150
- logical.SubqueryAlias (identifier = common.get.getAlias, child = relation)
151
- } else {
152
- relation
153
- }
159
+ new org.apache.spark.sql.catalyst.plans.logical.LocalRelation (attributes)
154
160
}
155
161
156
162
private def transformAttribute (exp : proto.Expression .QualifiedAttribute ): Attribute = {
157
163
AttributeReference (exp.getName, DataTypeProtoConverter .toCatalystType(exp.getType))()
158
164
}
159
165
160
- private def transformReadRel (
161
- rel : proto.Read ,
162
- common : Option [proto.RelationCommon ]): LogicalPlan = {
166
+ private def transformReadRel (rel : proto.Read ): LogicalPlan = {
163
167
val baseRelation = rel.getReadTypeCase match {
164
168
case proto.Read .ReadTypeCase .NAMED_TABLE =>
165
169
val multipartIdentifier =
166
170
CatalystSqlParser .parseMultipartIdentifier(rel.getNamedTable.getUnparsedIdentifier)
167
- val child = UnresolvedRelation (multipartIdentifier)
168
- if (common.nonEmpty && common.get.getAlias.nonEmpty) {
169
- SubqueryAlias (identifier = common.get.getAlias, child = child)
170
- } else {
171
- child
172
- }
171
+ UnresolvedRelation (multipartIdentifier)
173
172
case proto.Read .ReadTypeCase .DATA_SOURCE =>
174
173
if (rel.getDataSource.getFormat == " " ) {
175
174
throw InvalidPlanInput (" DataSource requires a format" )
@@ -193,9 +192,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
193
192
logical.Filter (condition = transformExpression(rel.getCondition), child = baseRel)
194
193
}
195
194
196
- private def transformProject (
197
- rel : proto.Project ,
198
- common : Option [proto.RelationCommon ]): LogicalPlan = {
195
+ private def transformProject (rel : proto.Project ): LogicalPlan = {
199
196
val baseRel = transformRelation(rel.getInput)
200
197
// TODO: support the target field for *.
201
198
val projection =
@@ -204,12 +201,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
204
201
} else {
205
202
rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias (_))
206
203
}
207
- val project = logical.Project (projectList = projection.toSeq, child = baseRel)
208
- if (common.nonEmpty && common.get.getAlias.nonEmpty) {
209
- logical.SubqueryAlias (identifier = common.get.getAlias, child = project)
210
- } else {
211
- project
212
- }
204
+ logical.Project (projectList = projection.toSeq, child = baseRel)
213
205
}
214
206
215
207
private def transformUnresolvedExpression (exp : proto.Expression ): UnresolvedAttribute = {
0 commit comments