Skip to content

Commit bc3dda2

Browse files
committed
SI-6448 Collecting the spoils of PartialFun#runWith
Avoids calling both `isDefinedAt` and `apply`. This pathological case that would benefit the most looks like: xs collect { case x if {expensive(); true} => x } The typical change looks like: - for (x <- this) if (pf.isDefinedAt(x)) b += pf(x) + foreach(pf.runWith(b += _)) Incorporates feedback provided by Pavel Pavlov: retronym@ef5430 A few more opportunities for optimization are noted in the `Pending` section of the enclosed test. `Iterator.collect` would be nice, but a solution eludes me. Calling the guard less frequently does change the behaviour of these functions in an obervable way, but not contravene the documented semantics. That said, there is an alternative opinion on the comment of the ticket: https://issues.scala-lang.org/browse/SI-6448
1 parent 2c6777f commit bc3dda2

File tree

8 files changed

+110
-12
lines changed

8 files changed

+110
-12
lines changed

src/library/scala/Option.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ sealed abstract class Option[+A] extends Product with Serializable {
256256
* value (if possible), or $none.
257257
*/
258258
@inline final def collect[B](pf: PartialFunction[A, B]): Option[B] =
259-
if (!isEmpty && pf.isDefinedAt(this.get)) Some(pf(this.get)) else None
259+
if (!isEmpty) pf.lift(this.get) else None
260260

261261
/** Returns this $option if it is nonempty,
262262
* otherwise return the result of evaluating `alternative`.

src/library/scala/collection/TraversableLike.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ trait TraversableLike[+A, +Repr] extends Any
275275

276276
def collect[B, That](pf: PartialFunction[A, B])(implicit bf: CanBuildFrom[Repr, B, That]): That = {
277277
val b = bf(repr)
278-
for (x <- this) if (pf.isDefinedAt(x)) b += pf(x)
278+
foreach(pf.runWith(b += _))
279279
b.result
280280
}
281281

src/library/scala/collection/TraversableOnce.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,8 @@ trait TraversableOnce[+A] extends Any with GenTraversableOnce[A] {
128128
* @example `Seq("a", 1, 5L).collectFirst({ case x: Int => x*10 }) = Some(10)`
129129
*/
130130
def collectFirst[B](pf: PartialFunction[A, B]): Option[B] = {
131-
for (x <- self.toIterator) { // make sure to use an iterator or `seq`
132-
if (pf isDefinedAt x)
133-
return Some(pf(x))
134-
}
131+
// make sure to use an iterator or `seq`
132+
self.toIterator.foreach(pf.runWith(b => return Some(b)))
135133
None
136134
}
137135

src/library/scala/collection/immutable/Stream.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,17 @@ self =>
385385
// 1) stackoverflows (could be achieved with tailrec, too)
386386
// 2) out of memory errors for big streams (`this` reference can be eliminated from the stack)
387387
var rest: Stream[A] = this
388-
while (rest.nonEmpty && !pf.isDefinedAt(rest.head)) rest = rest.tail
388+
389+
// Avoids calling both `pf.isDefined` and `pf.apply`.
390+
var newHead: B = null.asInstanceOf[B]
391+
val runWith = pf.runWith((b: B) => newHead = b)
392+
393+
while (rest.nonEmpty && !runWith(rest.head)) rest = rest.tail
389394

390395
// without the call to the companion object, a thunk is created for the tail of the new stream,
391396
// and the closure of the thunk will reference `this`
392397
if (rest.isEmpty) Stream.Empty.asInstanceOf[That]
393-
else Stream.collectedTail(rest, pf, bf).asInstanceOf[That]
398+
else Stream.collectedTail(newHead, rest, pf, bf).asInstanceOf[That]
394399
}
395400
}
396401

@@ -1170,8 +1175,8 @@ object Stream extends SeqFactory[Stream] {
11701175
cons(stream.head, stream.tail filter p)
11711176
}
11721177

1173-
private[immutable] def collectedTail[A, B, That](stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = {
1174-
cons(pf(stream.head), stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]])
1178+
private[immutable] def collectedTail[A, B, That](head: B, stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = {
1179+
cons(head, stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]])
11751180
}
11761181
}
11771182

