Skip to content

Commit c60b4d7

Browse files
[ty] Add subtyping between Callable types and class literals with __init__ (#17638)
## Summary Allow classes with `__init__` to be subtypes of `Callable` Fixes astral-sh/ty#358 ## Test Plan Update is_subtype_of.md --------- Co-authored-by: Carl Meyer <[email protected]>
1 parent 16621fa commit c60b4d7

File tree

5 files changed

+329
-46
lines changed

5 files changed

+329
-46
lines changed

crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,11 +608,49 @@ c: Callable[[Any], str] = A().g
608608
```py
609609
from typing import Any, Callable
610610

611+
c: Callable[[object], type] = type
611612
c: Callable[[str], Any] = str
612613
c: Callable[[str], Any] = int
613614

614615
# error: [invalid-assignment]
615616
c: Callable[[str], Any] = object
617+
618+
class A:
619+
def __init__(self, x: int) -> None: ...
620+
621+
a: Callable[[int], A] = A
622+
623+
class C:
624+
def __new__(cls, *args, **kwargs) -> "C":
625+
return super().__new__(cls)
626+
627+
def __init__(self, x: int) -> None: ...
628+
629+
c: Callable[[int], C] = C
630+
```
631+
632+
### Generic class literal types
633+
634+
```toml
635+
[environment]
636+
python-version = "3.12"
637+
```
638+
639+
```py
640+
from typing import Callable
641+
642+
class B[T]:
643+
def __init__(self, x: T) -> None: ...
644+
645+
b: Callable[[int], B[int]] = B[int]
646+
647+
class C[T]:
648+
def __new__(cls, *args, **kwargs) -> "C[T]":
649+
return super().__new__(cls)
650+
651+
def __init__(self, x: T) -> None: ...
652+
653+
c: Callable[[int], C[int]] = C[int]
616654
```
617655

618656
### Overloads

crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1219,7 +1219,7 @@ static_assert(is_subtype_of(TypeOf[C], Callable[[], str]))
12191219
#### Classes with `__new__`
12201220

12211221
```py
1222-
from typing import Callable
1222+
from typing import Callable, overload
12231223
from ty_extensions import TypeOf, static_assert, is_subtype_of
12241224

12251225
class A:
@@ -1244,6 +1244,20 @@ static_assert(is_subtype_of(TypeOf[E], Callable[[], C]))
12441244
static_assert(is_subtype_of(TypeOf[E], Callable[[], B]))
12451245
static_assert(not is_subtype_of(TypeOf[D], Callable[[], C]))
12461246
static_assert(is_subtype_of(TypeOf[D], Callable[[], B]))
1247+
1248+
class F:
1249+
@overload
1250+
def __new__(cls) -> int: ...
1251+
@overload
1252+
def __new__(cls, x: int) -> "F": ...
1253+
def __new__(cls, x: int | None = None) -> "int | F":
1254+
return 1 if x is None else object.__new__(cls)
1255+
1256+
def __init__(self, y: str) -> None: ...
1257+
1258+
static_assert(is_subtype_of(TypeOf[F], Callable[[int], F]))
1259+
static_assert(is_subtype_of(TypeOf[F], Callable[[], int]))
1260+
static_assert(not is_subtype_of(TypeOf[F], Callable[[str], F]))
12471261
```
12481262

12491263
#### Classes with `__call__` and `__new__`
@@ -1266,6 +1280,123 @@ static_assert(is_subtype_of(TypeOf[F], Callable[[], int]))
12661280
static_assert(not is_subtype_of(TypeOf[F], Callable[[], str]))
12671281
```
12681282

1283+
#### Classes with `__init__`
1284+
1285+
```py
1286+
from typing import Callable, overload
1287+
from ty_extensions import TypeOf, static_assert, is_subtype_of
1288+
1289+
class A:
1290+
def __init__(self, a: int) -> None: ...
1291+
1292+
static_assert(is_subtype_of(TypeOf[A], Callable[[int], A]))
1293+
static_assert(not is_subtype_of(TypeOf[A], Callable[[], A]))
1294+
1295+
class B:
1296+
@overload
1297+
def __init__(self, a: int) -> None: ...
1298+
@overload
1299+
def __init__(self) -> None: ...
1300+
def __init__(self, a: int | None = None) -> None: ...
1301+
1302+
static_assert(is_subtype_of(TypeOf[B], Callable[[int], B]))
1303+
static_assert(is_subtype_of(TypeOf[B], Callable[[], B]))
1304+
1305+
class C: ...
1306+
1307+
# TODO: This assertion should be true once we understand `Self`
1308+
# error: [static-assert-error] "Static assertion error: argument evaluates to `False`"
1309+
static_assert(is_subtype_of(TypeOf[C], Callable[[], C]))
1310+
1311+
class D[T]:
1312+
def __init__(self, x: T) -> None: ...
1313+
1314+
static_assert(is_subtype_of(TypeOf[D[int]], Callable[[int], D[int]]))
1315+
static_assert(not is_subtype_of(TypeOf[D[int]], Callable[[str], D[int]]))
1316+
```
1317+
1318+
#### Classes with `__init__` and `__new__`
1319+
1320+
```py
1321+
from typing import Callable, overload, Self
1322+
from ty_extensions import TypeOf, static_assert, is_subtype_of
1323+
1324+
class A:
1325+
def __new__(cls, a: int) -> Self:
1326+
return super().__new__(cls)
1327+
1328+
def __init__(self, a: int) -> None: ...
1329+
1330+
static_assert(is_subtype_of(TypeOf[A], Callable[[int], A]))
1331+
static_assert(not is_subtype_of(TypeOf[A], Callable[[], A]))
1332+
1333+
class B:
1334+
def __new__(cls, a: int) -> int:
1335+
return super().__new__(cls)
1336+
1337+
def __init__(self, a: str) -> None: ...
1338+
1339+
static_assert(is_subtype_of(TypeOf[B], Callable[[int], int]))
1340+
static_assert(not is_subtype_of(TypeOf[B], Callable[[str], B]))
1341+
1342+
class C:
1343+
def __new__(cls, *args, **kwargs) -> "C":
1344+
return super().__new__(cls)
1345+
1346+
def __init__(self, x: int) -> None: ...
1347+
1348+
# Not subtype because __new__ signature is not fully static
1349+
static_assert(not is_subtype_of(TypeOf[C], Callable[[int], C]))
1350+
static_assert(not is_subtype_of(TypeOf[C], Callable[[], C]))
1351+
1352+
class D: ...
1353+
1354+
class E:
1355+
@overload
1356+
def __new__(cls) -> int: ...
1357+
@overload
1358+
def __new__(cls, x: int) -> D: ...
1359+
def __new__(cls, x: int | None = None) -> int | D:
1360+
return D()
1361+
1362+
def __init__(self, y: str) -> None: ...
1363+
1364+
static_assert(is_subtype_of(TypeOf[E], Callable[[int], D]))
1365+
static_assert(is_subtype_of(TypeOf[E], Callable[[], int]))
1366+
1367+
class F[T]:
1368+
def __new__(cls, x: T) -> "F[T]":
1369+
return super().__new__(cls)
1370+
1371+
def __init__(self, x: T) -> None: ...
1372+
1373+
static_assert(is_subtype_of(TypeOf[F[int]], Callable[[int], F[int]]))
1374+
static_assert(not is_subtype_of(TypeOf[F[int]], Callable[[str], F[int]]))
1375+
```
1376+
1377+
#### Classes with `__call__`, `__new__` and `__init__`
1378+
1379+
If `__call__`, `__new__` and `__init__` are all present, `__call__` takes precedence.
1380+
1381+
```py
1382+
from typing import Callable
1383+
from ty_extensions import TypeOf, static_assert, is_subtype_of
1384+
1385+
class MetaWithIntReturn(type):
1386+
def __call__(cls) -> int:
1387+
return super().__call__()
1388+
1389+
class F(metaclass=MetaWithIntReturn):
1390+
def __new__(cls) -> str:
1391+
return super().__new__(cls)
1392+
1393+
def __init__(self, x: int) -> None: ...
1394+
1395+
static_assert(is_subtype_of(TypeOf[F], Callable[[], int]))
1396+
static_assert(not is_subtype_of(TypeOf[F], Callable[[], str]))
1397+
static_assert(not is_subtype_of(TypeOf[F], Callable[[int], F]))
1398+
```
1399+
12691400
### Bound methods
12701401

12711402
```py

crates/ty_python_semantic/src/types.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,12 +1351,15 @@ impl<'db> Type<'db> {
13511351
}
13521352

13531353
(Type::ClassLiteral(class_literal), Type::Callable(_)) => {
1354-
if let Some(callable) = class_literal.into_callable(db) {
1355-
return callable.is_subtype_of(db, target);
1356-
}
1357-
false
1354+
ClassType::NonGeneric(class_literal)
1355+
.into_callable(db)
1356+
.is_subtype_of(db, target)
13581357
}
13591358

1359+
(Type::GenericAlias(alias), Type::Callable(_)) => ClassType::Generic(alias)
1360+
.into_callable(db)
1361+
.is_subtype_of(db, target),
1362+
13601363
// `Literal[str]` is a subtype of `type` because the `str` class object is an instance of its metaclass `type`.
13611364
// `Literal[abc.ABC]` is a subtype of `abc.ABCMeta` because the `abc.ABC` class object
13621365
// is an instance of its metaclass `abc.ABCMeta`.
@@ -1656,12 +1659,15 @@ impl<'db> Type<'db> {
16561659
}
16571660

16581661
(Type::ClassLiteral(class_literal), Type::Callable(_)) => {
1659-
if let Some(callable) = class_literal.into_callable(db) {
1660-
return callable.is_assignable_to(db, target);
1661-
}
1662-
false
1662+
ClassType::NonGeneric(class_literal)
1663+
.into_callable(db)
1664+
.is_assignable_to(db, target)
16631665
}
16641666

1667+
(Type::GenericAlias(alias), Type::Callable(_)) => ClassType::Generic(alias)
1668+
.into_callable(db)
1669+
.is_assignable_to(db, target),
1670+
16651671
(Type::FunctionLiteral(self_function_literal), Type::Callable(_)) => {
16661672
self_function_literal
16671673
.into_callable_type(db)

0 commit comments

Comments
 (0)