Skip to content

Commit 596fbc1

Browse files
[SPARK-33556][ML] Add array_to_vector function for dataframe column
### What changes were proposed in this pull request? Add array_to_vector function for dataframe column ### Why are the changes needed? Utility function for array to vector conversion. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? scala unit test & doctest. Closes #30498 from WeichenXu123/array_to_vec. Lead-authored-by: Weichen Xu <[email protected]> Co-authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent f5d2165 commit 596fbc1

File tree

5 files changed

+68
-3
lines changed

5 files changed

+68
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/functions.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml
1919

2020
import org.apache.spark.annotation.Since
21-
import org.apache.spark.ml.linalg.{SparseVector, Vector}
21+
import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors}
2222
import org.apache.spark.mllib.linalg.{Vector => OldVector}
2323
import org.apache.spark.sql.Column
2424
import org.apache.spark.sql.functions.udf
@@ -72,6 +72,20 @@ object functions {
7272
}
7373
}
7474

75+
private val arrayToVectorUdf = udf { array: Seq[Double] =>
76+
Vectors.dense(array.toArray)
77+
}
78+
79+
/**
80+
* Converts a column of array of numeric type into a column of dense vectors in MLlib.
81+
* @param v: the column of array&lt;NumericType&gt type
82+
* @return a column of type `org.apache.spark.ml.linalg.Vector`
83+
* @since 3.1.0
84+
*/
85+
def array_to_vector(v: Column): Column = {
86+
arrayToVectorUdf(v)
87+
}
88+
7589
private[ml] def checkNonNegativeWeight = udf {
7690
value: Double =>
7791
require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.")

mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.apache.spark.ml
1919

2020
import org.apache.spark.SparkException
21-
import org.apache.spark.ml.functions.vector_to_array
22-
import org.apache.spark.ml.linalg.Vectors
21+
import org.apache.spark.ml.functions.{array_to_vector, vector_to_array}
22+
import org.apache.spark.ml.linalg.{Vector, Vectors}
2323
import org.apache.spark.ml.util.MLTest
2424
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
2525
import org.apache.spark.sql.functions.col
@@ -87,4 +87,18 @@ class FunctionsSuite extends MLTest {
8787
assert(thrown2.getMessage.contains(
8888
s"Unsupported dtype: float16. Valid values: float64, float32."))
8989
}
90+
91+
test("test array_to_vector") {
92+
val df1 = Seq(Tuple1(Array(0.5, 1.5))).toDF("c1")
93+
val resultVec = df1.select(array_to_vector(col("c1"))).collect()(0)(0).asInstanceOf[Vector]
94+
assert(resultVec === Vectors.dense(Array(0.5, 1.5)))
95+
96+
val df2 = Seq(Tuple1(Array(1.5f, 2.5f))).toDF("c1")
97+
val resultVec2 = df2.select(array_to_vector(col("c1"))).collect()(0)(0).asInstanceOf[Vector]
98+
assert(resultVec2 === Vectors.dense(Array(1.5, 2.5)))
99+
100+
val df3 = Seq(Tuple1(Array(1, 2))).toDF("c1")
101+
val resultVec3 = df3.select(array_to_vector(col("c1"))).collect()(0)(0).asInstanceOf[Vector]
102+
assert(resultVec3 === Vectors.dense(Array(1.0, 2.0)))
103+
}
90104
}

python/docs/source/reference/pyspark.ml.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ ML Functions
196196
.. autosummary::
197197
:toctree: api/
198198

199+
array_to_vector
199200
vector_to_array
200201

201202

python/pyspark/ml/functions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,40 @@ def vector_to_array(col, dtype="float64"):
6969
sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col), dtype))
7070

7171

72+
def array_to_vector(col):
73+
"""
74+
Converts a column of array of numeric type into a column of dense vectors in MLlib
75+
76+
.. versionadded:: 3.1.0
77+
78+
Parameters
79+
----------
80+
col : :py:class:`pyspark.sql.Column` or str
81+
Input column
82+
83+
Returns
84+
-------
85+
:py:class:`pyspark.sql.Column`
86+
The converted column of MLlib dense vectors.
87+
88+
Examples
89+
--------
90+
>>> from pyspark.ml.functions import array_to_vector
91+
>>> df1 = spark.createDataFrame([([1.5, 2.5],),], schema='v1 array<double>')
92+
>>> df1.select(array_to_vector('v1').alias('vec1')).collect()
93+
[Row(vec1=DenseVector([1.5, 2.5]))]
94+
>>> df2 = spark.createDataFrame([([1.5, 3.5],),], schema='v1 array<float>')
95+
>>> df2.select(array_to_vector('v1').alias('vec1')).collect()
96+
[Row(vec1=DenseVector([1.5, 3.5]))]
97+
>>> df3 = spark.createDataFrame([([1, 3],),], schema='v1 array<int>')
98+
>>> df3.select(array_to_vector('v1').alias('vec1')).collect()
99+
[Row(vec1=DenseVector([1.0, 3.0]))]
100+
"""
101+
sc = SparkContext._active_spark_context
102+
return Column(
103+
sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
104+
105+
72106
def _test():
73107
import doctest
74108
from pyspark.sql import SparkSession

python/pyspark/ml/functions.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ from pyspark import SparkContext as SparkContext, since as since # noqa: F401
2020
from pyspark.sql.column import Column as Column
2121

2222
def vector_to_array(col: Column) -> Column: ...
23+
24+
def array_to_vector(col: Column) -> Column: ...

0 commit comments

Comments
 (0)