src/library/scala/collection/parallel/RemainsIterator.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,10 @@ private[collection] trait AugmentedIterableIterator[+T] extends RemainsIterator[
123123

124124
def collect2combiner[S, That](pf: PartialFunction[T, S], cb: Combiner[S, That]): Combiner[S, That] = {
125125
//val cb = pbf(repr)
126+
val runWith = pf.runWith(cb += _)
126127
while (hasNext) {
127128
val curr = next
128-
if (pf.isDefinedAt(curr)) cb += pf(curr)
129+
runWith(curr)
129130
}
130131
cb
131132
}

src/library/scala/collection/parallel/mutable/ParArray.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,10 @@ self =>
405405

406406
private def collect2combiner_quick[S, That](pf: PartialFunction[T, S], a: Array[Any], cb: Builder[S, That], ntil: Int, from: Int) {
407407
var j = from
408+
val runWith = pf.runWith(b => cb += b)
408409
while (j < ntil) {
409410
val curr = a(j).asInstanceOf[T]
410-
if (pf.isDefinedAt(curr)) cb += pf(curr)
411+
runWith(curr)
411412
j += 1
412413
}
413414
}

test/files/run/t6448.check

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
2+
=List.collect=
3+
f(1)
4+
f(2)
5+
List(1)
6+
7+
=List.collectFirst=
8+
f(1)
9+
Some(1)
10+
11+
=Option.collect=
12+
f(1)
13+
Some(1)
14+
15+
=Option.collect=
16+
f(2)
17+
None
18+
19+
=Stream.collect=
20+
f(1)
21+
f(2)
22+
List(1)
23+
24+
=Stream.collectFirst=
25+
f(1)
26+
Some(1)
27+
28+
=ParVector.collect=
29+
(ParVector(1),2)
30+
31+
=ParArray.collect=
32+
(ParArray(1),2)

test/files/run/t6448.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Tests to show that various `collect` functions avoid calling
2+
// both `PartialFunction#isDefinedAt` and `PartialFunction#apply`.
3+
//
4+
object Test {
5+
def f(i: Int) = { println("f(" + i + ")"); true }
6+
class Counter {
7+
var count = 0
8+
def apply(i: Int) = synchronized {count += 1; true}
9+
}
10+
11+
def testing(label: String)(body: => Any) {
12+
println(s"\n=$label=")
13+
println(body)
14+
}
15+
16+
def main(args: Array[String]) {
17+
testing("List.collect")(List(1, 2) collect { case x if f(x) && x < 2 => x})
18+
testing("List.collectFirst")(List(1, 2) collectFirst { case x if f(x) && x < 2 => x})
19+
testing("Option.collect")(Some(1) collect { case x if f(x) && x < 2 => x})
20+
testing("Option.collect")(Some(2) collect { case x if f(x) && x < 2 => x})
21+
testing("Stream.collect")((Stream(1, 2).collect { case x if f(x) && x < 2 => x}).toList)
22+
testing("Stream.collectFirst")(Stream.continually(1) collectFirst { case x if f(x) && x < 2 => x})
23+
24+
import collection.parallel.ParIterable
25+
import collection.parallel.immutable.ParVector
26+
import collection.parallel.mutable.ParArray
27+
testing("ParVector.collect") {
28+
val counter = new Counter()
29+
(ParVector(1, 2) collect { case x if counter(x) && x < 2 => x}, counter.synchronized(counter.count))
30+
}
31+
32+
testing("ParArray.collect") {
33+
val counter = new Counter()
34+
(ParArray(1, 2) collect { case x if counter(x) && x < 2 => x}, counter.synchronized(counter.count))
35+
}
36+
37+
object PendingTests {
38+
testing("Iterator.collect")((Iterator(1, 2) collect { case x if f(x) && x < 2 => x}).toList)
39+
40+
testing("List.view.collect")((List(1, 2).view collect { case x if f(x) && x < 2 => x}).force)
41+
42+
// This would do the trick in Future.collect, but I haven't added this yet as there is a tradeoff
43+
// with extra allocations to consider.
44+
//
45+
// pf.lift(v) match {
46+
// case Some(x) => p success x
47+
// case None => fail(v)
48+
// }
49+
testing("Future.collect") {
50+
import concurrent.ExecutionContext.Implicits.global
51+
import concurrent.Await
52+
import concurrent.duration.Duration
53+
val result = concurrent.future(1) collect { case x if f(x) => x}
54+
Await.result(result, Duration.Inf)
55+
}
56+
57+
// TODO Future.{onSuccess, onFailure, recoverWith, andThen}
58+
}
59+
60+
}
61+
}

0 commit comments

Comments
 (0)