Skip to content

Commit cf1e91b

Browse files
authored
[red-knot] simplify subtypes from unions (#13401)
Add `Type::is_subtype_of` method, and simplify subtypes out of unions.
1 parent 125eaaf commit cf1e91b

File tree

3 files changed

+82
-9
lines changed

3 files changed

+82
-9
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,16 +388,18 @@ impl<'db> Type<'db> {
388388
}
389389
}
390390

391-
/// Return true if this type is [assignable to] type `target`.
391+
/// Return true if this type is a [subtype of] type `target`.
392392
///
393-
/// [assignable to]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation
394-
pub(crate) fn is_assignable_to(self, db: &'db dyn Db, target: Type<'db>) -> bool {
393+
/// [subtype of]: https://typing.readthedocs.io/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence
394+
pub(crate) fn is_subtype_of(self, db: &'db dyn Db, target: Type<'db>) -> bool {
395395
if self.is_equivalent_to(db, target) {
396396
return true;
397397
}
398398
match (self, target) {
399-
(Type::Unknown | Type::Any | Type::Never, _) => true,
400-
(_, Type::Unknown | Type::Any) => true,
399+
(Type::Unknown | Type::Any, _) => false,
400+
(_, Type::Unknown | Type::Any) => false,
401+
(Type::Never, _) => true,
402+
(_, Type::Never) => false,
401403
(Type::IntLiteral(_), Type::Instance(class))
402404
if class.is_stdlib_symbol(db, "builtins", "int") =>
403405
{
@@ -417,12 +419,28 @@ impl<'db> Type<'db> {
417419
(ty, Type::Union(union)) => union
418420
.elements(db)
419421
.iter()
420-
.any(|&elem_ty| ty.is_assignable_to(db, elem_ty)),
422+
.any(|&elem_ty| ty.is_subtype_of(db, elem_ty)),
421423
// TODO
422424
_ => false,
423425
}
424426
}
425427

428+
/// Return true if this type is [assignable to] type `target`.
429+
///
430+
/// [assignable to]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation
431+
pub(crate) fn is_assignable_to(self, db: &'db dyn Db, target: Type<'db>) -> bool {
432+
match (self, target) {
433+
(Type::Unknown | Type::Any, _) => true,
434+
(_, Type::Unknown | Type::Any) => true,
435+
(ty, Type::Union(union)) => union
436+
.elements(db)
437+
.iter()
438+
.any(|&elem_ty| ty.is_assignable_to(db, elem_ty)),
439+
// TODO other types containing gradual forms (e.g. generics containing Any/Unknown)
440+
_ => self.is_subtype_of(db, target),
441+
}
442+
}
443+
426444
/// Return true if this type is equivalent to type `other`.
427445
pub(crate) fn is_equivalent_to(self, _db: &'db dyn Db, other: Type<'db>) -> bool {
428446
// TODO equivalent but not identical structural types, differently-ordered unions and
@@ -1132,6 +1150,31 @@ mod tests {
11321150
assert!(!from.into_type(&db).is_assignable_to(&db, to.into_type(&db)));
11331151
}
11341152

1153+
#[test_case(Ty::Never, Ty::IntLiteral(1))]
1154+
#[test_case(Ty::IntLiteral(1), Ty::BuiltinInstance("int"))]
1155+
#[test_case(Ty::StringLiteral("foo"), Ty::BuiltinInstance("str"))]
1156+
#[test_case(Ty::StringLiteral("foo"), Ty::LiteralString)]
1157+
#[test_case(Ty::LiteralString, Ty::BuiltinInstance("str"))]
1158+
#[test_case(Ty::BytesLiteral("foo"), Ty::BuiltinInstance("bytes"))]
1159+
#[test_case(Ty::IntLiteral(1), Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")]))]
1160+
fn is_subtype_of(from: Ty, to: Ty) {
1161+
let db = setup_db();
1162+
assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
1163+
}
1164+
1165+
#[test_case(Ty::Unknown, Ty::IntLiteral(1))]
1166+
#[test_case(Ty::Any, Ty::IntLiteral(1))]
1167+
#[test_case(Ty::IntLiteral(1), Ty::Unknown)]
1168+
#[test_case(Ty::IntLiteral(1), Ty::Any)]
1169+
#[test_case(Ty::IntLiteral(1), Ty::Union(vec![Ty::Unknown, Ty::BuiltinInstance("str")]))]
1170+
#[test_case(Ty::IntLiteral(1), Ty::BuiltinInstance("str"))]
1171+
#[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))]
1172+
#[test_case(Ty::BuiltinInstance("int"), Ty::IntLiteral(1))]
1173+
fn is_not_subtype_of(from: Ty, to: Ty) {
1174+
let db = setup_db();
1175+
assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
1176+
}
1177+
11351178
#[test_case(
11361179
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]),
11371180
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)])

crates/red_knot_python_semantic/src/types/builder.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,23 @@ impl<'db> UnionBuilder<'db> {
4646
pub(crate) fn add(mut self, ty: Type<'db>) -> Self {
4747
match ty {
4848
Type::Union(union) => {
49-
self.elements.extend(union.elements(self.db));
49+
for element in union.elements(self.db) {
50+
self = self.add(*element);
51+
}
5052
}
5153
Type::Never => {}
5254
_ => {
55+
let mut remove = vec![];
56+
for element in &self.elements {
57+
if ty.is_subtype_of(self.db, *element) {
58+
return self;
59+
} else if element.is_subtype_of(self.db, ty) {
60+
remove.push(*element);
61+
}
62+
}
63+
for element in remove {
64+
self.elements.remove(&element);
65+
}
5366
self.elements.insert(ty);
5467
}
5568
}
@@ -368,6 +381,24 @@ mod tests {
368381
assert_eq!(union.elements_vec(&db), &[t0, t1, t2]);
369382
}
370383

384+
#[test]
385+
fn build_union_simplify_subtype() {
386+
let db = setup_db();
387+
let t0 = builtins_symbol_ty(&db, "str").to_instance(&db);
388+
let t1 = Type::LiteralString;
389+
let t2 = Type::Unknown;
390+
let u0 = UnionType::from_elements(&db, [t0, t1]);
391+
let u1 = UnionType::from_elements(&db, [t1, t0]);
392+
let u2 = UnionType::from_elements(&db, [t0, t1, t2]);
393+
394+
assert_eq!(u0, t0);
395+
assert_eq!(u1, t0);
396+
assert_eq!(u2.expect_union().elements_vec(&db), &[t0, t2]);
397+
}
398+
399+
#[test]
400+
fn build_union_no_simplify_any() {}
401+
371402
impl<'db> IntersectionType<'db> {
372403
fn pos_vec(self, db: &'db TestDb) -> Vec<Type<'db>> {
373404
self.positive(db).into_iter().copied().collect()

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5800,8 +5800,7 @@ mod tests {
58005800
.unwrap();
58015801
db.write_file("/src/c.pyi", "x: int").unwrap();
58025802

5803-
// TODO this should simplify to just 'int'
5804-
assert_public_ty(&db, "/src/a.py", "x", "int | Literal[1]");
5803+
assert_public_ty(&db, "/src/a.py", "x", "int");
58055804
}
58065805

58075806
// Incremental inference tests

0 commit comments

Comments
 (0)