Skip to content

Commit 14784f8

Browse files
dtenedorragnarok56
authored andcommitted
[SPARK-44748][SQL] Query execution for the PARTITION BY clause in UDTF TABLE arguments
### What changes were proposed in this pull request? This PR implements query execution support for for the PARTITION BY and ORDER BY clauses for UDTF TABLE arguments. * The query planning support was added in [1] and [2] and [3]. After those changes, the planner added a projection to compute the PARTITION BY expressions, plus a repartition operator, plus a sort operator. * In this PR, the Python executor receives the indexes of these expressions within the input table's rows, and compares the values of the projected partitioning expressions between consecutive rows. * When the values change, this marks the boundary between partitions, and so we call the UDTF instance's `terminate` method, then destroy it and create a new one for the next partition. [1] apache#42100 [2] apache#42174 [3] apache#42351 Example: ``` # Make a test UDTF to yield an output row with the same value # consumed from the last input row in the input table or partition. class TestUDTF: def eval(self, row: Row): self._last = row['input'] self._partition_col = row['partition_col'] def terminate(self): yield self._partition_col, self._last func = udtf(TestUDTF, returnType='partition_col: int, last: int') self.spark.udtf.register('test_udtf', func) self.spark.sql(''' WITH t AS ( SELECT id AS partition_col, 1 AS input FROM range(0, 2) UNION ALL SELECT id AS partition_col, 2 AS input FROM range(0, 2) ) SELECT * FROM test_udtf(TABLE(t) PARTITION BY partition_col ORDER BY input) ''').collect() > [Row(partition_col=0, last=2), (partition_col=1, last=2)] ``` ### Why are the changes needed? This brings full end-to-end execution for the PARTITION BY and/or ORDER BY clauses for UDTF TABLE arguments. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds end-to-end testing in `test_udtf.py`. Closes apache#42420 from dtenedor/inspect-partition-by. Authored-by: Daniel Tenedorio <[email protected]> Signed-off-by: Takuya UESHIN <[email protected]>
1 parent 1f58395 commit 14784f8

File tree

7 files changed

+313
-37
lines changed

7 files changed

+313
-37
lines changed

python/pyspark/sql/tests/test_udtf.py

+167
Original file line numberDiff line numberDiff line change
@@ -1964,6 +1964,173 @@ def eval(self, a, b=100):
19641964
with self.subTest(query_no=i):
19651965
assertDataFrameEqual(df, [Row(a=10, b="z")])
19661966

