Skip to content

Commit 2eac041

Browse files
committed
Check ensures on early return due to Try / Yeet
Expand these two expressions to include a call to contract checking
1 parent 5b294f4 commit 2eac041

File tree

3 files changed

+84
-31
lines changed

3 files changed

+84
-31
lines changed

Diff for: compiler/rustc_ast_lowering/src/expr.rs

+31-17
Original file line numberDiff line numberDiff line change
@@ -314,21 +314,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
314314
hir::ExprKind::Continue(self.lower_jump_destination(e.id, *opt_label))
315315
}
316316
ExprKind::Ret(e) => {
317-
let mut e = e.as_ref().map(|x| self.lower_expr(x));
318-
if let Some(Some((span, fresh_ident))) = self
319-
.contract
320-
.as_ref()
321-
.map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
322-
{
323-
let checker_fn = self.expr_ident(span, fresh_ident.0, fresh_ident.2);
324-
let args = if let Some(e) = e {
325-
std::slice::from_ref(e)
326-
} else {
327-
std::slice::from_ref(self.expr_unit(span))
328-
};
329-
e = Some(self.expr_call(span, checker_fn, args));
330-
}
331-
hir::ExprKind::Ret(e)
317+
let expr = e.as_ref().map(|x| self.lower_expr(x));
318+
self.checked_return(expr)
332319
}
333320
ExprKind::Yeet(sub_expr) => self.lower_expr_yeet(e.span, sub_expr.as_deref()),
334321
ExprKind::Become(sub_expr) => {
@@ -395,6 +382,32 @@ impl<'hir> LoweringContext<'_, 'hir> {
395382
})
396383
}
397384

385+
/// Create an `ExprKind::Ret` that is preceded by a call to check contract ensures clause.
386+
fn checked_return(&mut self, opt_expr: Option<&'hir hir::Expr<'hir>>) -> hir::ExprKind<'hir> {
387+
let checked_ret = if let Some(Some((span, fresh_ident))) =
388+
self.contract.as_ref().map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
389+
{
390+
let expr = opt_expr.unwrap_or_else(|| self.expr_unit(span));
391+
Some(self.inject_ensures_check(expr, span, fresh_ident.0, fresh_ident.2))
392+
} else {
393+
opt_expr
394+
};
395+
hir::ExprKind::Ret(checked_ret)
396+
}
397+
398+
/// Wraps an expression with a call to the ensures check before it gets returned.
399+
pub(crate) fn inject_ensures_check(
400+
&mut self,
401+
expr: &'hir hir::Expr<'hir>,
402+
span: Span,
403+
check_ident: Ident,
404+
check_hir_id: HirId,
405+
) -> &'hir hir::Expr<'hir> {
406+
let checker_fn = self.expr_ident(span, check_ident, check_hir_id);
407+
let span = self.mark_span_with_reason(DesugaringKind::Contract, span, None);
408+
self.expr_call(span, checker_fn, std::slice::from_ref(expr))
409+
}
410+
398411
pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock {
399412
self.with_new_scopes(c.value.span, |this| {
400413
let def_id = this.local_def_id(c.id);
@@ -1983,7 +1996,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
19831996
),
19841997
))
19851998
} else {
1986-
self.arena.alloc(self.expr(try_span, hir::ExprKind::Ret(Some(from_residual_expr))))
1999+
let ret_expr = self.checked_return(Some(from_residual_expr));
2000+
self.arena.alloc(self.expr(try_span, ret_expr))
19872001
};
19882002
self.lower_attrs(ret_expr.hir_id, &attrs);
19892003

@@ -2032,7 +2046,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
20322046
let target_id = Ok(catch_id);
20332047
hir::ExprKind::Break(hir::Destination { label: None, target_id }, Some(from_yeet_expr))
20342048
} else {
2035-
hir::ExprKind::Ret(Some(from_yeet_expr))
2049+
self.checked_return(Some(from_yeet_expr))
20362050
}
20372051
}
20382052

Diff for: compiler/rustc_ast_lowering/src/item.rs

