Skip to content

Commit 6cad754

Browse files
committed
Auto merge of #68358 - matthewjasper:spec-fix, r=nikomatsakis
Remove some unsound specializations This removes the unsound and exploitable specializations in the standard library * The `PartialEq` and `Hash` implementations for `RangeInclusive` are changed to avoid specialization. * The `PartialOrd` specialization for slices now specializes on a limited set of concrete types. * Added some tests for the soundness problems.
2 parents 8498c5f + a81c59f commit 6cad754

File tree

7 files changed

+168
-57
lines changed

7 files changed

+168
-57
lines changed

Diff for: src/libcore/iter/range.rs

+4
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,14 @@ impl<A: Step> Iterator for ops::RangeInclusive<A> {
385385
}
386386
Some(Equal) => {
387387
self.is_empty = Some(true);
388+
self.start = plus_n.clone();
388389
return Some(plus_n);
389390
}
390391
_ => {}
391392
}
392393
}
393394

395+
self.start = self.end.clone();
394396
self.is_empty = Some(true);
395397
None
396398
}
@@ -477,12 +479,14 @@ impl<A: Step> DoubleEndedIterator for ops::RangeInclusive<A> {
477479
}
478480
Some(Equal) => {
479481
self.is_empty = Some(true);
482+
self.end = minus_n.clone();
480483
return Some(minus_n);
481484
}
482485
_ => {}
483486
}
484487
}
485488

489+
self.end = self.start.clone();
486490
self.is_empty = Some(true);
487491
None
488492
}

Diff for: src/libcore/ops/range.rs

+15-23
Original file line numberDiff line numberDiff line change
@@ -343,38 +343,21 @@ pub struct RangeInclusive<Idx> {
343343
pub(crate) is_empty: Option<bool>,
344344
// This field is:
345345
// - `None` when next() or next_back() was never called
346-
// - `Some(false)` when `start <= end` assuming no overflow
347-
// - `Some(true)` otherwise
346+
// - `Some(false)` when `start < end`
347+
// - `Some(true)` when `end < start`
348+
// - `Some(false)` when `start == end` and the range hasn't yet completed iteration
349+
// - `Some(true)` when `start == end` and the range has completed iteration
348350
// The field cannot be a simple `bool` because the `..=` constructor can
349351
// accept non-PartialOrd types, also we want the constructor to be const.
350352
}
351353

352-
trait RangeInclusiveEquality: Sized {
353-
fn canonicalized_is_empty(range: &RangeInclusive<Self>) -> bool;
354-
}
355-
356-
impl<T> RangeInclusiveEquality for T {
357-
#[inline]
358-
default fn canonicalized_is_empty(range: &RangeInclusive<Self>) -> bool {
359-
range.is_empty.unwrap_or_default()
360-
}
361-
}
362-
363-
impl<T: PartialOrd> RangeInclusiveEquality for T {
364-
#[inline]
365-
fn canonicalized_is_empty(range: &RangeInclusive<Self>) -> bool {
366-
range.is_empty()
367-
}
368-
}
369-
370354
#[stable(feature = "inclusive_range", since = "1.26.0")]
371355
impl<Idx: PartialEq> PartialEq for RangeInclusive<Idx> {
372356
#[inline]
373357
fn eq(&self, other: &Self) -> bool {
374358
self.start == other.start
375359
&& self.end == other.end
376-
&& RangeInclusiveEquality::canonicalized_is_empty(self)
377-
== RangeInclusiveEquality::canonicalized_is_empty(other)
360+
&& self.is_exhausted() == other.is_exhausted()
378361
}
379362
}
380363

@@ -386,7 +369,8 @@ impl<Idx: Hash> Hash for RangeInclusive<Idx> {
386369
fn hash<H: Hasher>(&self, state: &mut H) {
387370
self.start.hash(state);
388371
self.end.hash(state);
389-
RangeInclusiveEquality::canonicalized_is_empty(self).hash(state);
372+
// Ideally we would hash `is_exhausted` here as well, but there's no
373+
// way for us to call it.
390374
}
391375
}
392376

@@ -485,6 +469,14 @@ impl<Idx: fmt::Debug> fmt::Debug for RangeInclusive<Idx> {
485469
}
486470
}
487471