1967+
def test_udtf_with_table_argument_and_partition_by(self):
1968+
class TestUDTF:
1969+
def __init__(self):
1970+
self._sum = 0
1971+
self._partition_col = None
1972+
1973+
def eval(self, row: Row):
1974+
self._sum += row["input"]
1975+
if self._partition_col is not None and self._partition_col != row["partition_col"]:
1976+
# Make sure that all values of the partitioning column are the same
1977+
# for each row consumed by this method for this instance of the class.
1978+
raise Exception(
1979+
f"self._partition_col was {self._partition_col} but the row "
1980+
+ f"value was {row['partition_col']}"
1981+
)
1982+
self._partition_col = row["partition_col"]
1983+
1984+
def terminate(self):
1985+
yield self._partition_col, self._sum
1986+
1987+
# This is a basic example.
1988+
func = udtf(TestUDTF, returnType="partition_col: int, total: int")
1989+
self.spark.udtf.register("test_udtf", func)
1990+
self.assertEqual(
1991+
self.spark.sql(
1992+
"""
1993+
WITH t AS (
1994+
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
1995+
UNION ALL
1996+
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
1997+
)
1998+
SELECT partition_col, total
1999+
FROM test_udtf(TABLE(t) PARTITION BY partition_col - 1)
2000+
ORDER BY 1, 2
2001+
"""
2002+
).collect(),
2003+
[Row(partition_col=x, total=3) for x in range(1, 21)],
2004+
)
2005+
2006+
# These cases partition by constant values.
2007+
for str_first, str_second, result_first, result_second in (
2008+
("123", "456", 123, 456),
2009+
("123", "NULL", None, 123),
2010+
):
2011+
self.assertEqual(
2012+
self.spark.sql(
2013+
f"""
2014+
WITH t AS (
2015+
SELECT {str_first} AS partition_col, id AS input FROM range(0, 2)
2016+
UNION ALL
2017+
SELECT {str_second} AS partition_col, id AS input FROM range(0, 2)
2018+
)
2019+
SELECT partition_col, total
2020+
FROM test_udtf(TABLE(t) PARTITION BY partition_col)
2021+
ORDER BY 1, 2
2022+
"""
2023+
).collect(),
2024+
[
2025+
Row(partition_col=result_first, total=1),
2026+
Row(partition_col=result_second, total=1),
2027+
],
2028+
)
2029+
2030+
# Combine a lateral join with a TABLE argument with PARTITION BY .
2031+
func = udtf(TestUDTF, returnType="partition_col: int, total: int")
2032+
self.spark.udtf.register("test_udtf", func)
2033+
self.assertEqual(
2034+
self.spark.sql(
2035+
"""
2036+
WITH t AS (
2037+
SELECT id AS partition_col, 1 AS input FROM range(1, 3)
2038+
UNION ALL
2039+
SELECT id AS partition_col, 2 AS input FROM range(1, 3)
2040+
)
2041+
SELECT v.a, v.b, f.partition_col, f.total
2042+
FROM VALUES (0, 1) AS v(a, b),
2043+
LATERAL test_udtf(TABLE(t) PARTITION BY partition_col - 1) f
2044+
ORDER BY 1, 2, 3, 4
2045+
"""
2046+
).collect(),
2047+
[Row(a=0, b=1, partition_col=1, total=3), Row(a=0, b=1, partition_col=2, total=3)],
2048+
)
2049+
2050+
def test_udtf_with_table_argument_and_partition_by_and_order_by(self):
2051+
class TestUDTF:
2052+
def __init__(self):
2053+
self._last = None
2054+
self._partition_col = None
2055+
2056+
def eval(self, row: Row, partition_col: str):
2057+
# Make sure that all values of the partitioning column are the same
2058+
# for each row consumed by this method for this instance of the class.
2059+
if self._partition_col is not None and self._partition_col != row[partition_col]:
2060+
raise Exception(
2061+
f"self._partition_col was {self._partition_col} but the row "
2062+
+ f"value was {row[partition_col]}"
2063+
)
2064+
self._last = row["input"]
2065+
self._partition_col = row[partition_col]
2066+
2067+
def terminate(self):
2068+
yield self._partition_col, self._last
2069+
2070+
func = udtf(TestUDTF, returnType="partition_col: int, last: int")
2071+
self.spark.udtf.register("test_udtf", func)
2072+
for order_by_str, result_val in (
2073+
("input ASC", 2),
2074+
("input + 1 ASC", 2),
2075+
("input DESC", 1),
2076+
("input - 1 DESC", 1),
2077+
):
2078+
self.assertEqual(
2079+
self.spark.sql(
2080+
f"""
2081+
WITH t AS (
2082+
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
2083+
UNION ALL
2084+
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
2085+
)
2086+
SELECT partition_col, last
2087+
FROM test_udtf(
2088+
row => TABLE(t) PARTITION BY partition_col - 1 ORDER BY {order_by_str},
2089+
partition_col => 'partition_col')
2090+
ORDER BY 1, 2
2091+
"""
2092+
).collect(),
2093+
[Row(partition_col=x, last=result_val) for x in range(1, 21)],
2094+
)
2095+
2096+
def test_udtf_with_table_argument_with_single_partition(self):
2097+
class TestUDTF:
2098+
def __init__(self):
2099+
self._count = 0
2100+
self._sum = 0
2101+
self._last = None
2102+
2103+
def eval(self, row: Row):
2104+
# Make sure that the rows arrive in the expected order.
2105+
if self._last is not None and self._last > row["input"]:
2106+
raise Exception(
2107+
f"self._last was {self._last} but the row value was {row['input']}"
2108+
)
2109+
self._count += 1
2110+
self._last = row["input"]
2111+
self._sum += row["input"]
2112+
2113+
def terminate(self):
2114+
yield self._count, self._sum, self._last
2115+
2116+
func = udtf(TestUDTF, returnType="count: int, total: int, last: int")
2117+
self.spark.udtf.register("test_udtf", func)
2118+
self.assertEqual(
2119+
self.spark.sql(
2120+
"""
2121+
WITH t AS (
2122+
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
2123+
UNION ALL
2124+
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
2125+
)
2126+
SELECT count, total, last
2127+
FROM test_udtf(TABLE(t) WITH SINGLE PARTITION ORDER BY (input, partition_col))
2128+
ORDER BY 1, 2
2129+
"""
2130+
).collect(),
2131+
[Row(count=40, total=60, last=2)],
2132+
)
2133+
19672134

19682135
class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
19692136
@classmethod

python/pyspark/worker.py

