Skip to content

Commit be1d5e3

Browse files
authored
[red-knot] Add Type::bool and boolean expression inference (#13449)
1 parent 03503f7 commit be1d5e3

File tree

2 files changed

+252
-14
lines changed

2 files changed

+252
-14
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,54 @@ impl<'db> Type<'db> {
521521
}
522522
}
523523

524+
/// Resolves the boolean value of a type.
525+
///
526+
/// This is used to determine the value that would be returned
527+
/// when `bool(x)` is called on an object `x`.
528+
fn bool(&self, db: &'db dyn Db) -> Truthiness {
529+
match self {
530+
Type::Any | Type::Never | Type::Unknown | Type::Unbound => Truthiness::Ambiguous,
531+
Type::None => Truthiness::AlwaysFalse,
532+
Type::Function(_) | Type::RevealTypeFunction(_) => Truthiness::AlwaysTrue,
533+
Type::Module(_) => Truthiness::AlwaysTrue,
534+
Type::Class(_) => {
535+
// TODO: lookup `__bool__` and `__len__` methods on the class's metaclass
536+
// More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing
537+
Truthiness::Ambiguous
538+
}
539+
Type::Instance(_) => {
540+
// TODO: lookup `__bool__` and `__len__` methods on the instance's class
541+
// More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing
542+
Truthiness::Ambiguous
543+
}
544+
Type::Union(union) => {
545+
let union_elements = union.elements(db);
546+
let first_element_truthiness = union_elements[0].bool(db);
547+
if first_element_truthiness.is_ambiguous() {
548+
return Truthiness::Ambiguous;
549+
}
550+
if !union_elements
551+
.iter()
552+
.skip(1)
553+
.all(|element| element.bool(db) == first_element_truthiness)
554+
{
555+
return Truthiness::Ambiguous;
556+
}
557+
first_element_truthiness
558+
}
559+
Type::Intersection(_) => {
560+
// TODO
561+
Truthiness::Ambiguous
562+
}
563+
Type::IntLiteral(num) => Truthiness::from(*num != 0),
564+
Type::BooleanLiteral(bool) => Truthiness::from(*bool),
565+
Type::StringLiteral(str) => Truthiness::from(!str.value(db).is_empty()),
566+
Type::LiteralString => Truthiness::Ambiguous,
567+
Type::BytesLiteral(bytes) => Truthiness::from(!bytes.value(db).is_empty()),
568+
Type::Tuple(items) => Truthiness::from(!items.elements(db).is_empty()),
569+
}
570+
}
571+
524572
/// Return the type resulting from calling an object of this type.
525573
///
526574
/// Returns `None` if `self` is not a callable type.
@@ -873,6 +921,50 @@ impl<'db> IterationOutcome<'db> {
873921
}
874922
}
875923

924+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
925+
enum Truthiness {
926+
/// For an object `x`, `bool(x)` will always return `True`
927+
AlwaysTrue,
928+
/// For an object `x`, `bool(x)` will always return `False`
929+
AlwaysFalse,
930+
/// For an object `x`, `bool(x)` could return either `True` or `False`
931+
Ambiguous,
932+
}
933+
934+
impl Truthiness {
935+
const fn is_ambiguous(self) -> bool {
936+
matches!(self, Truthiness::Ambiguous)
937+
}
938+
939+
#[allow(unused)]
940+
const fn negate(self) -> Self {
941+
match self {
942+
Self::AlwaysTrue => Self::AlwaysFalse,
943+
Self::AlwaysFalse => Self::AlwaysTrue,
944+
Self::Ambiguous => Self::Ambiguous,
945+
}
946+
}
947+
948+
#[allow(unused)]
949+
fn into_type(self, db: &dyn Db) -> Type {
950+
match self {
951+
Self::AlwaysTrue => Type::BooleanLiteral(true),
952+
Self::AlwaysFalse => Type::BooleanLiteral(false),
953+
Self::Ambiguous => builtins_symbol_ty(db, "bool").to_instance(db),
954+
}
955+
}
956+
}
957+
958+
impl From<bool> for Truthiness {
959+
fn from(value: bool) -> Self {
960+
if value {
961+
Truthiness::AlwaysTrue
962+
} else {
963+
Truthiness::AlwaysFalse
964+
}
965+
}
966+
}
967+
876968
#[salsa::interned]
877969
pub struct FunctionType<'db> {
878970
/// name of the function at definition
@@ -1075,7 +1167,10 @@ pub struct TupleType<'db> {
10751167

