Skip to content

Commit 6f8c620

Browse files
alkispoly-dbsrowen
authored andcommitted
[SPARK-35558] Optimizes for multi-quantile retrieval
### What changes were proposed in this pull request? Optimizes the retrieval of approximate quantiles for an array of percentiles. * Adds an overload for QuantileSummaries.query that accepts an array of percentiles and optimizes the computation to do a single pass over the sketch and avoid redundant computation. * Modifies the ApproximatePercentiles operator to call into the new method. All formatting changes are the result of running ./dev/scalafmt ### Why are the changes needed? The existing implementation does repeated calls per input percentile resulting in redundant computation. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests for the new method. Closes #32700 from alkispoly-db/spark_35558_approx_quants_array. Authored-by: Alkis Polyzotis <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 510bde4 commit 6f8c620

File tree

5 files changed

+149
-57
lines changed

5 files changed

+149
-57
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala

+2-9
Original file line numberDiff line numberDiff line change
@@ -261,19 +261,12 @@ object ApproximatePercentile {
261261
* val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
262262
* }}}
263263
*/
264-
def getPercentiles(percentages: Array[Double]): Array[Double] = {
264+
def getPercentiles(percentages: Array[Double]): Seq[Double] = {
265265
if (!isCompressed) compress()
266266
if (summaries.count == 0 || percentages.length == 0) {
267267
Array.emptyDoubleArray
268268
} else {
269-
val result = new Array[Double](percentages.length)
270-
var i = 0
271-
while (i < percentages.length) {
272-
// Since summaries.count != 0, the query here never return None.
273-
result(i) = summaries.query(percentages(i)).get
274-
i += 1
275-
}
276-
result
269+
summaries.query(percentages).get
277270
}
278271
}
279272

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala

+80-27
Original file line numberDiff line numberDiff line change
@@ -229,46 +229,99 @@ class QuantileSummaries(
229229
}
230230

231231
/**
232-
* Runs a query for a given quantile.
232+
* Finds the approximate quantile for a percentile, starting at a specific index in the summary.
233+
* This is a helper method that is called as we are making a pass over the summary and a sorted
234+
* sequence of input percentiles.
235+
*
236+
* @param index The point at which to start scanning the summary for an approximate value.
237+
* @param minRankAtIndex The accumulated minimum rank at the given index.
238+
* @param targetError Target error from the summary.
239+
* @param percentile The percentile whose value is computed.
240+
* @return A tuple (i, r, a) where: i is the updated index for the next call, r is the updated
241+
* rank at i, and a is the approximate quantile.
242+
*/
243+
private def findApproxQuantile(
244+
index: Int,
245+
minRankAtIndex: Long,
246+
targetError: Double,
247+
percentile: Double): (Int, Long, Double) = {
248+
var curSample = sampled(index)
249+
val rank = math.ceil(percentile * count).toLong
250+
var i = index
251+
var minRank = minRankAtIndex
252+
while (i < sampled.length - 1) {
253+
val maxRank = minRank + curSample.delta
254+
if (maxRank - targetError <= rank && rank <= minRank + targetError) {
255+
return (i, minRank, curSample.value)
256+
} else {
257+
i += 1
258+
curSample = sampled(i)
259+
minRank += curSample.g
260+
}
261+
}
262+
(sampled.length - 1, 0, sampled.last.value)
263+
}
264+
265+
/**
266+
* Runs a query for a given sequence of percentiles.
233267
* The result follows the approximation guarantees detailed above.
234268
* The query can only be run on a compressed summary: you need to call compress() before using
235269
* it.
236270
*
237-
* @param quantile the target quantile
238-
* @return
271+
* @param percentiles the target percentiles
272+
* @return the corresponding approximate quantiles, in the same order as the input
239273
*/
240-
def query(quantile: Double): Option[Double] = {
241-
require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]")
242-
require(headSampled.isEmpty,
274+
def query(percentiles: Seq[Double]): Option[Seq[Double]] = {
275+
percentiles.foreach(p =>
276+
require(p >= 0 && p <= 1.0, "percentile should be in the range [0.0, 1.0]"))
277+
require(
278+
headSampled.isEmpty,
243279
"Cannot operate on an uncompressed summary, call compress() first")
244280

245281
if (sampled.isEmpty) return None
246282

247-
if (quantile <= relativeError) {
248-
return Some(sampled.head.value)
249-
}
283+
val targetError = sampled.foldLeft(Long.MinValue)((currentMax, stats) =>
284+
currentMax.max(stats.delta + stats.g)) / 2
250285

251-
if (quantile >= 1 - relativeError) {
252-
return Some(sampled.last.value)
253-
}
254-
255-
// Target rank
256-
val rank = math.ceil(quantile * count).toLong
257-
val targetError = sampled.map(s => s.delta + s.g).max / 2
286+
// Index to track the current sample
287+
var index = 0
258288
// Minimum rank at current sample
259-
var minRank = 0L
260-
var i = 0
261-
while (i < sampled.length - 1) {
262-
val curSample = sampled(i)
263-
minRank += curSample.g
264-
val maxRank = minRank + curSample.delta
265-
if (maxRank - targetError <= rank && rank <= minRank + targetError) {
266-
return Some(curSample.value)
267-
}
268-
i += 1
289+
var minRank = sampled(0).g
290+
291+
val sortedPercentiles = percentiles.zipWithIndex.sortBy(_._1)
292+
val result = Array.fill(percentiles.length)(0.0)
293+
sortedPercentiles.foreach {
294+
case (percentile, pos) =>
295+
if (percentile <= relativeError) {
296+
result(pos) = sampled.head.value
297+
} else if (percentile >= 1 - relativeError) {
298+
result(pos) = sampled.last.value
299+
} else {
300+
val (newIndex, newMinRank, approxQuantile) =
301+
findApproxQuantile(index, minRank, targetError, percentile)
302+
index = newIndex
303+
minRank = newMinRank
304+
result(pos) = approxQuantile
305+
}
269306
}
270-
Some(sampled.last.value)
307+
Some(result)
271308
}
309+
310+
/**
311+
* Runs a query for a given percentile.
312+
* The result follows the approximation guarantees detailed above.
313+
* The query can only be run on a compressed summary: you need to call compress() before using
314+
* it.
315+
*
316+
* @param percentile the target percentile
317+
* @return the corresponding approximate quantile
318+
*/
319+
def query(percentile: Double): Option[Double] =
320+
query(Seq(percentile)) match {
321+
case Some(approxSeq) if approxSeq.nonEmpty => Some(approxSeq.head)
322+
case _ => None
323+
}
324+
272325
}
273326

274327
object QuantileSummaries {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala

+60-19
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20+
import scala.collection.mutable.ArrayBuffer
2021
import scala.util.Random
2122

2223
import org.apache.spark.SparkFunSuite
@@ -54,25 +55,51 @@ class QuantileSummariesSuite extends SparkFunSuite {
5455
summary
5556
}
5657

57-
private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = {
58+
private def validateQuantileApproximation(
59+
approx: Double,
60+
percentile: Double,
61+
data: Seq[Double],
62+
summary: QuantileSummaries): Unit = {
63+
assert(data.nonEmpty)
64+
65+
val rankOfValue = data.count(_ <= approx)
66+
val rankOfPreValue = data.count(_ < approx)
67+
// `rankOfValue` is the last position of the quantile value. If the input repeats the value
68+
// chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's
69+
// improper to choose the last position as its rank. Instead, we get the rank by averaging
70+
// `rankOfValue` and `rankOfPreValue`.
71+
val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
72+
val lower = math.floor((percentile - summary.relativeError) * data.size)
73+
val upper = math.ceil((percentile + summary.relativeError) * data.size)
74+
val msg =
75+
s"$rank not in [$lower $upper], requested percentile: $percentile, approx returned: $approx"
76+
assert(rank >= lower, msg)
77+
assert(rank <= upper, msg)
78+
}
79+
80+
private def checkQuantile(
81+
percentile: Double,
82+
data: Seq[Double],
83+
summary: QuantileSummaries): Unit = {
5884
if (data.nonEmpty) {
59-
val approx = summary.query(quant).get
60-
// Get the rank of the approximation.
61-
val rankOfValue = data.count(_ <= approx)
62-
val rankOfPreValue = data.count(_ < approx)
63-
// `rankOfValue` is the last position of the quantile value. If the input repeats the value
64-
// chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's
65-
// improper to choose the last position as its rank. Instead, we get the rank by averaging
66-
// `rankOfValue` and `rankOfPreValue`.
67-
val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
68-
val lower = math.floor((quant - summary.relativeError) * data.size)
69-
val upper = math.ceil((quant + summary.relativeError) * data.size)
70-
val msg =
71-
s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx"
72-
assert(rank >= lower, msg)
73-
assert(rank <= upper, msg)
85+
val approx = summary.query(percentile).get
86+
validateQuantileApproximation(approx, percentile, data, summary)
87+
} else {
88+
assert(summary.query(percentile).isEmpty)
89+
}
90+
}
91+
92+
private def checkQuantiles(
93+
percentiles: Seq[Double],
94+
data: Seq[Double],
95+
summary: QuantileSummaries): Unit = {
96+
if (data.nonEmpty) {
97+
val approx = summary.query(percentiles).get
98+
for ((q, a) <- percentiles zip approx) {
99+
validateQuantileApproximation(a, q, data, summary)
100+
}
74101
} else {
75-
assert(summary.query(quant).isEmpty)
102+
assert(summary.query(percentiles).isEmpty)
76103
}
77104
}
78105

@@ -98,6 +125,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
98125
checkQuantile(0.5, data, s)
99126
checkQuantile(0.1, data, s)
100127
checkQuantile(0.001, data, s)
128+
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
129+
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
101130
}
102131

103132
test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression " +
@@ -109,6 +138,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
109138
checkQuantile(0.5, data, s)
110139
checkQuantile(0.1, data, s)
111140
checkQuantile(0.001, data, s)
141+
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
142+
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
112143
}
113144

114145
test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") {
@@ -121,6 +152,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
121152
checkQuantile(0.5, emptyData, s)
122153
checkQuantile(0.1, emptyData, s)
123154
checkQuantile(0.001, emptyData, s)
155+
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), emptyData, s)
156+
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), emptyData, s)
124157
}
125158
}
126159

