Skip to content

Commit 68161af

Browse files
committed
Add size comparison methods to IterableOps
Add IterableOps.sizeCompare(Int), .sizeCompare(Iterable[_]), and .sizeIs
1 parent 4f0908f commit 68161af

File tree

5 files changed

+150
-26
lines changed

5 files changed

+150
-26
lines changed

src/library/scala/collection/IndexedSeq.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,16 @@ trait IndexedSeqOps[+A, +CC[_], +C] extends Any with SeqOps[A, CC, C] { self =>
5252

5353
override def last: A = apply(length - 1)
5454

55-
override def lengthCompare(len: Int): Int = Integer.compare(length, len)
55+
override final def lengthCompare(len: Int): Int = Integer.compare(length, len)
5656

5757
final override def knownSize: Int = length
5858

59+
override final def sizeCompare(that: Iterable[_]): Int = {
60+
val res = that.sizeCompare(length)
61+
// can't just invert the result, because `-Int.MinValue == Int.MinValue`
62+
if (res == Int.MinValue) 1 else -res
63+
}
64+
5965
override def search[B >: A](elem: B)(implicit ord: Ordering[B]): SearchResult =
6066
binarySearch(elem, 0, length)(ord)
6167

src/library/scala/collection/Iterable.scala

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,88 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable
228228
/** A view over the elements of this collection. */
229229
def view: View[A] = View.fromIteratorProvider(() => iterator)
230230

231+
/** Compares the size of this $coll to a test value.
232+
*
233+
* @param otherSize the test value that gets compared with the size.
234+
* @return A value `x` where
235+
* {{{
236+
* x < 0 if this.size < otherSize
237+
* x == 0 if this.size == otherSize
238+
* x > 0 if this.size > otherSize
239+
* }}}
240+
* The method as implemented here does not call `size` directly; its running time
241+
* is `O(size min _size)` instead of `O(size)`. The method should be overwritten
242+
* if computing `size` is cheap.
243+
*/
244+
def sizeCompare(otherSize: Int): Int = {
245+
if (otherSize < 0) 1
246+
else {
247+
val known = knownSize
248+
if (known >= 0) Integer.compare(known, otherSize)
249+
else {
250+
var i = 0
251+
val it = iterator
252+
while (it.hasNext) {
253+
if (i == otherSize) return if (it.hasNext) 1 else 0
254+
it.next()
255+
i += 1
256+
}
257+
i - otherSize
258+
}
259+
}
260+
}
261+
262+
/** Returns a value class containing operations for comparing the size of this $coll to a test value.
263+
*
264+
* These operations are implemented in terms of [[sizeCompare(Int) `sizeCompare(Int)`]], and
265+
* allow the following more readable usages:
266+
*
267+
* {{{
268+
* this.sizeIs < size // this.sizeCompare(size) < 0
269+
* this.sizeIs <= size // this.sizeCompare(size) <= 0
270+
* this.sizeIs == size // this.sizeCompare(size) == 0
271+
* this.sizeIs != size // this.sizeCompare(size) != 0
272+
* this.sizeIs >= size // this.sizeCompare(size) >= 0
273+
* this.sizeIs > size // this.sizeCompare(size) > 0
274+
* }}}
275+
*/
276+
@inline final def sizeIs: IterableOps.SizeCompareOps = new IterableOps.SizeCompareOps(this)
277+
278+
/** Compares the size of this $coll to the size of another `Iterable`.
279+
*
280+
* @param that the `Iterable` whose size is compared with this $coll's size.
281+
* {{{
282+
* x < 0 if this.size < that.size
283+
* x == 0 if this.size == that.size
284+
* x > 0 if this.size > that.size
285+
* }}}
286+
* The method as implemented here does not call `size` directly; its running time
287+
* is `O(this.size min that.size)` instead of `O(this.size + that.size)`.
288+
* The method should be overwritten if computing `size` is cheap.
289+
*/
290+
def sizeCompare(that: Iterable[_]): Int = {
291+
val thatKnownSize = that.knownSize
292+
293+
if (thatKnownSize >= 0) this sizeCompare thatKnownSize
294+
else {
295+
val thisKnownSize = this.knownSize
296+
297+
if (thisKnownSize >= 0) {
298+
val res = that sizeCompare thisKnownSize
299+
// can't just invert the result, because `-Int.MinValue == Int.MinValue`
300+
if (res == Int.MinValue) 1 else -res
301+
} else {
302+
val thisIt = this.iterator
303+
val thatIt = that.iterator
304+
while (thisIt.hasNext && thatIt.hasNext) {
305+
thisIt.next()
306+
thatIt.next()
307+
}
308+
java.lang.Boolean.compare(thisIt.hasNext, thatIt.hasNext)
309+
}
310+
}
311+
}
312+
231313
/** A view over a slice of the elements of this collection. */
232314
@deprecated("Use .view.slice(from, until) instead of .view(from, until)", "2.13.0")
233315
@`inline` final def view(from: Int, until: Int): View[A] = view.slice(from, until)
@@ -690,6 +772,26 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable
690772

691773
object IterableOps {
692774

775+
/** Operations for comparing the size of a collection to a test value.
776+
*
777+
* These operations are implemented in terms of
778+
* [[scala.collection.IterableOps.sizeCompare(Int) `sizeCompare(Int)`]].
779+
*/
780+
final class SizeCompareOps private[collection](val it: IterableOps[_, AnyConstr, _]) extends AnyVal {
781+
/** Tests if the size of the collection is less than some value. */
782+
@inline def <(size: Int): Boolean = it.sizeCompare(size) < 0
783+
/** Tests if the size of the collection is less than or equal to some value. */
784+
@inline def <=(size: Int): Boolean = it.sizeCompare(size) <= 0
785+
/** Tests if the size of the collection is equal to some value. */
786+
@inline def ==(size: Int): Boolean = it.sizeCompare(size) == 0
787+
/** Tests if the size of the collection is not equal to some value. */
788+
@inline def !=(size: Int): Boolean = it.sizeCompare(size) != 0
789+
/** Tests if the size of the collection is greater than or equal to some value. */
790+
@inline def >=(size: Int): Boolean = it.sizeCompare(size) >= 0
791+
/** Tests if the size of the collection is greater than some value. */
792+
@inline def >(size: Int): Boolean = it.sizeCompare(size) > 0
793+
}
794+
693795
/** A trait that contains just the `map`, `flatMap`, `foreach` and `withFilter` methods
694796
* of trait `Iterable`.
695797
*

src/library/scala/collection/Seq.scala

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,8 @@ trait SeqOps[+A, +CC[_], +C] extends Any
661661
*/
662662
def indices: Range = Range(0, length)
663663

664+
override final def sizeCompare(_size: Int): Int = lengthCompare(_size)
665+
664666
/** Compares the length of this $coll to a test value.
665667
*
666668
* @param len the test value that gets compared with the length.
@@ -674,23 +676,7 @@ trait SeqOps[+A, +CC[_], +C] extends Any
674676
* is `O(length min len)` instead of `O(length)`. The method should be overwritten
675677
* if computing `length` is cheap.
676678
*/
677-
def lengthCompare(len: Int): Int = {
678-
if (len < 0) 1
679-
else {
680-
val known = knownSize
681-
if (known >= 0) Integer.compare(known, len)
682-
else {
683-
var i = 0
684-
val it = iterator
685-
while (it.hasNext) {
686-
if (i == len) return if (it.hasNext) 1 else 0
687-
it.next()
688-
i += 1
689-
}
690-
i - len
691-
}
692-
}
693-
}
679+
def lengthCompare(len: Int): Int = super.sizeCompare(len)
694680

695681
/** Returns a value class containing operations for comparing the length of this $coll to a test value.
696682
*
@@ -706,7 +692,7 @@ trait SeqOps[+A, +CC[_], +C] extends Any
706692
* this.lengthIs > len // this.lengthCompare(len) > 0
707693
* }}}
708694
*/
709-
@inline final def lengthIs: SeqOps.LengthCompareOps = new SeqOps.LengthCompareOps(this)
695+
@inline final def lengthIs: IterableOps.SizeCompareOps = new IterableOps.SizeCompareOps(this)
710696

711697
override def isEmpty: Boolean = lengthCompare(0) == 0
712698

test/benchmarks/src/main/scala/scala/collection/LengthCompareOpsBenchmark.scala renamed to test/benchmarks/src/main/scala/scala/collection/SizeCompareOpsBenchmark.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ import scala.util.Random
1313
@Measurement(iterations = 10)
1414
@OutputTimeUnit(TimeUnit.NANOSECONDS)
1515
@State(Scope.Benchmark)
16-
class LengthCompareOpsBenchmark {
16+
class SizeCompareOpsBenchmark {
1717
@Param(Array("0", "1", "10", "100", "1000"))
1818
var size: Int = _
1919

2020
@Param(Array("1", "100", "10000"))
21-
var len: Int = _
21+
var cmpTo: Int = _
2222

2323
var values: List[Int] = _
2424

@@ -27,11 +27,11 @@ class LengthCompareOpsBenchmark {
2727
values = List.fill(size)(Random.nextInt())
2828
}
2929

30-
@Benchmark def lengthCompareUgly: Any = {
31-
values.lengthCompare(len) == 0
30+
@Benchmark def sizeCompareUgly: Any = {
31+
values.sizeCompare(cmpTo) == 0
3232
}
3333

34-
@Benchmark def lengthComparePretty: Any = {
35-
values.lengthIs == len
34+
@Benchmark def sizeComparePretty: Any = {
35+
values.sizeIs == cmpTo
3636
}
3737
}

test/junit/scala/collection/IterableTest.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package scala.collection
33
import org.junit.{Assert, Test}
44
import org.junit.runner.RunWith
55
import org.junit.runners.JUnit4
6-
import scala.collection.immutable.{ArraySeq, List, Range}
6+
7+
import scala.collection.immutable.{ArraySeq, List, Range, Vector}
8+
import scala.language.higherKinds
79
import scala.tools.testing.AssertUtil._
810

911
@RunWith(classOf[JUnit4])
@@ -58,6 +60,34 @@ class IterableTest {
5860
Assert.assertEquals(expected, occurrences(xs))
5961
}
6062

63+
@Test
64+
def sizeCompareInt(): Unit = {
65+
val seq = Seq(1, 2, 3)
66+
assert(seq.sizeCompare(2) > 0)
67+
assert(seq.sizeCompare(3) == 0)
68+
assert(seq.sizeCompare(4) < 0)
69+
}
70+
71+
@Test
72+
def sizeCompareIterable(): Unit = {
73+
def check[I1[X] <: Iterable[X], I2[X] <: Iterable[X]]
74+
(f1: IterableFactory[I1], f2: IterableFactory[I2]): Unit = {
75+
val it = f1(1, 2, 3)
76+
assert(it.sizeCompare(f2(1, 2)) > 0)
77+
assert(it.sizeCompare(f2(1, 2, 3)) == 0)
78+
assert(it.sizeCompare(f2(1, 2, 3, 4)) < 0)
79+
}
80+
81+
// factories for `Seq`s with known and unknown size
82+
val known: IterableFactory[IndexedSeq] = Vector
83+
val unknown: IterableFactory[LinearSeq] = List
84+
85+
check(known, known)
86+
check(known, unknown)
87+
check(unknown, known)
88+
check(unknown, unknown)
89+
}
90+
6191
@Test def copyToArray(): Unit = {
6292
def check(a: Array[Int], start: Int, end: Int) = {
6393
var i = 0

0 commit comments

Comments
 (0)