Skip to content

Commit a108c67

Browse files
authored
Add get_expression_type to CheckerPluginInterface (python#15369)
Fixes python#14845. p.s. In the issue above, I was concerned that adding this method would create an avenue for infinite recursions (if called carelessly), but in fact I haven't managed to induce it, e.g. FunctionSigContext has `args` but not the call expression itself.
1 parent e7b917e commit a108c67

File tree

5 files changed

+25
-28
lines changed

5 files changed

+25
-28
lines changed

mypy/checker.py

+3
Original file line numberDiff line numberDiff line change
@@ -6793,6 +6793,9 @@ def has_valid_attribute(self, typ: Type, name: str) -> bool:
67936793
)
67946794
return not watcher.has_new_errors()
67956795

6796+
def get_expression_type(self, node: Expression, type_context: Type | None = None) -> Type:
6797+
return self.expr_checker.accept(node, type_context=type_context)
6798+
67966799

67976800
class CollectArgTypeVarTypes(TypeTraverserVisitor):
67986801
"""Collects the non-nested argument types in a set."""

mypy/plugin.py

+5
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ def named_generic_type(self, name: str, args: list[Type]) -> Instance:
250250
"""Construct an instance of a builtin type with given type arguments."""
251251
raise NotImplementedError
252252

253+
@abstractmethod
254+
def get_expression_type(self, node: Expression, type_context: Type | None = None) -> Type:
255+
"""Checks the type of the given expression."""
256+
raise NotImplementedError
257+
253258

254259
@trait
255260
class SemanticAnalyzerPluginInterface:

mypy/plugins/attrs.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import mypy.plugin # To avoid circular imports.
1111
from mypy.applytype import apply_generic_arguments
12-
from mypy.checker import TypeChecker
1312
from mypy.errorcodes import LITERAL_REQ
1413
from mypy.expandtype import expand_type, expand_type_by_instance
1514
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
@@ -1048,13 +1047,7 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
10481047
return ctx.default_signature # leave it to the type checker to complain
10491048

10501049
inst_arg = ctx.args[0][0]
1051-
1052-
# <hack>
1053-
assert isinstance(ctx.api, TypeChecker)
1054-
inst_type = ctx.api.expr_checker.accept(inst_arg)
1055-
# </hack>
1056-
1057-
inst_type = get_proper_type(inst_type)
1050+
inst_type = get_proper_type(ctx.api.get_expression_type(inst_arg))
10581051
inst_type_str = format_type_bare(inst_type, ctx.api.options)
10591052

10601053
attr_types = _get_expanded_attr_types(ctx, inst_type, inst_type, inst_type)
@@ -1074,14 +1067,10 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
10741067

10751068
def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
10761069
"""Provide the signature for `attrs.fields`."""
1077-
if not ctx.args or len(ctx.args) != 1 or not ctx.args[0] or not ctx.args[0][0]:
1070+
if len(ctx.args) != 1 or len(ctx.args[0]) != 1:
10781071
return ctx.default_signature
10791072

1080-
# <hack>
1081-
assert isinstance(ctx.api, TypeChecker)
1082-
inst_type = ctx.api.expr_checker.accept(ctx.args[0][0])
1083-
# </hack>
1084-
proper_type = get_proper_type(inst_type)
1073+
proper_type = get_proper_type(ctx.api.get_expression_type(ctx.args[0][0]))
10851074

10861075
# fields(Any) -> Any, fields(type[Any]) -> Any
10871076
if (
@@ -1098,7 +1087,7 @@ def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
10981087
inner = get_proper_type(proper_type.upper_bound)
10991088
if isinstance(inner, Instance):
11001089
# We need to work arg_types to compensate for the attrs stubs.
1101-
arg_types = [inst_type]
1090+
arg_types = [proper_type]
11021091
cls = inner.type
11031092
elif isinstance(proper_type, CallableType):
11041093
cls = proper_type.type_object()

test-data/unit/check-custom-plugin.test

+4-1
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,10 @@ plugins=<ROOT>/test-data/unit/plugins/descriptor.py
887887
# flags: --config-file tmp/mypy.ini
888888

889889
def dynamic_signature(arg1: str) -> str: ...
890-
reveal_type(dynamic_signature(1)) # N: Revealed type is "builtins.int"
890+
a: int = 1
891+
reveal_type(dynamic_signature(a)) # N: Revealed type is "builtins.int"
892+
b: bytes = b'foo'
893+
reveal_type(dynamic_signature(b)) # N: Revealed type is "builtins.bytes"
891894
[file mypy.ini]
892895
\[mypy]
893896
plugins=<ROOT>/test-data/unit/plugins/function_sig_hook.py
+9-12
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,23 @@
1-
from mypy.plugin import CallableType, CheckerPluginInterface, FunctionSigContext, Plugin
2-
from mypy.types import Instance, Type
1+
from mypy.plugin import CallableType, FunctionSigContext, Plugin
2+
33

44
class FunctionSigPlugin(Plugin):
55
def get_function_signature_hook(self, fullname):
66
if fullname == '__main__.dynamic_signature':
77
return my_hook
88
return None
99

10-
def _str_to_int(api: CheckerPluginInterface, typ: Type) -> Type:
11-
if isinstance(typ, Instance):
12-
if typ.type.fullname == 'builtins.str':
13-
return api.named_generic_type('builtins.int', [])
14-
elif typ.args:
15-
return typ.copy_modified(args=[_str_to_int(api, t) for t in typ.args])
16-
17-
return typ
1810

1911
def my_hook(ctx: FunctionSigContext) -> CallableType:
12+
arg1_args = ctx.args[0]
13+
if len(arg1_args) != 1:
14+
return ctx.default_signature
15+
arg1_type = ctx.api.get_expression_type(arg1_args[0])
2016
return ctx.default_signature.copy_modified(
21-
arg_types=[_str_to_int(ctx.api, t) for t in ctx.default_signature.arg_types],
22-
ret_type=_str_to_int(ctx.api, ctx.default_signature.ret_type),
17+
arg_types=[arg1_type],
18+
ret_type=arg1_type,
2319
)
2420

21+
2522
def plugin(version):
2623
return FunctionSigPlugin

0 commit comments

Comments
 (0)