472+
impl<Idx: PartialEq<Idx>> RangeInclusive<Idx> {
473+
// Returns true if this is a range that started non-empty, and was iterated
474+
// to exhaustion.
475+
fn is_exhausted(&self) -> bool {
476+
Some(true) == self.is_empty && self.start == self.end
477+
}
478+
}
479+
488480
impl<Idx: PartialOrd<Idx>> RangeInclusive<Idx> {
489481
/// Returns `true` if `item` is contained in the range.
490482
///

Diff for: src/libcore/slice/mod.rs

+51-29
Original file line numberDiff line numberDiff line change
@@ -5584,21 +5584,18 @@ where
55845584

55855585
#[doc(hidden)]
55865586
// intermediate trait for specialization of slice's PartialOrd
5587-
trait SlicePartialOrd<B> {
5588-
fn partial_compare(&self, other: &[B]) -> Option<Ordering>;
5587+
trait SlicePartialOrd: Sized {
5588+
fn partial_compare(left: &[Self], right: &[Self]) -> Option<Ordering>;
55895589
}
55905590

5591-
impl<A> SlicePartialOrd<A> for [A]
5592-
where
5593-
A: PartialOrd,
5594-
{
5595-
default fn partial_compare(&self, other: &[A]) -> Option<Ordering> {
5596-
let l = cmp::min(self.len(), other.len());
5591+
impl<A: PartialOrd> SlicePartialOrd for A {
5592+
default fn partial_compare(left: &[A], right: &[A]) -> Option<Ordering> {
5593+
let l = cmp::min(left.len(), right.len());
55975594

55985595
// Slice to the loop iteration range to enable bound check
55995596
// elimination in the compiler
5600-
let lhs = &self[..l];
5601-
let rhs = &other[..l];
5597+
let lhs = &left[..l];
5598+
let rhs = &right[..l];
56025599

56035600
for i in 0..l {
56045601
match lhs[i].partial_cmp(&rhs[i]) {
@@ -5607,36 +5604,61 @@ where
56075604
}
56085605
}
56095606

5610-
self.len().partial_cmp(&other.len())
5607+
left.len().partial_cmp(&right.len())
56115608
}
56125609
}
56135610

5614-
impl<A> SlicePartialOrd<A> for [A]
5611+
// This is the impl that we would like to have. Unfortunately it's not sound.
5612+
// See `partial_ord_slice.rs`.
5613+
/*
5614+
impl<A> SlicePartialOrd for A
56155615
where
56165616
A: Ord,
56175617
{
5618-
default fn partial_compare(&self, other: &[A]) -> Option<Ordering> {
5619-
Some(SliceOrd::compare(self, other))
5618+
default fn partial_compare(left: &[A], right: &[A]) -> Option<Ordering> {
5619+
Some(SliceOrd::compare(left, right))
5620+
}
5621+
}
5622+
*/
5623+
5624+
impl<A: AlwaysApplicableOrd> SlicePartialOrd for A {
5625+
fn partial_compare(left: &[A], right: &[A]) -> Option<Ordering> {
5626+
Some(SliceOrd::compare(left, right))
5627+
}
5628+
}
5629+
5630+
trait AlwaysApplicableOrd: SliceOrd + Ord {}
5631+
5632+
macro_rules! always_applicable_ord {
5633+
($([$($p:tt)*] $t:ty,)*) => {
5634+
$(impl<$($p)*> AlwaysApplicableOrd for $t {})*
56205635
}
56215636
}
56225637

5638+
always_applicable_ord! {
5639+
[] u8, [] u16, [] u32, [] u64, [] u128, [] usize,
5640+
[] i8, [] i16, [] i32, [] i64, [] i128, [] isize,
5641+
[] bool, [] char,
5642+
[T: ?Sized] *const T, [T: ?Sized] *mut T,
5643+
[T: AlwaysApplicableOrd] &T,
5644+
[T: AlwaysApplicableOrd] &mut T,
5645+
[T: AlwaysApplicableOrd] Option<T>,
5646+
}
5647+
56235648
#[doc(hidden)]
56245649
// intermediate trait for specialization of slice's Ord
5625-
trait SliceOrd<B> {
5626-
fn compare(&self, other: &[B]) -> Ordering;
5650+
trait SliceOrd: Sized {
5651+
fn compare(left: &[Self], right: &[Self]) -> Ordering;
56275652
}
56285653

5629-
impl<A> SliceOrd<A> for [A]
5630-
where
5631-
A: Ord,
5632-
{
5633-
default fn compare(&self, other: &[A]) -> Ordering {
5634-
let l = cmp::min(self.len(), other.len());
5654+
impl<A: Ord> SliceOrd for A {
5655+
default fn compare(left: &[Self], right: &[Self]) -> Ordering {
5656+
let l = cmp::min(left.len(), right.len());
56355657

56365658
// Slice to the loop iteration range to enable bound check
56375659
// elimination in the compiler
5638-
let lhs = &self[..l];
5639-
let rhs = &other[..l];
5660+
let lhs = &left[..l];
5661+
let rhs = &right[..l];
56405662

56415663
for i in 0..l {
56425664
match lhs[i].cmp(&rhs[i]) {
@@ -5645,19 +5667,19 @@ where
56455667
}
56465668
}
56475669

5648-
self.len().cmp(&other.len())
5670+
left.len().cmp(&right.len())
56495671
}
56505672
}
56515673

56525674
// memcmp compares a sequence of unsigned bytes lexicographically.
56535675
// this matches the order we want for [u8], but no others (not even [i8]).
5654-
impl SliceOrd<u8> for [u8] {
5676+
impl SliceOrd for u8 {
56555677
#[inline]
5656-
fn compare(&self, other: &[u8]) -> Ordering {
5678+
fn compare(left: &[Self], right: &[Self]) -> Ordering {
56575679
let order =
5658-
unsafe { memcmp(self.as_ptr(), other.as_ptr(), cmp::min(self.len(), other.len())) };
5680+
unsafe { memcmp(left.as_ptr(), right.as_ptr(), cmp::min(left.len(), right.len())) };
56595681
if order == 0 {
5660-
self.len().cmp(&other.len())
5682+
left.len().cmp(&right.len())
56615683
} else if order < 0 {
56625684
Less
56635685
} else {

Diff for: src/libcore/str/mod.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use self::pattern::{DoubleEndedSearcher, ReverseSearcher, SearchStep, Searcher};
1212
use crate::char;
1313
use crate::fmt::{self, Write};
1414
use crate::iter::{Chain, FlatMap, Flatten};
15-
use crate::iter::{Cloned, Filter, FusedIterator, Map, TrustedLen, TrustedRandomAccess};
15+
use crate::iter::{Copied, Filter, FusedIterator, Map, TrustedLen, TrustedRandomAccess};
1616
use crate::mem;
1717
use crate::ops::Try;
1818
use crate::option;
@@ -750,7 +750,7 @@ impl<'a> CharIndices<'a> {
750750
/// [`str`]: ../../std/primitive.str.html
751751
#[stable(feature = "rust1", since = "1.0.0")]
752752
#[derive(Clone, Debug)]
753-
pub struct Bytes<'a>(Cloned<slice::Iter<'a, u8>>);
753+
pub struct Bytes<'a>(Copied<slice::Iter<'a, u8>>);
754754

755755
#[stable(feature = "rust1", since = "1.0.0")]
756756
impl Iterator for Bytes<'_> {
@@ -2778,7 +2778,7 @@ impl str {
27782778
#[stable(feature = "rust1", since = "1.0.0")]
27792779
#[inline]
27802780
pub fn bytes(&self) -> Bytes<'_> {
2781-
Bytes(self.as_bytes().iter().cloned())
2781+
Bytes(self.as_bytes().iter().copied())
27822782
}
27832783

27842784
/// Splits a string slice by whitespace.
@@ -3895,7 +3895,7 @@ impl str {
38953895
debug_assert_eq!(
38963896
start, 0,
38973897
"The first search step from Searcher \
3898-
must include the first character"
3898+
must include the first character"
38993899
);
39003900
// SAFETY: `Searcher` is known to return valid indices.
39013901
unsafe { Some(self.get_unchecked(len..)) }
@@ -3934,7 +3934,7 @@ impl str {
39343934
end,
39353935
self.len(),
39363936
"The first search step from ReverseSearcher \
3937-
must include the last character"
3937+
must include the last character"
39383938
);
39393939
// SAFETY: `Searcher` is known to return valid indices.
39403940
unsafe { Some(self.get_unchecked(..start)) }

Diff for: src/libcore/tests/iter.rs

+16
Original file line numberDiff line numberDiff line change
@@ -1956,11 +1956,19 @@ fn test_range_inclusive_exhaustion() {
19561956
assert_eq!(r.next(), None);
19571957
assert_eq!(r.next(), None);
19581958

1959+
assert_eq!(*r.start(), 10);
1960+
assert_eq!(*r.end(), 10);
1961+
assert_ne!(r, 10..=10);
1962+
19591963
let mut r = 10..=10;
19601964
assert_eq!(r.next_back(), Some(10));
19611965
assert!(r.is_empty());
19621966
assert_eq!(r.next_back(), None);
19631967

1968+
assert_eq!(*r.start(), 10);
1969+
assert_eq!(*r.end(), 10);
1970+
assert_ne!(r, 10..=10);
1971+
19641972
let mut r = 10..=12;
19651973
assert_eq!(r.next(), Some(10));
19661974
assert_eq!(r.next(), Some(11));
@@ -2078,6 +2086,9 @@ fn test_range_inclusive_nth() {
20782086
assert_eq!((10..=15).nth(5), Some(15));
20792087
assert_eq!((10..=15).nth(6), None);
20802088

2089+
let mut exhausted_via_next = 10_u8..=20;
2090+
while exhausted_via_next.next().is_some() {}
2091+
20812092
let mut r = 10_u8..=20;
20822093
assert_eq!(r.nth(2), Some(12));
20832094
assert_eq!(r, 13..=20);
@@ -2087,6 +2098,7 @@ fn test_range_inclusive_nth() {
20872098
assert_eq!(ExactSizeIterator::is_empty(&r), false);
20882099
assert_eq!(r.nth(10), None);
20892100
assert_eq!(r.is_empty(), true);
2101+
assert_eq!(r, exhausted_via_next);
20902102
assert_eq!(ExactSizeIterator::is_empty(&r), true);
20912103
}
20922104

@@ -2098,6 +2110,9 @@ fn test_range_inclusive_nth_back() {
20982110
assert_eq!((10..=15).nth_back(6), None);
20992111
assert_eq!((-120..=80_i8).nth_back(200), Some(-120));
21002112

2113+
let mut exhausted_via_next_back = 10_u8..=20;
2114+
while exhausted_via_next_back.next_back().is_some() {}
2115+
21012116
let mut r = 10_u8..=20;
21022117
assert_eq!(r.nth_back(2), Some(18));
21032118
assert_eq!(r, 10..=17);
@@ -2107,6 +2122,7 @@ fn test_range_inclusive_nth_back() {
21072122
assert_eq!(ExactSizeIterator::is_empty(&r), false);
21082123
assert_eq!(r.nth_back(10), None);
21092124
assert_eq!(r.is_empty(), true);
2125+
assert_eq!(r, exhausted_via_next_back);
21102126
assert_eq!(ExactSizeIterator::is_empty(&r), true);
21112127
}
21122128

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// run-pass
2+
3+
use std::cell::RefCell;
4+
use std::cmp::Ordering;
5+
6+
struct Evil<'a, 'b> {
7+
values: RefCell<Vec<&'a str>>,
8+
to_insert: &'b String,
9+
}
10+
11+
impl<'a, 'b> PartialEq for Evil<'a, 'b> {
12+
fn eq(&self, _other: &Self) -> bool {
13+
true
14+
}
15+
}
16+
17+
impl<'a> PartialOrd for Evil<'a, 'a> {
18+
fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
19+
self.values.borrow_mut().push(self.to_insert);
20+
None
21+
}
22+
}
23+
24+
fn main() {
25+
let e;
26+
let values;
27+
{
28+
let to_insert = String::from("Hello, world!");
29+
e = Evil { values: RefCell::new(Vec::new()), to_insert: &to_insert };
30+
let range = &e..=&e;
31+
let _ = range == range;
32+
values = e.values;
33+
}
34+
assert_eq!(*values.borrow(), Vec::<&str>::new());
35+
}

0 commit comments

Comments
 (0)