+85-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import time
2424
from inspect import getfullargspec
2525
import json
26-
from typing import Any, Iterable, Iterator
26+
from typing import Any, Callable, Iterable, Iterator
2727

2828
import traceback
2929
import faulthandler
@@ -53,7 +53,7 @@
5353
ApplyInPandasWithStateSerializer,
5454
)
5555
from pyspark.sql.pandas.types import to_arrow_type
56-
from pyspark.sql.types import BinaryType, StringType, StructType, _parse_datatype_json_string
56+
from pyspark.sql.types import BinaryType, Row, StringType, StructType, _parse_datatype_json_string
5757
from pyspark.util import fail_on_stopiteration, try_simplify_traceback
5858
from pyspark import shuffle
5959
from pyspark.errors import PySparkRuntimeError, PySparkTypeError
@@ -609,7 +609,8 @@ def read_udtf(pickleSer, infile, eval_type):
609609
kwargs_offsets[name] = offset
610610
else:
611611
args_offsets.append(offset)
612-
612+
num_partition_child_indexes = read_int(infile)
613+
partition_child_indexes = [read_int(infile) for i in range(num_partition_child_indexes)]
613614
handler = read_command(pickleSer, infile)
614615
if not isinstance(handler, type):
615616
raise PySparkRuntimeError(
@@ -623,9 +624,89 @@ def read_udtf(pickleSer, infile, eval_type):
623624
f"The return type of a UDTF must be a struct type, but got {type(return_type)}."
624625
)
625626

627+
class UDTFWithPartitions:
628+
"""
629+
This implements the logic of a UDTF that accepts an input TABLE argument with one or more
630+
PARTITION BY expressions.
631+
632+
For example, let's assume we have a table like:
633+
CREATE TABLE t (c1 INT, c2 INT) USING delta;
634+
Then for the following queries:
635+
SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2);
636+
The partition_child_indexes will be: 0, 1.
637+
SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4);
638+
The partition_child_indexes will be: 0, 2 (where we add a projection for "c2 + 4").
639+
"""
640+
641+
def __init__(self, create_udtf: Callable, partition_child_indexes: list):
642+
"""
643+
Creates a new instance of this class to wrap the provided UDTF with another one that
644+
checks the values of projected partitioning expressions on consecutive rows to figure
645+
out when the partition boundaries change.
646+
647+
Parameters
648+
----------
649+
create_udtf: function
650+
Function to create a new instance of the UDTF to be invoked.
651+
partition_child_indexes: list
652+
List of integers identifying zero-based indexes of the columns of the input table
653+
that contain projected partitioning expressions. This class will inspect these
654+
values for each pair of consecutive input rows. When they change, this indicates
655+
the boundary between two partitions, and we will invoke the 'terminate' method on
656+
the UDTF class instance and then destroy it and create a new one to implement the
657+
desired partitioning semantics.
658+
"""
659+
self._create_udtf: Callable = create_udtf
660+
self._udtf = create_udtf()
661+
self._prev_arguments: list = list()
662+
self._partition_child_indexes: list = partition_child_indexes
663+
664+
def eval(self, *args, **kwargs) -> Iterator:
665+
changed_partitions = self._check_partition_boundaries(
666+
list(args) + list(kwargs.values())
667+
)
668+
if changed_partitions:
669+
if self._udtf.terminate is not None:
670+
result = self._udtf.terminate()
671+
if result is not None:
672+
for row in result:
673+
yield row
674+
self._udtf = self._create_udtf()
675+
if self._udtf.eval is not None:
676+
result = self._udtf.eval(*args, **kwargs)
677+
if result is not None:
678+
for row in result:
679+
yield row
680+
681+
def terminate(self) -> Iterator:
682+
if self._udtf.terminate is not None:
683+
return self._udtf.terminate()
684+
return iter(())
685+
686+
def _check_partition_boundaries(self, arguments: list) -> bool:
687+
result = False
688+
if len(self._prev_arguments) > 0:
689+
cur_table_arg = self._get_table_arg(arguments)
690+
prev_table_arg = self._get_table_arg(self._prev_arguments)
691+
cur_partitions_args = []
692+
prev_partitions_args = []
693+
for i in partition_child_indexes:
694+
cur_partitions_args.append(cur_table_arg[i])
695+
prev_partitions_args.append(prev_table_arg[i])
696+
self._prev_arguments = arguments
697+
result = any(k != v for k, v in zip(cur_partitions_args, prev_partitions_args))
698+
self._prev_arguments = arguments
699+
return result
700+
701+
def _get_table_arg(self, inputs: list) -> Row:
702+
return [x for x in inputs if type(x) is Row][0]
703+
626704
# Instantiate the UDTF class.
627705
try:
628-
udtf = handler()
706+
if len(partition_child_indexes) > 0:
707+
udtf = UDTFWithPartitions(handler, partition_child_indexes)
708+
else:
709+
udtf = handler()
629710
except Exception as e:
630711
raise PySparkRuntimeError(
631712
error_class="UDTF_EXEC_ERROR",

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

+28-2
Original file line numberDiff line numberDiff line change
@@ -2098,6 +2098,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
20982098
}
20992099

