Skip to content

Commit ad5a950

Browse files
erikvanoostenjulienrf
authored andcommitted
Iterator.takeUntilException (#62)
1 parent 590098f commit ad5a950

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

src/main/scala/scala/collection/decorators/IteratorDecorator.scala

+41
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package scala.collection
22
package decorators
33

44
import scala.annotation.tailrec
5+
import scala.util.control.NonFatal
56

67
/** Enriches Iterator with additional methods.
78
*
@@ -227,4 +228,44 @@ class IteratorDecorator[A](val `this`: Iterator[A]) extends AnyVal {
227228
}
228229
}
229230
}
231+
232+
/** Gives elements from the source iterator until the source iterator ends or throws a NonFatal exception.
233+
*
234+
* @return an iterator that takes items until the source iterator ends or throws a NonFatal exception
235+
* @see scala.util.control.NonFatal
236+
* @note Reuse: $consumesAndProducesIterator
237+
*/
238+
def takeUntilException: Iterator[A] = {
239+
takeUntilException(_ => ())
240+
}
241+
242+
/** Gives elements from the source iterator until the source iterator ends or throws a NonFatal exception.
243+
*
244+
* @param exceptionCaught a callback invoked from `hasNext` when the source iterator throws a NonFatal exception
245+
* @return an iterator that takes items until the wrapped iterator ends or throws a NonFatal exception
246+
* @see scala.util.control.NonFatal
247+
* @note Reuse: $consumesAndProducesIterator
248+
*/
249+
def takeUntilException(exceptionCaught: Throwable => Unit): Iterator[A] = {
250+
new AbstractIterator[A] {
251+
private val wrapped = `this`.buffered
252+
253+
override def hasNext: Boolean = {
254+
try {
255+
val n = wrapped.hasNext
256+
// By already invoking `head` (and therefore also `next` on `this`),
257+
// we are sure that `wrapped.next` will not throw when it is used from
258+
// `next`.
259+
if (n) wrapped.head
260+
n
261+
} catch {
262+
case NonFatal(t) =>
263+
exceptionCaught(t)
264+
false
265+
}
266+
}
267+
268+
override def next(): A = wrapped.next
269+
}
270+
}
230271
}

src/test/scala/scala/collection/decorators/IteratorDecoratorTest.scala

+63
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,67 @@ class IteratorDecoratorTest {
136136
Iterator.from(0).map(_ / 3).splitBy(identity).take(3).toSeq
137137
)
138138
}
139+
140+
@Test
141+
def takeUntilExceptionShouldWrapAnyNonThrowingIterator(): Unit = {
142+
Assert.assertEquals(Seq(1, 2, 3, 4, 5), Iterator(1, 2, 3, 4, 5).takeUntilException.toSeq)
143+
Assert.assertEquals(Seq(1, 2, 3, 4, 5), Iterator(1, 2, 3, 4, 5).takeUntilException(_ => ()).toSeq)
144+
Assert.assertEquals(Seq.empty, Iterator.empty.takeUntilException.toSeq)
145+
Assert.assertEquals(Seq.empty, Iterator.empty.takeUntilException(_ => ()).toSeq)
146+
// Works with infinite iterators:
147+
Assert.assertEquals(Seq(1, 2, 3, 4, 5), Iterator.from(1).takeUntilException.take(5).toSeq)
148+
Assert.assertEquals(Seq(1, 2, 3, 4, 5), Iterator.from(1).takeUntilException(_ => ()).take(5).toSeq)
149+
}
150+
151+
@Test
152+
def takeUntilExceptionShouldTakeTillAnExceptionFromHasNext(): Unit = {
153+
val toThrow = new RuntimeException("~expected exception~")
154+
def brokenIterator: Iterator[Int] = new AbstractIterator[Int] {
155+
private var previousPosition = 0
156+
157+
override def hasNext: Boolean = {
158+
if (previousPosition == 3) {
159+
throw toThrow
160+
} else {
161+
true
162+
}
163+
}
164+
165+
override def next(): Int = {
166+
previousPosition += 1
167+
previousPosition
168+
}
169+
}
170+
171+
Assert.assertEquals(Seq(1, 2, 3), brokenIterator.takeUntilException.toSeq)
172+
173+
var caught: Throwable = null
174+
Assert.assertEquals(Seq(1, 2, 3), brokenIterator.takeUntilException(caught = _).toSeq)
175+
Assert.assertSame(toThrow, caught)
176+
}
177+
178+
@Test
179+
def takeUntilExceptionShouldTakeTillAnExceptionFromNext(): Unit = {
180+
val toThrow = new RuntimeException("~expected exception~")
181+
def brokenIterator: Iterator[Int] = new AbstractIterator[Int] {
182+
private var previousPosition = 0
183+
184+
override def hasNext: Boolean = true
185+
186+
override def next(): Int = {
187+
if (previousPosition == 3) {
188+
throw toThrow
189+
} else {
190+
previousPosition += 1
191+
previousPosition
192+
}
193+
}
194+
}
195+
196+
Assert.assertEquals(Seq(1, 2, 3), brokenIterator.takeUntilException.toSeq)
197+
198+
var caught: Throwable = null
199+
Assert.assertEquals(Seq(1, 2, 3), brokenIterator.takeUntilException(caught = _).toSeq)
200+
Assert.assertSame(toThrow, caught)
201+
}
139202
}

0 commit comments

Comments
 (0)