Skip to content

Commit 071feab

Browse files
committed
[SPARK-44531][CONNECT][SQL] Move encoder inference to sql/api
### What changes were proposed in this pull request? This PR move encoder inference (ScalaReflection/RowEncoder/JavaTypeInference) into sql/api. ### Why are the changes needed? We want to use encoder inference in the spark connect scala client. The client's dependency to catalyst is going away, so we need to move this. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #42134 from hvanhovell/SPARK-44531. Authored-by: Herman van Hovell <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent a84e2b1 commit 071feab

File tree

45 files changed

+205
-184
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+205
-184
lines changed

mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.hadoop.mapreduce.Job
2525
import org.apache.spark.ml.image.ImageSchema
2626
import org.apache.spark.sql.SparkSession
2727
import org.apache.spark.sql.catalyst.InternalRow
28-
import org.apache.spark.sql.catalyst.encoders.RowEncoder
28+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2929
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
3030
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
3131
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
@@ -90,7 +90,7 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister
9090
if (requiredSchema.isEmpty) {
9191
filteredResult.map(_ => emptyUnsafeRow)
9292
} else {
93-
val toRow = RowEncoder(requiredSchema).createSerializer()
93+
val toRow = ExpressionEncoder(requiredSchema).createSerializer()
9494
filteredResult.map(row => toRow(row))
9595
}
9696
}

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
3131
import org.apache.spark.mllib.util.MLUtils
3232
import org.apache.spark.sql.{Row, SparkSession}
3333
import org.apache.spark.sql.catalyst.InternalRow
34-
import org.apache.spark.sql.catalyst.encoders.RowEncoder
34+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3535
import org.apache.spark.sql.catalyst.expressions.AttributeReference
3636
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3737
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -167,7 +167,7 @@ private[libsvm] class LibSVMFileFormat
167167
LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
168168
}
169169

170-
val toRow = RowEncoder(dataSchema).createSerializer()
170+
val toRow = ExpressionEncoder(dataSchema).createSerializer()
171171
val fullOutput = dataSchema.map { f =>
172172
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
173173
}