10761168
#[cfg(test)]
10771169
mod tests {
1078-
use super::{builtins_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionType};
1170+
use super::{
1171+
builtins_symbol_ty, BytesLiteralType, StringLiteralType, Truthiness, TupleType, Type,
1172+
UnionType,
1173+
};
10791174
use crate::db::tests::TestDb;
10801175
use crate::program::{Program, SearchPathSettings};
10811176
use crate::python_version::PythonVersion;
@@ -1116,6 +1211,7 @@ mod tests {
11161211
BytesLiteral(&'static str),
11171212
BuiltinInstance(&'static str),
11181213
Union(Vec<Ty>),
1214+
Tuple(Vec<Ty>),
11191215
}
11201216

11211217
impl Ty {
@@ -1136,6 +1232,10 @@ mod tests {
11361232
Ty::Union(tys) => {
11371233
UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db)))
11381234
}
1235+
Ty::Tuple(tys) => {
1236+
let elements = tys.into_iter().map(|ty| ty.into_type(db)).collect();
1237+
Type::Tuple(TupleType::new(db, elements))
1238+
}
11391239
}
11401240
}
11411241
}
@@ -1205,4 +1305,32 @@ mod tests {
12051305

12061306
assert!(from.into_type(&db).is_equivalent_to(&db, to.into_type(&db)));
12071307
}
1308+
1309+
#[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")]
1310+
#[test_case(Ty::IntLiteral(-1))]
1311+
#[test_case(Ty::StringLiteral("foo"))]
1312+
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(0)]))]
1313+
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]))]
1314+
fn is_truthy(ty: Ty) {
1315+
let db = setup_db();
1316+
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysTrue);
1317+
}
1318+
1319+
#[test_case(Ty::Tuple(vec![]))]
1320+
#[test_case(Ty::IntLiteral(0))]
1321+
#[test_case(Ty::StringLiteral(""))]
1322+
#[test_case(Ty::Union(vec![Ty::IntLiteral(0), Ty::IntLiteral(0)]))]
1323+
fn is_falsy(ty: Ty) {
1324+
let db = setup_db();
1325+
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysFalse);
1326+
}
1327+
1328+
#[test_case(Ty::BuiltinInstance("str"))]
1329+
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(0)]))]
1330+
#[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(0)]))]
1331+
#[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(1)]))]
1332+
fn boolean_value_is_unknown(ty: Ty) {
1333+
let db = setup_db();
1334+
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous);
1335+
}
12081336
}

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 123 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,13 @@
2828
//! definitions once the rest of the types in the scope have been inferred.
2929
use std::num::NonZeroU32;
3030

31-
use rustc_hash::FxHashMap;
32-
use salsa;
33-
use salsa::plumbing::AsId;
34-
3531
use ruff_db::files::File;
3632
use ruff_db::parsed::parsed_module;
3733
use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp};
3834
use ruff_text_size::Ranged;
35+
use rustc_hash::FxHashMap;
36+
use salsa;
37+
use salsa::plumbing::AsId;
3938

4039
use crate::module_name::ModuleName;
4140
use crate::module_resolver::{file_to_module, resolve_module};
@@ -52,7 +51,7 @@ use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
5251
use crate::types::{
5352
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
5453
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, StringLiteralType,
55-
TupleType, Type, TypeArrayDisplay, UnionType,
54+
Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
5655
};
5756
use crate::Db;
5857

@@ -2318,16 +2317,35 @@ impl<'db> TypeInferenceBuilder<'db> {
23182317
fn infer_boolean_expression(&mut self, bool_op: &ast::ExprBoolOp) -> Type<'db> {
23192318
let ast::ExprBoolOp {
23202319
range: _,
2321-
op: _,
2320+
op,
23222321
values,
23232322
} = bool_op;
2324-
2325-
for value in values {
2326-
self.infer_expression(value);
2327-
}
2328-
2329-
// TODO resolve bool op
2330-
Type::Unknown
2323+
let mut done = false;
2324+
UnionType::from_elements(
2325+
self.db,
2326+
values.iter().enumerate().map(|(i, value)| {
2327+
// We need to infer the type of every expression (that's an invariant maintained by
2328+
// type inference), even if we can short-circuit boolean evaluation of some of
2329+
// those types.
2330+
let value_ty = self.infer_expression(value);
2331+
if done {
2332+
Type::Never
2333+
} else {
2334+
let is_last = i == values.len() - 1;
2335+
match (value_ty.bool(self.db), is_last, op) {
2336+
(Truthiness::Ambiguous, _, _) => value_ty,
2337+
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
2338+
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
2339+
(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
2340+
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
2341+
done = true;
2342+
value_ty
2343+
}
2344+
(_, true, _) => value_ty,
2345+
}
2346+
}
2347+
}),
2348+
)
23312349
}
23322350

