Skip to content

Commit 889cead

Browse files
committed
Merge pull request scala#1570 from retronym/ticket/6448
SI-6448 Collecting the spoils of PartialFun#runWith
2 parents 1cfb363 + bc3dda2 commit 889cead

File tree

8 files changed

+110
-12
lines changed

8 files changed

+110
-12
lines changed

src/library/scala/Option.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ sealed abstract class Option[+A] extends Product with Serializable {
256256
* value (if possible), or $none.
257257
*/
258258
@inline final def collect[B](pf: PartialFunction[A, B]): Option[B] =
259-
if (!isEmpty && pf.isDefinedAt(this.get)) Some(pf(this.get)) else None
259+
if (!isEmpty) pf.lift(this.get) else None
260260

261261
/** Returns this $option if it is nonempty,
262262
* otherwise return the result of evaluating `alternative`.

src/library/scala/collection/TraversableLike.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ trait TraversableLike[+A, +Repr] extends Any
275275

276276
def collect[B, That](pf: PartialFunction[A, B])(implicit bf: CanBuildFrom[Repr, B, That]): That = {
277277
val b = bf(repr)
278-
for (x <- this) if (pf.isDefinedAt(x)) b += pf(x)
278+
foreach(pf.runWith(b += _))
279279
b.result
280280
}
281281

src/library/scala/collection/TraversableOnce.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,8 @@ trait TraversableOnce[+A] extends Any with GenTraversableOnce[A] {
128128
* @example `Seq("a", 1, 5L).collectFirst({ case x: Int => x*10 }) = Some(10)`
129129
*/
130130
def collectFirst[B](pf: PartialFunction[A, B]): Option[B] = {
131-
for (x <- self.toIterator) { // make sure to use an iterator or `seq`
132-
if (pf isDefinedAt x)
133-
return Some(pf(x))
134-
}
131+
// make sure to use an iterator or `seq`
132+
self.toIterator.foreach(pf.runWith(b => return Some(b)))
135133
None
136134
}
137135

src/library/scala/collection/immutable/Stream.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,17 @@ self =>
385385
// 1) stackoverflows (could be achieved with tailrec, too)
386386
// 2) out of memory errors for big streams (`this` reference can be eliminated from the stack)
387387
var rest: Stream[A] = this
388-
while (rest.nonEmpty && !pf.isDefinedAt(rest.head)) rest = rest.tail
388+
389+
// Avoids calling both `pf.isDefined` and `pf.apply`.
390+
var newHead: B = null.asInstanceOf[B]
391+
val runWith = pf.runWith((b: B) => newHead = b)
392+
393+
while (rest.nonEmpty && !runWith(rest.head)) rest = rest.tail
389394

390395
// without the call to the companion object, a thunk is created for the tail of the new stream,
391396
// and the closure of the thunk will reference `this`
392397
if (rest.isEmpty) Stream.Empty.asInstanceOf[That]
393-
else Stream.collectedTail(rest, pf, bf).asInstanceOf[That]
398+
else Stream.collectedTail(newHead, rest, pf, bf).asInstanceOf[That]
394399
}
395400
}
396401

@@ -1170,8 +1175,8 @@ object Stream extends SeqFactory[Stream] {
11701175
cons(stream.head, stream.tail filter p)
11711176
}
11721177

1173-
private[immutable] def collectedTail[A, B, That](stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = {
1174-
cons(pf(stream.head), stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]])
1178+
private[immutable] def collectedTail[A, B, That](head: B, stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = {
1179+
cons(head, stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]])
11751180
}
11761181
}
11771182