project/MimaExcludes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ object MimaExcludes {
4040

4141
// Exclude rules for 3.5.x from 3.4.0
4242
lazy val v35excludes = defaultExcludes ++ Seq(
43+
// [SPARK-44531][CONNECT][SQL] Move encoder inference to sql/api
44+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DataTypes"),
45+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.SQLUserDefinedType"),
4346
// [SPARK-43165][SQL] Move canWrite to DataTypeUtils
4447
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.types.DataType.canWrite"),
4548
// [SPARK-43195][CORE] Remove unnecessary serializable wrapper in HadoopFSUtils

sql/api/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
</properties>
3636

3737
<dependencies>
38+
<dependency>
39+
<groupId>org.scala-lang</groupId>
40+
<artifactId>scala-reflect</artifactId>
41+
</dependency>
3842
<dependency>
3943
<groupId>org.scala-lang.modules</groupId>
4044
<artifactId>scala-parser-combinators_${scala.binary.version}</artifactId>

sql/api/src/main/scala/org/apache/spark/sql/SqlApiConf.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ private[sql] trait SqlApiConf {
4141
def timestampType: AtomicType
4242
def allowNegativeScaleOfDecimalEnabled: Boolean
4343
def charVarcharAsString: Boolean
44+
def datetimeJava8ApiEnabled: Boolean
4445
}
4546

4647
private[sql] object SqlApiConf {
@@ -76,4 +77,5 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf {
7677
override def timestampType: AtomicType = TimestampType
7778
override def allowNegativeScaleOfDecimalEnabled: Boolean = false
7879
override def charVarcharAsString: Boolean = false
80+
override def datetimeJava8ApiEnabled: Boolean = false
7981
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala renamed to sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@ import javax.annotation.Nonnull
2424
import scala.annotation.tailrec
2525
import scala.reflect.ClassTag
2626

27-
import org.apache.spark.SPARK_DOC_ROOT
2827
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
2928
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder}
30-
import org.apache.spark.sql.errors.QueryExecutionErrors
29+
import org.apache.spark.sql.errors.EncoderErrors
3130
import org.apache.spark.sql.types._
3231

3332
/**
@@ -116,7 +115,7 @@ object JavaTypeInference {
116115

117116
case c: Class[_] =>
118117
if (seenTypeSet.contains(c)) {
119-
throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(c)
118+
throw EncoderErrors.cannotHaveCircularReferencesInBeanClassError(c)
120119
}
121120

122121
// TODO: we should only collect properties that have getter and setter. However, some tests
@@ -139,7 +138,7 @@ object JavaTypeInference {
139138
JavaBeanEncoder(ClassTag(c), fields)
140139

141140
case _ =>
142-
throw QueryExecutionErrors.cannotFindEncoderForTypeError(t.toString, SPARK_DOC_ROOT)
141+
throw EncoderErrors.cannotFindEncoderForTypeError(t.toString)
143142
}
144143

145144
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
@@ -197,7 +196,7 @@ object JavaTypeInference {
197196
}
198197
}
199198
}
200-
throw QueryExecutionErrors.unreachableError()
199+
throw EncoderErrors.unreachableError()
201200
}
202201
}
203202

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala renamed to sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@ import scala.util.{Failure, Success}
2626

2727
import org.apache.commons.lang3.reflect.ConstructorUtils
2828

29-
import org.apache.spark.SPARK_DOC_ROOT
3029
import org.apache.spark.internal.Logging
3130
import org.apache.spark.sql.Row
3231
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
3332
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
34-
import org.apache.spark.sql.errors.QueryExecutionErrors
33+
import org.apache.spark.sql.errors.EncoderErrors
3534
import org.apache.spark.sql.types._
3635
import org.apache.spark.unsafe.types.CalendarInterval
3736

@@ -378,13 +377,13 @@ object ScalaReflection extends ScalaReflection {
378377

379378
case t if definedByConstructorParams(t) =>
380379
if (seenTypeSet.contains(t)) {
381-
throw QueryExecutionErrors.cannotHaveCircularReferencesInClassError(t.toString)
380+
throw EncoderErrors.cannotHaveCircularReferencesInClassError(t.toString)
382381
}
383382
val params = getConstructorParameters(t).map {
384383
case (fieldName, fieldType) =>
385384
if (SourceVersion.isKeyword(fieldName) ||
386385
!SourceVersion.isIdentifier(encodeFieldNameToIdentifier(fieldName))) {
387-
throw QueryExecutionErrors.cannotUseInvalidJavaIdentifierAsFieldNameError(
386+
throw EncoderErrors.cannotUseInvalidJavaIdentifierAsFieldNameError(
388387
fieldName,
389388
path)
390389
}
@@ -397,7 +396,7 @@ object ScalaReflection extends ScalaReflection {
397396
}
398397
ProductEncoder(ClassTag(getClassFromType(t)), params)
399398
case _ =>
400-
throw QueryExecutionErrors.cannotFindEncoderForTypeError(tpe.toString, SPARK_DOC_ROOT)
399+
throw EncoderErrors.cannotFindEncoderForTypeError(tpe.toString)
401400
}
402401
}
403402
}
@@ -478,7 +477,7 @@ trait ScalaReflection extends Logging {
478477
*/
479478
private def getCompanionConstructor(tpe: Type): Symbol = {
480479
def throwUnsupportedOperation = {
481-
throw QueryExecutionErrors.cannotFindConstructorForTypeError(tpe.toString)
480+
throw EncoderErrors.cannotFindConstructorForTypeError(tpe.toString)
482481
}
483482
tpe.typeSymbol.asClass.companion match {
484483
case NoSymbol => throwUnsupportedOperation
@@ -501,7 +500,7 @@ trait ScalaReflection extends Logging {
501500
val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
502501
s => s.isMethod && s.asMethod.isPrimaryConstructor)
503502
if (primaryConstructorSymbol.isEmpty) {
504-
throw QueryExecutionErrors.primaryConstructorNotFoundError(tpe.getClass)
503+
throw EncoderErrors.primaryConstructorNotFoundError(tpe.getClass)
505504
} else {
506505
primaryConstructorSymbol.get.asMethod.paramLists
507506
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala renamed to sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ package org.apache.spark.sql.catalyst.encoders
2020
import scala.collection.mutable
2121
import scala.reflect.classTag
2222

23-
import org.apache.spark.sql.Row
23+
import org.apache.spark.sql.{Row, SqlApiConf}
2424
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder}
25-
import org.apache.spark.sql.errors.QueryExecutionErrors
26-
import org.apache.spark.sql.internal.SQLConf
25+
import org.apache.spark.sql.errors.EncoderErrors
2726
import org.apache.spark.sql.types._
2827

2928
/**
@@ -59,14 +58,6 @@ import org.apache.spark.sql.types._
5958
* }}}
6059
*/
6160
object RowEncoder {
62-
def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = {
63-
ExpressionEncoder(encoderFor(schema, lenient))
64-
}
65-
66-
def apply(schema: StructType): ExpressionEncoder[Row] = {
67-
apply(schema, lenient = false)
68-
}
69-
7061
def encoderFor(schema: StructType): AgnosticEncoder[Row] = {
7162
encoderFor(schema, lenient = false)
7263
}
@@ -89,10 +80,10 @@ object RowEncoder {
8980
case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true)
9081
case BinaryType => BinaryEncoder
9182
case StringType => StringEncoder
92-
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient)
83+
case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient)
9384
case TimestampType => TimestampEncoder(lenient)
9485
case TimestampNTZType => LocalDateTimeEncoder
95-
case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient)
86+
case DateType if SqlApiConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient)
9687
case DateType => DateEncoder(lenient)
9788
case CalendarIntervalType => CalendarIntervalEncoder
9889
case _: DayTimeIntervalType => DayTimeIntervalEncoder
@@ -106,7 +97,7 @@ object RowEncoder {
10697
annotation.udt()
10798
} else {
10899
UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
109-
throw QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
100+
throw EncoderErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
110101
}
111102
}
112103
UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]])

sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package org.apache.spark.sql.errors
1818

1919
import java.util.Locale
2020

21-
import org.apache.spark.QueryContext
21+
import org.apache.spark.{QueryContext, SparkRuntimeException}
2222
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
2323
import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils, SparkStringUtils}
2424
import org.apache.spark.sql.types.{AbstractDataType, DataType, TypeCollection}
@@ -73,4 +73,10 @@ private[sql] trait DataTypeErrorsBase {
7373
def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = {
7474
if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext])
7575
}
76+
77+
def unreachableError(err: String = ""): SparkRuntimeException = {
78+
new SparkRuntimeException(
79+
errorClass = "_LEGACY_ERROR_TEMP_2028",
80+
messageParameters = Map("err" -> err))
81+
}
7682
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.errors
18+
19+
import org.apache.spark.{SparkBuildInfo, SparkException, SparkRuntimeException, SparkUnsupportedOperationException}
20+
import org.apache.spark.sql.catalyst.WalkedTypePath
21+
import org.apache.spark.sql.types.UserDefinedType
22+
23+
object EncoderErrors extends DataTypeErrorsBase {
24+
def userDefinedTypeNotAnnotatedAndRegisteredError(udt: UserDefinedType[_]): Throwable = {
25+
new SparkException(
26+
errorClass = "_LEGACY_ERROR_TEMP_2155",
27+
messageParameters = Map(
28+
"userClass" -> udt.userClass.getName),
29+
cause = null)
30+
}
31+
32+
def cannotFindEncoderForTypeError(typeName: String): SparkUnsupportedOperationException = {
33+
new SparkUnsupportedOperationException(
34+
errorClass = "ENCODER_NOT_FOUND",
35+
messageParameters = Map(
36+
"typeName" -> typeName,
37+
"docroot" -> SparkBuildInfo.spark_doc_root))
38+
}
39+
40+
def cannotHaveCircularReferencesInBeanClassError(
41+
clazz: Class[_]): SparkUnsupportedOperationException = {
42+
new SparkUnsupportedOperationException(
43+
errorClass = "_LEGACY_ERROR_TEMP_2138",
44+
messageParameters = Map("clazz" -> clazz.toString()))
45+
}
46+
47+
def cannotFindConstructorForTypeError(tpe: String): SparkUnsupportedOperationException = {
48+
new SparkUnsupportedOperationException(
49+
errorClass = "_LEGACY_ERROR_TEMP_2144",
50+
messageParameters = Map(
51+
"tpe" -> tpe))
52+
}
53+
54+
def cannotHaveCircularReferencesInClassError(t: String): SparkUnsupportedOperationException = {
55+
new SparkUnsupportedOperationException(
56+
errorClass = "_LEGACY_ERROR_TEMP_2139",
57+
messageParameters = Map("t" -> t))
58+
}
59+
60+
def cannotUseInvalidJavaIdentifierAsFieldNameError(
61+
fieldName: String, walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = {
62+
new SparkUnsupportedOperationException(
63+
errorClass = "_LEGACY_ERROR_TEMP_2140",
64+
messageParameters = Map(
65+
"fieldName" -> fieldName,
66+
"walkedTypePath" -> walkedTypePath.toString()))
67+
}
68+
69+
def primaryConstructorNotFoundError(cls: Class[_]): SparkRuntimeException = {
70+
new SparkRuntimeException(
71+
errorClass = "_LEGACY_ERROR_TEMP_2021",
72+
messageParameters = Map("cls" -> cls.toString()))
73+
}
74+
}

sql/catalyst/pom.xml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@
3535
</properties>
3636

3737
<dependencies>
38-
<dependency>
39-
<groupId>org.scala-lang</groupId>
40-
<artifactId>scala-reflect</artifactId>
41-
</dependency>
42-
4338
<dependency>
4439
<groupId>org.apache.spark</groupId>
4540
<artifactId>spark-core_${scala.binary.version}</artifactId>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
2020
import scala.reflect.ClassTag
2121
import scala.reflect.runtime.universe.TypeTag
2222

23-
import org.apache.spark.sql.Encoder
23+
import org.apache.spark.sql.{Encoder, Row}
2424
import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, JavaTypeInference, ScalaReflection, SerializerBuildHelper}
2525
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
2626
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
@@ -58,6 +58,12 @@ object ExpressionEncoder {
5858
enc.clsTag)
5959
}
6060