21002100
val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan]
2101+
val functionTableSubqueryArgs =
2102+
mutable.ArrayBuffer.empty[FunctionTableSubqueryArgumentExpression]
21012103
val tvf = resolvedFunc.transformAllExpressionsWithPruning(
21022104
_.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), ruleId) {
21032105
case t: FunctionTableSubqueryArgumentExpression =>
@@ -2110,6 +2112,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
21102112
"PARTITION BY clause, but only Python table functions support this clause")
21112113
}
21122114
tableArgs.append(SubqueryAlias(alias, t.evaluable))
2115+
functionTableSubqueryArgs.append(t)
21132116
UnresolvedAttribute(Seq(alias, "c"))
21142117
}
21152118
if (tableArgs.nonEmpty) {
@@ -2118,11 +2121,34 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
21182121
tableArgs.size)
21192122
}
21202123
val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
2124+
// Propagate the column indexes for TABLE arguments to the PythonUDTF instance.
2125+
def assignUDTFPartitionColumnIndexes(
2126+
fn: PythonUDTFPartitionColumnIndexes => LogicalPlan): Option[LogicalPlan] = {
2127+
val indexes: Seq[Int] = functionTableSubqueryArgs.headOption
2128+
.map(_.partitioningExpressionIndexes).getOrElse(Seq.empty)
2129+
if (indexes.nonEmpty) {
2130+
Some(fn(PythonUDTFPartitionColumnIndexes(indexes)))
2131+
} else {
2132+
None
2133+
}
2134+
}
2135+
val tvfWithTableColumnIndexes: LogicalPlan = tvf match {
2136+
case g@Generate(p: PythonUDTF, _, _, _, _, _) =>
2137+
assignUDTFPartitionColumnIndexes(
2138+
i => g.copy(generator = p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
2139+
.getOrElse(g)
2140+
case g@Generate(p: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) =>
2141+
assignUDTFPartitionColumnIndexes(
2142+
i => g.copy(generator = p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
2143+
.getOrElse(g)
2144+
case _ =>
2145+
tvf
2146+
}
21212147
Project(
21222148
Seq(UnresolvedStar(Some(Seq(alias)))),
21232149
LateralJoin(
21242150
tableArgs.reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)),
2125-
LateralSubquery(SubqueryAlias(alias, tvf)), Inner, None)
2151+
LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None)
21262152
)
21272153
} else {
21282154
tvf
@@ -2200,7 +2226,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
22002226
case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) {
22012227
val elementSchema = u.resolveElementSchema(u.func, u.children)
22022228
PythonUDTF(u.name, u.func, elementSchema, u.children,
2203-
u.evalType, u.udfDeterministic, u.resultId)
2229+
u.evalType, u.udfDeterministic, u.resultId, u.pythonUDTFPartitionColumnIndexes)
22042230
}
22052231
}
22062232
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala

-18
Original file line numberDiff line numberDiff line change
@@ -163,21 +163,3 @@ case class FunctionTableSubqueryArgumentExpression(
163163

164164
private lazy val subqueryOutputs: Map[Expression, Int] = plan.output.zipWithIndex.toMap
165165
}
166-
167-
object FunctionTableSubqueryArgumentExpression {
168-
/**
169-
* Returns a sequence of zero-based integer indexes identifying the values of a Python UDTF's
170-
* 'eval' method's *args list that correspond to partitioning columns of the input TABLE argument.
171-
*/
172-
def partitionChildIndexes(udtfArguments: Seq[Expression]): Seq[Int] = {
173-
udtfArguments.zipWithIndex.flatMap { case (expr, index) =>
174-
expr match {
175-
case f: FunctionTableSubqueryArgumentExpression =>
176-
f.partitioningExpressionIndexes.map(_ + index)
177-
case _ =>
178-
Seq()
179-
}
180-
}
181-
}
182-
}
183-

0 commit comments

Comments
 (0)