@@ -149,6 +182,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
149182
checkQuantile(0.5, data, s)
150183
checkQuantile(0.1, data, s)
151184
checkQuantile(0.001, data, s)
185+
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
186+
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
152187
}
153188

154189
val (data11, data12) = {
@@ -168,6 +203,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
168203
checkQuantile(0.5, data, s)
169204
checkQuantile(0.1, data, s)
170205
checkQuantile(0.001, data, s)
206+
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
207+
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
171208
}
172209

173210
// length of data21 is 4 * length of data22
@@ -181,10 +218,14 @@ class QuantileSummariesSuite extends SparkFunSuite {
181218
val s2 = buildSummary(data22, epsi, compression)
182219
val s = s1.merge(s2)
183220
// Check all quantiles
221+
val percentiles = ArrayBuffer[Double]()
184222
for (queryRank <- 1 to n) {
185-
val queryQuantile = queryRank.toDouble / n.toDouble
186-
checkQuantile(queryQuantile, data, s)
223+
val percentile = queryRank.toDouble / n.toDouble
224+
checkQuantile(percentile, data, s)
225+
percentiles += percentile
187226
}
227+
checkQuantiles(percentiles.toSeq, data, s)
228+
checkQuantiles(percentiles.reverse.toSeq, data, s)
188229
}
189230
}
190231
}

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ object StatFunctions extends Logging {
102102
}
103103
val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)
104104

105-
summaries.map { summary => probabilities.flatMap(summary.query) }
105+
summaries.map {
106+
summary => summary.query(probabilities) match {
107+
case Some(q) => q
108+
case None => Seq()
109+
}
110+
}
106111
}
107112

108113
/** Calculate the Pearson Correlation Coefficient for the given columns */

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession {
204204
val e = intercept[IllegalArgumentException] {
205205
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head)
206206
}
207-
assert(e.getMessage.contains("quantile should be in the range [0.0, 1.0]"))
207+
assert(e.getMessage.contains("percentile should be in the range [0.0, 1.0]"))
208208

209209
// relativeError should be non-negative
210210
val e2 = intercept[IllegalArgumentException] {

0 commit comments

Comments
 (0)