61+
def apply(schema: StructType): ExpressionEncoder[Row] = apply(schema, lenient = false)
62+
63+
def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = {
64+
apply(RowEncoder.encoderFor(schema, lenient))
65+
}
66+
6167
// TODO: improve error message for java bean encoder.
6268
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
6369
apply(JavaTypeInference.encoderFor(beanClass))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ object MapPartitionsInR {
149149
broadcastVars,
150150
encoder.schema,
151151
schema,
152-
CatalystSerde.generateObjAttr(RowEncoder(schema)),
153-
deserialized))(RowEncoder(schema))
152+
CatalystSerde.generateObjAttr(ExpressionEncoder(schema)),
153+
deserialized))(ExpressionEncoder(schema))
154154
}
155155
}
156156
}
@@ -606,8 +606,8 @@ object FlatMapGroupsInR {
606606
UnresolvedDeserializer(valueDeserializer, dataAttributes),
607607
groupingAttributes,
608608
dataAttributes,
609-
CatalystSerde.generateObjAttr(RowEncoder(schema)),
610-
child))(RowEncoder(schema))
609+
CatalystSerde.generateObjAttr(ExpressionEncoder(schema)),
610+
child))(ExpressionEncoder(schema))
611611
}
612612
}
613613
}

0 commit comments

Comments
 (0)