23332351
fn infer_compare_expression(&mut self, compare: &ast::ExprCompare) -> Type<'db> {
@@ -6048,4 +6066,96 @@ mod tests {
60486066
);
60496067
Ok(())
60506068
}
6069+
6070+
#[test]
6071+
fn boolean_or_expression() -> anyhow::Result<()> {
6072+
let mut db = setup_db();
6073+
6074+
db.write_dedented(
6075+
"/src/a.py",
6076+
"
6077+
def foo() -> str:
6078+
pass
6079+
6080+
a = True or False
6081+
b = 'x' or 'y' or 'z'
6082+
c = '' or 'y' or 'z'
6083+
d = False or 'z'
6084+
e = False or True
6085+
f = False or False
6086+
g = foo() or False
6087+
h = foo() or True
6088+
",
6089+
)?;
6090+
6091+
assert_public_ty(&db, "/src/a.py", "a", "Literal[True]");
6092+
assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#);
6093+
assert_public_ty(&db, "/src/a.py", "c", r#"Literal["y"]"#);
6094+
assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#);
6095+
assert_public_ty(&db, "/src/a.py", "e", "Literal[True]");
6096+
assert_public_ty(&db, "/src/a.py", "f", "Literal[False]");
6097+
assert_public_ty(&db, "/src/a.py", "g", "str | Literal[False]");
6098+
assert_public_ty(&db, "/src/a.py", "h", "str | Literal[True]");
6099+
6100+
Ok(())
6101+
}
6102+
6103+
#[test]
6104+
fn boolean_and_expression() -> anyhow::Result<()> {
6105+
let mut db = setup_db();
6106+
6107+
db.write_dedented(
6108+
"/src/a.py",
6109+
"
6110+
def foo() -> str:
6111+
pass
6112+
6113+
a = True and False
6114+
b = False and True
6115+
c = foo() and False
6116+
d = foo() and True
6117+
e = 'x' and 'y' and 'z'
6118+
f = 'x' and 'y' and ''
6119+
g = '' and 'y'
6120+
",
6121+
)?;
6122+
6123+
assert_public_ty(&db, "/src/a.py", "a", "Literal[False]");
6124+
assert_public_ty(&db, "/src/a.py", "b", "Literal[False]");
6125+
assert_public_ty(&db, "/src/a.py", "c", "str | Literal[False]");
6126+
assert_public_ty(&db, "/src/a.py", "d", "str | Literal[True]");
6127+
assert_public_ty(&db, "/src/a.py", "e", r#"Literal["z"]"#);
6128+
assert_public_ty(&db, "/src/a.py", "f", r#"Literal[""]"#);
6129+
assert_public_ty(&db, "/src/a.py", "g", r#"Literal[""]"#);
6130+
Ok(())
6131+
}
6132+
6133+
#[test]
6134+
fn boolean_complex_expression() -> anyhow::Result<()> {
6135+
let mut db = setup_db();
6136+
6137+
db.write_dedented(
6138+
"/src/a.py",
6139+
r#"
6140+
def foo() -> str:
6141+
pass
6142+
6143+
a = "x" and "y" or "z"
6144+
b = "x" or "y" and "z"
6145+
c = "" and "y" or "z"
6146+
d = "" or "y" and "z"
6147+
e = "x" and "y" or ""
6148+
f = "x" or "y" and ""
6149+
6150+
"#,
6151+
)?;
6152+
6153+
assert_public_ty(&db, "/src/a.py", "a", r#"Literal["y"]"#);
6154+
assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#);
6155+
assert_public_ty(&db, "/src/a.py", "c", r#"Literal["z"]"#);
6156+
assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#);
6157+
assert_public_ty(&db, "/src/a.py", "e", r#"Literal["y"]"#);
6158+
assert_public_ty(&db, "/src/a.py", "f", r#"Literal["x"]"#);
6159+
Ok(())
6160+
}
60516161
}

0 commit comments

Comments
 (0)