src/library/scala/collection/parallel/RemainsIterator.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,10 @@ private[collection] trait AugmentedIterableIterator[+T] extends RemainsIterator[
123123

124124
def collect2combiner[S, That](pf: PartialFunction[T, S], cb: Combiner[S, That]): Combiner[S, That] = {
125125
//val cb = pbf(repr)
126+
val runWith = pf.runWith(cb += _)
126127
while (hasNext) {
127128
val curr = next
128-
if (pf.isDefinedAt(curr)) cb += pf(curr)
129+
runWith(curr)
129130
}
130131
cb
131132
}

src/library/scala/collection/parallel/mutable/ParArray.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,10 @@ self =>
405405

406406
private def collect2combiner_quick[S, That](pf: PartialFunction[T, S], a: Array[Any], cb: Builder[S, That], ntil: Int, from: Int) {
407407
var j = from
408+
val runWith = pf.runWith(b => cb += b)
408409
while (j < ntil) {
409410
val curr = a(j).asInstanceOf[T]
410-
if (pf.isDefinedAt(curr)) cb += pf(curr)
411+
runWith(curr)
411412
j += 1
412413
}
413414
}

test/files/run/t6448.check

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
2+
=List.collect=
3+
f(1)
4+
f(2)
5+
List(1)
6+
7+
=List.collectFirst=
8+
f(1)
9+
Some(1)
10+
11+
=Option.collect=
12+
f(1)
13+
Some(1)
14+
15+
=Option.collect=
16+
f(2)
17+
None
18+
19+
=Stream.collect=
20+
f(1)
21+
f(2)
22+
List(1)
23+
24+
=Stream.collectFirst=
25+
f(1)
26+
Some(1)
27+
28+
=ParVector.collect=
29+
(ParVector(1),2)
30+
31+
=ParArray.collect=
32+
(ParArray(1),2)

test/files/run/t6448.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Tests to show that various `collect` functions avoid calling
2+
// both `PartialFunction#isDefinedAt` and `PartialFunction#apply`.
3+
//
4+
object Test {
5+
def f(i: Int) = { println("f(" + i + ")"); true }
6+
class Counter {
7+
var count = 0
8+
def apply(i: Int) = synchronized {count += 1; true}
9+
}
10+
11+
def testing(label: String)(body: => Any) {
12+
println(s"\n=$label=")
13+
println(body)
14+
}
15+
16+
def main(args: Array[String]) {
17+
testing("List.collect")(List(1, 2) collect { case x if f(x) && x < 2 => x})
18+
testing("List.collectFirst")(List(1, 2) collectFirst { case x if f(x) && x < 2 => x})
19+
testing("Option.collect")(Some(1) collect { case x if f(x) && x < 2 => x})
20+
testing("Option.collect")(Some(2) collect { case x if f(x) && x < 2 => x})
21+
testing("Stream.collect")((Stream(1, 2).collect { case x if f(x) && x < 2 => x}).toList)
22+
testing("Stream.collectFirst")(Stream.continually(1) collectFirst { case x if f(x) && x < 2 => x})
23+
24+
import collection.parallel.ParIterable
25+
import collection.parallel.immutable.ParVector
26+
import collection.parallel.mutable.ParArray
27+
testing("ParVector.collect") {
28+
val counter = new Counter()
29+
(ParVector(1, 2) collect { case x if counter(x) && x < 2 => x}, counter.synchronized(counter.count))
30+
}
31+
32+
testing("ParArray.collect") {
33+
val counter = new Counter()
34+
(ParArray(1, 2) collect { case x if counter(x) && x < 2 => x}, counter.synchronized(counter.count))
35+
}
36+
37+
object PendingTests {
38+
testing("Iterator.collect")((Iterator(1, 2) collect { case x if f(x) && x < 2 => x}).toList)
39+
40+
testing("List.view.collect")((List(1, 2).view collect { case x if f(x) && x < 2 => x}).force)
41+
42+
// This would do the trick in Future.collect, but I haven't added this yet as there is a tradeoff
43+
// with extra allocations to consider.
44+
//
45+
// pf.lift(v) match {
46+
// case Some(x) => p success x
47+
// case None => fail(v)
48+
// }
49+
testing("Future.collect") {
50+
import concurrent.ExecutionContext.Implicits.global
51+
import concurrent.Await
52+
import concurrent.duration.Duration
53+
val result = concurrent.future(1) collect { case x if f(x) => x}
54+
Await.result(result, Duration.Inf)
55+
}
56+
57+
// TODO Future.{onSuccess, onFailure, recoverWith, andThen}
58+
}
59+
60+
}
61+
}

0 commit comments

Comments
 (0)