Skip to content

Commit af988bf

Browse files
authored
[red-knot] Detect division-by-zero in unions and intersections (#17157)
## Summary With this PR, we emit a diagnostic for this case where previously didn't: ```py from typing import Literal def f(m: int, n: Literal[-1, 0, 1]): # error: [division-by-zero] "Cannot divide object of type `int` by zero" return m / n ``` ## Test Plan New Markdown test
1 parent f989c2c commit af988bf

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

crates/red_knot_python_semantic/resources/mdtest/binary/unions.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,11 @@ def f4(x: float, y: float):
4949
reveal_type(x // y) # revealed: int | float
5050
reveal_type(x % y) # revealed: int | float
5151
```
52+
53+
If any of the union elements leads to a division by zero, we will report an error:
54+
55+
```py
56+
def f5(m: int, n: Literal[-1, 0, 1]):
57+
# error: [division-by-zero] "Cannot divide object of type `int` by zero"
58+
return m / n
59+
```

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -971,32 +971,39 @@ impl<'db> TypeInferenceBuilder<'db> {
971971
/// Raise a diagnostic if the given type cannot be divided by zero.
972972
///
973973
/// Expects the resolved type of the left side of the binary expression.
974-
fn check_division_by_zero(&mut self, expr: &ast::ExprBinOp, left: Type<'db>) {
974+
fn check_division_by_zero(
975+
&mut self,
976+
node: AnyNodeRef<'_>,
977+
op: ast::Operator,
978+
left: Type<'db>,
979+
) -> bool {
975980
match left {
976981
Type::BooleanLiteral(_) | Type::IntLiteral(_) => {}
977982
Type::Instance(instance)
978983
if matches!(
979984
instance.class().known(self.db()),
980985
Some(KnownClass::Float | KnownClass::Int | KnownClass::Bool)
981986
) => {}
982-
_ => return,
987+
_ => return false,
983988
};
984989

985-
let (op, by_zero) = match expr.op {
990+
let (op, by_zero) = match op {
986991
ast::Operator::Div => ("divide", "by zero"),
987992
ast::Operator::FloorDiv => ("floor divide", "by zero"),
988993
ast::Operator::Mod => ("reduce", "modulo zero"),
989-
_ => return,
994+
_ => return false,
990995
};
991996

992997
self.context.report_lint(
993998
&DIVISION_BY_ZERO,
994-
expr,
999+
node,
9951000
format_args!(
9961001
"Cannot {op} object of type `{}` {by_zero}",
9971002
left.display(self.db())
9981003
),
9991004
);
1005+
1006+
true
10001007
}
10011008

10021009
fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) {
@@ -2858,7 +2865,7 @@ impl<'db> TypeInferenceBuilder<'db> {
28582865

28592866
// Fall back to non-augmented binary operator inference.
28602867
let mut binary_return_ty = || {
2861-
self.infer_binary_expression_type(target_type, value_type, op)
2868+
self.infer_binary_expression_type(assignment.into(), false, target_type, value_type, op)
28622869
.unwrap_or_else(|| {
28632870
report_unsupported_augmented_op(&mut self.context);
28642871
Type::unknown()
@@ -4495,19 +4502,7 @@ impl<'db> TypeInferenceBuilder<'db> {
44954502
let left_ty = self.infer_expression(left);
44964503
let right_ty = self.infer_expression(right);
44974504

4498-
// Check for division by zero; this doesn't change the inferred type for the expression, but
4499-
// may emit a diagnostic
4500-
if matches!(
4501-
(op, right_ty),
4502-
(
4503-
ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod,
4504-
Type::IntLiteral(0) | Type::BooleanLiteral(false)
4505-
)
4506-
) {
4507-
self.check_division_by_zero(binary, left_ty);
4508-
}
4509-
4510-
self.infer_binary_expression_type(left_ty, right_ty, *op)
4505+
self.infer_binary_expression_type(binary.into(), false, left_ty, right_ty, *op)
45114506
.unwrap_or_else(|| {
45124507
self.context.report_lint(
45134508
&UNSUPPORTED_OPERATOR,
@@ -4524,23 +4519,51 @@ impl<'db> TypeInferenceBuilder<'db> {
45244519

45254520
fn infer_binary_expression_type(
45264521
&mut self,
4522+
node: AnyNodeRef<'_>,
4523+
mut emitted_division_by_zero_diagnostic: bool,
45274524
left_ty: Type<'db>,
45284525
right_ty: Type<'db>,
45294526
op: ast::Operator,
45304527
) -> Option<Type<'db>> {
4528+
// Check for division by zero; this doesn't change the inferred type for the expression, but
4529+
// may emit a diagnostic
4530+
if !emitted_division_by_zero_diagnostic
4531+
&& matches!(
4532+
(op, right_ty),
4533+
(
4534+
ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod,
4535+
Type::IntLiteral(0) | Type::BooleanLiteral(false)
4536+
)
4537+
)
4538+
{
4539+
emitted_division_by_zero_diagnostic = self.check_division_by_zero(node, op, left_ty);
4540+
}
4541+
45314542
match (left_ty, right_ty, op) {
45324543
(Type::Union(lhs_union), rhs, _) => {
45334544
let mut union = UnionBuilder::new(self.db());
45344545
for lhs in lhs_union.elements(self.db()) {
4535-
let result = self.infer_binary_expression_type(*lhs, rhs, op)?;
4546+
let result = self.infer_binary_expression_type(
4547+
node,
4548+
emitted_division_by_zero_diagnostic,
4549+
*lhs,
4550+
rhs,
4551+
op,
4552+
)?;
45364553
union = union.add(result);
45374554
}
45384555
Some(union.build())
45394556
}
45404557
(lhs, Type::Union(rhs_union), _) => {
45414558
let mut union = UnionBuilder::new(self.db());
45424559
for rhs in rhs_union.elements(self.db()) {
4543-
let result = self.infer_binary_expression_type(lhs, *rhs, op)?;
4560+
let result = self.infer_binary_expression_type(
4561+
node,
4562+
emitted_division_by_zero_diagnostic,
4563+
lhs,
4564+
*rhs,
4565+
op,
4566+
)?;
45444567
union = union.add(result);
45454568
}
45464569
Some(union.build())
@@ -4659,13 +4682,19 @@ impl<'db> TypeInferenceBuilder<'db> {
46594682
}
46604683

46614684
(Type::BooleanLiteral(bool_value), right, op) => self.infer_binary_expression_type(
4685+
node,
4686+
emitted_division_by_zero_diagnostic,
46624687
Type::IntLiteral(i64::from(bool_value)),
46634688
right,
46644689
op,
46654690
),
4666-
(left, Type::BooleanLiteral(bool_value), op) => {
4667-
self.infer_binary_expression_type(left, Type::IntLiteral(i64::from(bool_value)), op)
4668-
}
4691+
(left, Type::BooleanLiteral(bool_value), op) => self.infer_binary_expression_type(
4692+
node,
4693+
emitted_division_by_zero_diagnostic,
4694+
left,
4695+
Type::IntLiteral(i64::from(bool_value)),
4696+
op,
4697+
),
46694698

46704699
(Type::Tuple(lhs), Type::Tuple(rhs), ast::Operator::Add) => {
46714700
// Note: this only works on heterogeneous tuples.

0 commit comments

Comments
 (0)