diff --git a/src/main/scala/scala/collection/decorators/IteratorDecorator.scala b/src/main/scala/scala/collection/decorators/IteratorDecorator.scala index 6800a56..52b0192 100644 --- a/src/main/scala/scala/collection/decorators/IteratorDecorator.scala +++ b/src/main/scala/scala/collection/decorators/IteratorDecorator.scala @@ -2,6 +2,7 @@ package scala.collection package decorators import scala.annotation.tailrec +import scala.util.control.NonFatal /** Enriches Iterator with additional methods. * @@ -227,4 +228,44 @@ class IteratorDecorator[A](val `this`: Iterator[A]) extends AnyVal { } } } + + /** Gives elements from the source iterator until the source iterator ends or throws a NonFatal exception. + * + * @return an iterator that takes items until the source iterator ends or throws a NonFatal exception + * @see scala.util.control.NonFatal + * @note Reuse: $consumesAndProducesIterator + */ + def takeUntilException: Iterator[A] = { + takeUntilException(_ => ()) + } + + /** Gives elements from the source iterator until the source iterator ends or throws a NonFatal exception. + * + * @param exceptionCaught a callback invoked from `hasNext` when the source iterator throws a NonFatal exception + * @return an iterator that takes items until the wrapped iterator ends or throws a NonFatal exception + * @see scala.util.control.NonFatal + * @note Reuse: $consumesAndProducesIterator + */ + def takeUntilException(exceptionCaught: Throwable => Unit): Iterator[A] = { + new AbstractIterator[A] { + private val wrapped = `this`.buffered + + override def hasNext: Boolean = { + try { + val n = wrapped.hasNext + // By already invoking `head` (and therefore also `next` on `this`), + // we are sure that `wrapped.next` will not throw when it is used from + // `next`. + if (n) wrapped.head + n + } catch { + case NonFatal(t) => + exceptionCaught(t) + false + } + } + + override def next(): A = wrapped.next + } + } } diff --git a/src/test/scala/scala/collection/decorators/IteratorDecoratorTest.scala b/src/test/scala/scala/collection/decorators/IteratorDecoratorTest.scala index ff80e94..6a860bb 100644 --- a/src/test/scala/scala/collection/decorators/IteratorDecoratorTest.scala +++ b/src/test/scala/scala/collection/decorators/IteratorDecoratorTest.scala @@ -136,4 +136,67 @@ class IteratorDecoratorTest { Iterator.from(0).map(_ / 3).splitBy(identity).take(3).toSeq ) } + + @Test + def takeUntilExceptionShouldWrapAnyNonThrowingIterator(): Unit = { + Assert.assertEquals(Seq(1, 2, 3, 4, 5), Iterator(1, 2, 3, 4, 5).takeUntilException.toSeq) + Assert.assertEquals(Seq(1, 2, 3, 4, 5), Iterator(1, 2, 3, 4, 5).takeUntilException(_ => ()).toSeq) + Assert.assertEquals(Seq.empty, Iterator.empty.takeUntilException.toSeq) + Assert.assertEquals(Seq.empty, Iterator.empty.takeUntilException(_ => ()).toSeq) + // Works with infinite iterators: + Assert.assertEquals(Seq(1, 2, 3, 4, 5), Iterator.from(1).takeUntilException.take(5).toSeq) + Assert.assertEquals(Seq(1, 2, 3, 4, 5), Iterator.from(1).takeUntilException(_ => ()).take(5).toSeq) + } + + @Test + def takeUntilExceptionShouldTakeTillAnExceptionFromHasNext(): Unit = { + val toThrow = new RuntimeException("~expected exception~") + def brokenIterator: Iterator[Int] = new AbstractIterator[Int] { + private var previousPosition = 0 + + override def hasNext: Boolean = { + if (previousPosition == 3) { + throw toThrow + } else { + true + } + } + + override def next(): Int = { + previousPosition += 1 + previousPosition + } + } + + Assert.assertEquals(Seq(1, 2, 3), brokenIterator.takeUntilException.toSeq) + + var caught: Throwable = null + Assert.assertEquals(Seq(1, 2, 3), brokenIterator.takeUntilException(caught = _).toSeq) + Assert.assertSame(toThrow, caught) + } + + @Test + def takeUntilExceptionShouldTakeTillAnExceptionFromNext(): Unit = { + val toThrow = new RuntimeException("~expected exception~") + def brokenIterator: Iterator[Int] = new AbstractIterator[Int] { + private var previousPosition = 0 + + override def hasNext: Boolean = true + + override def next(): Int = { + if (previousPosition == 3) { + throw toThrow + } else { + previousPosition += 1 + previousPosition + } + } + } + + Assert.assertEquals(Seq(1, 2, 3), brokenIterator.takeUntilException.toSeq) + + var caught: Throwable = null + Assert.assertEquals(Seq(1, 2, 3), brokenIterator.takeUntilException(caught = _).toSeq) + Assert.assertSame(toThrow, caught) + } }