+5-14
Original file line numberDiff line numberDiff line change
@@ -1093,9 +1093,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
10931093

10941094
// { body }
10951095
// ==>
1096-
// { rustc_contract_requires(PRECOND); { body } }
1096+
// { contract_requires(PRECOND); { body } }
10971097
let Some(contract) = opt_contract else { return (params, result) };
1098-
1098+
let result_ref = this.arena.alloc(result);
10991099
let lit_unit = |this: &mut LoweringContext<'_, 'hir>| {
11001100
this.expr(contract.span, hir::ExprKind::Tup(&[]))
11011101
};
@@ -1130,26 +1130,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
11301130
this.arena.alloc(checker_binding_pat),
11311131
hir::LocalSource::Contract,
11321132
),
1133-
{
1134-
let checker_fn = this.expr_ident(ens.span, fresh_ident.0, fresh_ident.2);
1135-
let span =
1136-
this.mark_span_with_reason(DesugaringKind::Contract, ens.span, None);
1137-
this.expr_call_mut(
1138-
span,
1139-
checker_fn,
1140-
std::slice::from_ref(this.arena.alloc(result)),
1141-
)
1142-
},
1133+
this.inject_ensures_check(result_ref, ens.span, fresh_ident.0, fresh_ident.2),
11431134
)
11441135
} else {
11451136
let u = lit_unit(this);
1146-
(this.stmt_expr(contract.span, u), result)
1137+
(this.stmt_expr(contract.span, u), &*result_ref)
11471138
};
11481139

11491140
let block = this.block_all(
11501141
contract.span,
11511142
arena_vec![this; precond, postcond_checker],
1152-
Some(this.arena.alloc(result)),
1143+
Some(result),
11531144
);
11541145
(params, this.expr_block(block))
11551146
})
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//@ revisions: unchk_pass chk_pass chk_fail_try chk_fail_ret chk_fail_yeet
2+
//
3+
//@ [unchk_pass] run-pass
4+
//@ [chk_pass] run-pass
5+
//@ [chk_fail_try] run-fail
6+
//@ [chk_fail_ret] run-fail
7+
//@ [chk_fail_yeet] run-fail
8+
//
9+
//@ [unchk_pass] compile-flags: -Zcontract-checks=no
10+
//@ [chk_pass] compile-flags: -Zcontract-checks=yes
11+
//@ [chk_fail_try] compile-flags: -Zcontract-checks=yes
12+
//@ [chk_fail_ret] compile-flags: -Zcontract-checks=yes
13+
//@ [chk_fail_yeet] compile-flags: -Zcontract-checks=yes
14+
//! This test ensures that ensures clauses are checked for different return points of a function.
15+
16+
#![feature(rustc_contracts)]
17+
#![feature(yeet_expr)]
18+
19+
/// This ensures will fail in different return points depending on the input.
20+
#[core::contracts::ensures(|ret: &Option<u32>| ret.is_some())]
21+
fn try_sum(x: u32, y: u32, z: u32) -> Option<u32> {
22+
// Use Yeet to return early.
23+
if x == u32::MAX && (y > 0 || z > 0) { do yeet }
24+
25+
// Use `?` to early return.
26+
let partial = x.checked_add(y)?;
27+
28+
// Explicitly use `return` clause.
29+
if u32::MAX - partial < z {
30+
return None;
31+
}
32+
33+
Some(partial + z)
34+
}
35+
36+
fn main() {
37+
// This should always succeed
38+
assert_eq!(try_sum(0, 1, 2), Some(3));
39+
40+
#[cfg(any(unchk_pass, chk_fail_yeet))]
41+
assert_eq!(try_sum(u32::MAX, 1, 1), None);
42+
43+
#[cfg(any(unchk_pass, chk_fail_try))]
44+
assert_eq!(try_sum(u32::MAX - 10, 12, 0), None);
45+
46+
#[cfg(any(unchk_pass, chk_fail_ret))]
47+
assert_eq!(try_sum(u32::MAX - 10, 2, 100), None);
48+
}

0 commit comments

Comments
 (0)