Skip to content

Commit c2bb5d5

Browse files
authored
[red-knot] Fix equivalence of differently ordered unions that contain Callable types (#17145)
## Summary Fixes #17058. Equivalent callable types were not understood as equivalent when they appeared nested inside unions and intersections. This PR fixes that by ensuring that `Callable` elements nested inside unions, intersections and tuples have their representations normalized before one union type is compared with another for equivalence, or before one intersection type is compared with another for equivalence. The normalizations applied to a `Callable` type are: - the type of the default value is stripped from all parameters (only whether the parameter _has_ a default value is relevant to whether one `Callable` type is equivalent to another) - The names of the parameters are stripped from positional-only parameters, variadic parameters and keyword-variadic parameters - Unions and intersections that are present (top-level or nested) inside parameter annotations or return annotations are normalized. Adding a `CallableType::normalized()` method also allows us to simplify the implementation of `CallableType::is_equivalent_to()`. ### Should these normalizations be done eagerly as part of a `CallableType` constructor? I considered this. It's something that we could still consider doing in the future; this PR doesn't rule it out as a possibility. However, I didn't pursue it for now, for several reasons: 1. Our current `Display` implementation doesn't handle well the possibility that a parameter might not have a name or an annotated type. Callable types with parameters like this would be displayed as follows: ```py (, ,) -> None: ... ``` That's fixable! It could easily become something like `(Unknown, Unknown) -> None: ...`. But it also illustrates that we probably want to retain the parameter names when displaying the signature of a `lambda` function if you're hovering over a reference to the lambda in an IDE. Currently we don't have a `LambdaType` struct for representing `lambda` functions; if we wanted to eagerly normalize signatures when creating `CallableType`s, we'd probably have to add a `LambdaType` struct so that we would retain the full signature of a `lambda` function, rather than representing it as an eagerly simplified `CallableType`. 2. In order to ensure that it's impossible to create `CallableType`s without the parameters being normalized, I'd either have to create an alternative `SimplifiedSignature` struct (which would duplicate a lot of code), or move `CallableType` to a new module so that the only way of constructing a `CallableType` instance would be via a constructor method that performs the normalizations eagerly on the callable's signature. Again, this isn't a dealbreaker, and I think it's still an option, but it would be a lot of churn, and it didn't seem necessary for now. Doing it this way, at least to start with, felt like it would create a diff that's easier to review and felt like it would create fewer merge conflicts for others. ## Test Plan - Added a regression mdtest for #17058 - Ran `QUICKCHECK_TESTS=1000000 cargo test --release -p red_knot_python_semantic -- --ignored types::property_tests::stable`
1 parent cb7dae1 commit c2bb5d5

File tree

5 files changed

+98
-56
lines changed

5 files changed

+98
-56
lines changed

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md

+3
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def f1(a: int = 1) -> None: ...
134134
def f2(a: int = 2) -> None: ...
135135

136136
static_assert(is_equivalent_to(CallableTypeOf[f1], CallableTypeOf[f2]))
137+
static_assert(is_equivalent_to(CallableTypeOf[f1] | bool | CallableTypeOf[f2], CallableTypeOf[f2] | bool | CallableTypeOf[f1]))
137138
```
138139

139140
The names of the positional-only, variadic and keyword-variadic parameters does not need to be the
@@ -144,6 +145,7 @@ def f3(a1: int, /, *args1: int, **kwargs2: int) -> None: ...
144145
def f4(a2: int, /, *args2: int, **kwargs1: int) -> None: ...
145146

146147
static_assert(is_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f4]))
148+
static_assert(is_equivalent_to(CallableTypeOf[f3] | bool | CallableTypeOf[f4], CallableTypeOf[f4] | bool | CallableTypeOf[f3]))
147149
```
148150

149151
Putting it all together, the following two callables are equivalent:
@@ -153,6 +155,7 @@ def f5(a1: int, /, b: float, c: bool = False, *args1: int, d: int = 1, e: str, *
153155
def f6(a2: int, /, b: float, c: bool = True, *args2: int, d: int = 2, e: str, **kwargs2: float) -> None: ...
154156

155157
static_assert(is_equivalent_to(CallableTypeOf[f5], CallableTypeOf[f6]))
158+
static_assert(is_equivalent_to(CallableTypeOf[f5] | bool | CallableTypeOf[f6], CallableTypeOf[f6] | bool | CallableTypeOf[f5]))
156159
```
157160

158161
### Not equivalent

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md

+3
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ def f4(a=2): ...
147147
def f5(a): ...
148148

149149
static_assert(is_gradual_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f4]))
150+
static_assert(
151+
is_gradual_equivalent_to(CallableTypeOf[f3] | bool | CallableTypeOf[f4], CallableTypeOf[f4] | bool | CallableTypeOf[f3])
152+
)
150153
static_assert(not is_gradual_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f5]))
151154

152155
def f6(a, /): ...

crates/red_knot_python_semantic/src/types.rs

+44-35
Original file line numberDiff line numberDiff line change
@@ -597,19 +597,22 @@ impl<'db> Type<'db> {
597597
}
598598
}
599599

600-
/// Return a normalized version of `self` in which all unions and intersections are sorted
601-
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
600+
/// Return a "normalized" version of `self` that ensures that equivalent types have the same Salsa ID.
601+
///
602+
/// A normalized type:
603+
/// - Has all unions and intersections sorted according to a canonical order,
604+
/// no matter how "deeply" a union/intersection may be nested.
605+
/// - Strips the names of positional-only parameters and variadic parameters from `Callable` types,
606+
/// as these are irrelevant to whether a callable type `X` is equivalent to a callable type `Y`.
607+
/// - Strips the types of default values from parameters in `Callable` types: only whether a parameter
608+
/// *has* or *does not have* a default value is relevant to whether two `Callable` types are equivalent.
602609
#[must_use]
603-
pub fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self {
610+
pub fn normalized(self, db: &'db dyn Db) -> Self {
604611
match self {
605-
Type::Union(union) => Type::Union(union.to_sorted_union(db)),
606-
Type::Intersection(intersection) => {
607-
Type::Intersection(intersection.to_sorted_intersection(db))
608-
}
609-
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions_and_intersections(db)),
610-
Type::Callable(callable) => {
611-
Type::Callable(callable.with_sorted_unions_and_intersections(db))
612-
}
612+
Type::Union(union) => Type::Union(union.normalized(db)),
613+
Type::Intersection(intersection) => Type::Intersection(intersection.normalized(db)),
614+
Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)),
615+
Type::Callable(callable) => Type::Callable(callable.normalized(db)),
613616
Type::LiteralString
614617
| Type::Instance(_)
615618
| Type::PropertyInstance(_)
@@ -4676,16 +4679,19 @@ impl<'db> CallableType<'db> {
46764679
)
46774680
}
46784681

4679-
fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self {
4682+
/// Return a "normalized" version of this `Callable` type.
4683+
///
4684+
/// See [`Type::normalized`] for more details.
4685+
fn normalized(self, db: &'db dyn Db) -> Self {
46804686
let signature = self.signature(db);
46814687
let parameters = signature
46824688
.parameters()
46834689
.iter()
4684-
.map(|param| param.clone().with_sorted_unions_and_intersections(db))
4690+
.map(|param| param.normalized(db))
46854691
.collect();
46864692
let return_ty = signature
46874693
.return_ty
4688-
.map(|return_ty| return_ty.with_sorted_unions_and_intersections(db));
4694+
.map(|return_ty| return_ty.normalized(db));
46894695
CallableType::new(db, Signature::new(parameters, return_ty))
46904696
}
46914697

@@ -5447,13 +5453,15 @@ impl<'db> UnionType<'db> {
54475453
self.elements(db).iter().all(|ty| ty.is_fully_static(db))
54485454
}
54495455

5450-
/// Create a new union type with the elements sorted according to a canonical ordering.
5456+
/// Create a new union type with the elements normalized.
5457+
///
5458+
/// See [`Type::normalized`] for more details.
54515459
#[must_use]
5452-
pub fn to_sorted_union(self, db: &'db dyn Db) -> Self {
5460+
pub fn normalized(self, db: &'db dyn Db) -> Self {
54535461
let mut new_elements: Vec<Type<'db>> = self
54545462
.elements(db)
54555463
.iter()
5456-
.map(|element| element.with_sorted_unions_and_intersections(db))
5464+
.map(|element| element.normalized(db))
54575465
.collect();
54585466
new_elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r));
54595467
UnionType::new(db, new_elements.into_boxed_slice())
@@ -5487,13 +5495,13 @@ impl<'db> UnionType<'db> {
54875495
return true;
54885496
}
54895497

5490-
let sorted_self = self.to_sorted_union(db);
5498+
let sorted_self = self.normalized(db);
54915499

54925500
if sorted_self == other {
54935501
return true;
54945502
}
54955503

5496-
sorted_self == other.to_sorted_union(db)
5504+
sorted_self == other.normalized(db)
54975505
}
54985506

54995507
/// Return `true` if `self` has exactly the same set of possible static materializations as `other`
@@ -5510,13 +5518,13 @@ impl<'db> UnionType<'db> {
55105518
return false;
55115519
}
55125520

5513-
let sorted_self = self.to_sorted_union(db);
5521+
let sorted_self = self.normalized(db);
55145522

55155523
if sorted_self == other {
55165524
return true;
55175525
}
55185526

5519-
let sorted_other = other.to_sorted_union(db);
5527+
let sorted_other = other.normalized(db);
55205528

55215529
if sorted_self == sorted_other {
55225530
return true;
@@ -5547,17 +5555,17 @@ pub struct IntersectionType<'db> {
55475555

55485556
impl<'db> IntersectionType<'db> {
55495557
/// Return a new `IntersectionType` instance with the positive and negative types sorted
5550-
/// according to a canonical ordering.
5558+
/// according to a canonical ordering, and other normalizations applied to each element as applicable.
5559+
///
5560+
/// See [`Type::normalized`] for more details.
55515561
#[must_use]
5552-
pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self {
5562+
pub fn normalized(self, db: &'db dyn Db) -> Self {
55535563
fn normalized_set<'db>(
55545564
db: &'db dyn Db,
55555565
elements: &FxOrderSet<Type<'db>>,
55565566
) -> FxOrderSet<Type<'db>> {
5557-
let mut elements: FxOrderSet<Type<'db>> = elements
5558-
.iter()
5559-
.map(|ty| ty.with_sorted_unions_and_intersections(db))
5560-
.collect();
5567+
let mut elements: FxOrderSet<Type<'db>> =
5568+
elements.iter().map(|ty| ty.normalized(db)).collect();
55615569

55625570
elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r));
55635571
elements
@@ -5620,13 +5628,13 @@ impl<'db> IntersectionType<'db> {
56205628
return true;
56215629
}
56225630

5623-
let sorted_self = self.to_sorted_intersection(db);
5631+
let sorted_self = self.normalized(db);
56245632

56255633
if sorted_self == other {
56265634
return true;
56275635
}
56285636

5629-
sorted_self == other.to_sorted_intersection(db)
5637+
sorted_self == other.normalized(db)
56305638
}
56315639

56325640
/// Return `true` if `self` has exactly the same set of possible static materializations as `other`
@@ -5642,13 +5650,13 @@ impl<'db> IntersectionType<'db> {
56425650
return false;
56435651
}
56445652

5645-
let sorted_self = self.to_sorted_intersection(db);
5653+
let sorted_self = self.normalized(db);
56465654

56475655
if sorted_self == other {
56485656
return true;
56495657
}
56505658

5651-
let sorted_other = other.to_sorted_intersection(db);
5659+
let sorted_other = other.normalized(db);
56525660

56535661
if sorted_self == sorted_other {
56545662
return true;
@@ -5834,14 +5842,15 @@ impl<'db> TupleType<'db> {
58345842
Type::Tuple(Self::new(db, elements.into_boxed_slice()))
58355843
}
58365844

5837-
/// Return a normalized version of `self` in which all unions and intersections are sorted
5838-
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
5845+
/// Return a normalized version of `self`.
5846+
///
5847+
/// See [`Type::normalized`] for more details.
58395848
#[must_use]
5840-
pub fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self {
5849+
pub fn normalized(self, db: &'db dyn Db) -> Self {
58415850
let elements: Box<[Type<'db>]> = self
58425851
.elements(db)
58435852
.iter()
5844-
.map(|ty| ty.with_sorted_unions_and_intersections(db))
5853+
.map(|ty| ty.normalized(db))
58455854
.collect();
58465855
TupleType::new(db, elements)
58475856
}

crates/red_knot_python_semantic/src/types/signatures.rs

+39-16
Original file line numberDiff line numberDiff line change
@@ -606,31 +606,54 @@ impl<'db> Parameter<'db> {
606606
self
607607
}
608608

609-
pub(crate) fn with_sorted_unions_and_intersections(mut self, db: &'db dyn Db) -> Self {
610-
self.annotated_type = self
611-
.annotated_type
612-
.map(|ty| ty.with_sorted_unions_and_intersections(db));
613-
614-
self.kind = match self.kind {
615-
ParameterKind::PositionalOnly { name, default_type } => ParameterKind::PositionalOnly {
616-
name,
617-
default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)),
609+
/// Strip information from the parameter so that two equivalent parameters compare equal.
610+
/// Normalize nested unions and intersections in the annotated type, if any.
611+
///
612+
/// See [`Type::normalized`] for more details.
613+
pub(crate) fn normalized(&self, db: &'db dyn Db) -> Self {
614+
let Parameter {
615+
annotated_type,
616+
kind,
617+
form,
618+
} = self;
619+
620+
// Ensure unions and intersections are ordered in the annotated type (if there is one)
621+
let annotated_type = annotated_type.map(|ty| ty.normalized(db));
622+
623+
// Ensure that parameter names are stripped from positional-only, variadic and keyword-variadic parameters.
624+
// Ensure that we only record whether a parameter *has* a default
625+
// (strip the precise *type* of the default from the parameter, replacing it with `Never`).
626+
let kind = match kind {
627+
ParameterKind::PositionalOnly {
628+
name: _,
629+
default_type,
630+
} => ParameterKind::PositionalOnly {
631+
name: None,
632+
default_type: default_type.map(|_| Type::Never),
618633
},
619634
ParameterKind::PositionalOrKeyword { name, default_type } => {
620635
ParameterKind::PositionalOrKeyword {
621-
name,
622-
default_type: default_type
623-
.map(|ty| ty.with_sorted_unions_and_intersections(db)),
636+
name: name.clone(),
637+
default_type: default_type.map(|_| Type::Never),
624638
}
625639
}
626640
ParameterKind::KeywordOnly { name, default_type } => ParameterKind::KeywordOnly {
627-
name,
628-
default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)),
641+
name: name.clone(),
642+
default_type: default_type.map(|_| Type::Never),
643+
},
644+
ParameterKind::Variadic { name: _ } => ParameterKind::Variadic {
645+
name: Name::new_static("args"),
646+
},
647+
ParameterKind::KeywordVariadic { name: _ } => ParameterKind::KeywordVariadic {
648+
name: Name::new_static("kwargs"),
629649
},
630-
ParameterKind::Variadic { .. } | ParameterKind::KeywordVariadic { .. } => self.kind,
631650
};
632651

633-
self
652+
Self {
653+
annotated_type,
654+
kind,
655+
form: *form,
656+
}
634657
}
635658

636659
fn from_node_and_kind(

crates/red_knot_python_semantic/src/types/type_ordering.rs

+9-5
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,17 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
7373
(Type::WrapperDescriptor(_), _) => Ordering::Less,
7474
(_, Type::WrapperDescriptor(_)) => Ordering::Greater,
7575

76-
(Type::Callable(left), Type::Callable(right)) => left.cmp(right),
76+
(Type::Callable(left), Type::Callable(right)) => {
77+
debug_assert_eq!(*left, left.normalized(db));
78+
debug_assert_eq!(*right, right.normalized(db));
79+
left.cmp(right)
80+
}
7781
(Type::Callable(_), _) => Ordering::Less,
7882
(_, Type::Callable(_)) => Ordering::Greater,
7983

8084
(Type::Tuple(left), Type::Tuple(right)) => {
81-
debug_assert_eq!(*left, left.with_sorted_unions_and_intersections(db));
82-
debug_assert_eq!(*right, right.with_sorted_unions_and_intersections(db));
85+
debug_assert_eq!(*left, left.normalized(db));
86+
debug_assert_eq!(*right, right.normalized(db));
8387
left.cmp(right)
8488
}
8589
(Type::Tuple(_), _) => Ordering::Less,
@@ -271,8 +275,8 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
271275
}
272276

273277
(Type::Intersection(left), Type::Intersection(right)) => {
274-
debug_assert_eq!(*left, left.to_sorted_intersection(db));
275-
debug_assert_eq!(*right, right.to_sorted_intersection(db));
278+
debug_assert_eq!(*left, left.normalized(db));
279+
debug_assert_eq!(*right, right.normalized(db));
276280

277281
if left == right {
278282
return Ordering::Equal;

0 commit comments

Comments
 (0)