Skip to content

Commit 06aa182

Browse files
JukkaLwesleywright
andcommitted
[dataclass_transform] support implicit default for "init" parameter in field specifiers (#15010)
(Basic functionality was implemented by @wesleywright in #14870. I added overload resolution.) This note from PEP 681 was missed in the initial implementation of field specifiers: > If unspecified, init defaults to True. Field specifier functions can use overloads that implicitly specify the value of init using a literal bool value type (Literal[False] or Literal[True]). This commit adds support for reading a default from the declared type of the `init` parameter if possible. Otherwise, it continues to use the typical default of `True`. The implementation was non-trivial, since regular overload resolution can't be used in the dataclass plugin, which is applied before type checking. As a workaround, I added a simple overload resolution helper that should be enough to support typical use cases. It doesn't do full overload resolution using types, but it knows about `None`, `Literal[True]` and `Literal[False]` and a few other things. --------- Co-authored-by: Wesley Collin Wright <[email protected]>
1 parent 7beaec2 commit 06aa182

File tree

7 files changed

+299
-12
lines changed

7 files changed

+299
-12
lines changed

mypy/plugins/common.py

+75-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from mypy.argmap import map_actuals_to_formals
34
from mypy.fixup import TypeFixer
45
from mypy.nodes import (
56
ARG_POS,
@@ -13,6 +14,7 @@
1314
Expression,
1415
FuncDef,
1516
JsonDict,
17+
NameExpr,
1618
Node,
1719
PassStmt,
1820
RefExpr,
@@ -22,20 +24,27 @@
2224
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
2325
from mypy.semanal_shared import (
2426
ALLOW_INCOMPATIBLE_OVERRIDE,
27+
parse_bool,
2528
require_bool_literal_argument,
2629
set_callable_name,
2730
)
2831
from mypy.typeops import ( # noqa: F401 # Part of public API
2932
try_getting_str_literals as try_getting_str_literals,
3033
)
3134
from mypy.types import (
35+
AnyType,
3236
CallableType,
37+
Instance,
38+
LiteralType,
39+
NoneType,
3340
Overloaded,
3441
Type,
42+
TypeOfAny,
3543
TypeType,
3644
TypeVarType,
3745
deserialize_type,
3846
get_proper_type,
47+
is_optional,
3948
)
4049
from mypy.typevars import fill_typevars
4150
from mypy.util import get_unique_redefinition_name
@@ -87,6 +96,71 @@ def _get_argument(call: CallExpr, name: str) -> Expression | None:
8796
return None
8897

8998

99+
def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> CallableType:
100+
"""Perform limited lookup of a matching overload item.
101+
102+
Full overload resolution is only supported during type checking, but plugins
103+
sometimes need to resolve overloads. This can be used in some such use cases.
104+
105+
Resolve overloads based on these things only:
106+
107+
* Match using argument kinds and names
108+
* If formal argument has type None, only accept the "None" expression in the callee
109+
* If formal argument has type Literal[True] or Literal[False], only accept the
110+
relevant bool literal
111+
112+
Return the first matching overload item, or the last one if nothing matches.
113+
"""
114+
for item in overload.items[:-1]:
115+
ok = True
116+
mapped = map_actuals_to_formals(
117+
call.arg_kinds,
118+
call.arg_names,
119+
item.arg_kinds,
120+
item.arg_names,
121+
lambda i: AnyType(TypeOfAny.special_form),
122+
)
123+
124+
# Look for extra actuals
125+
matched_actuals = set()
126+
for actuals in mapped:
127+
matched_actuals.update(actuals)
128+
if any(i not in matched_actuals for i in range(len(call.args))):
129+
ok = False
130+
131+
for arg_type, kind, actuals in zip(item.arg_types, item.arg_kinds, mapped):
132+
if kind.is_required() and not actuals:
133+
# Missing required argument
134+
ok = False
135+
break
136+
elif actuals:
137+
args = [call.args[i] for i in actuals]
138+
arg_type = get_proper_type(arg_type)
139+
arg_none = any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args)
140+
if isinstance(arg_type, NoneType):
141+
if not arg_none:
142+
ok = False
143+
break
144+
elif (
145+
arg_none
146+
and not is_optional(arg_type)
147+
and not (
148+
isinstance(arg_type, Instance)
149+
and arg_type.type.fullname == "builtins.object"
150+
)
151+
and not isinstance(arg_type, AnyType)
152+
):
153+
ok = False
154+
break
155+
elif isinstance(arg_type, LiteralType) and type(arg_type.value) is bool:
156+
if not any(parse_bool(arg) == arg_type.value for arg in args):
157+
ok = False
158+
break
159+
if ok:
160+
return item
161+
return overload.items[-1]
162+
163+
90164
def _get_callee_type(call: CallExpr) -> CallableType | None:
91165
"""Return the type of the callee, regardless of its syntatic form."""
92166

@@ -103,8 +177,7 @@ def _get_callee_type(call: CallExpr) -> CallableType | None:
103177
if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type:
104178
callee_node_type = get_proper_type(callee_node.type)
105179
if isinstance(callee_node_type, Overloaded):
106-
# We take the last overload.
107-
return callee_node_type.items[-1]
180+
return find_shallow_matching_overload_item(callee_node_type, call)
108181
elif isinstance(callee_node_type, CallableType):
109182
return callee_node_type
110183

mypy/plugins/dataclasses.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
4242
from mypy.plugins.common import (
43+
_get_callee_type,
4344
_get_decorator_bool_argument,
4445
add_attribute_to_class,
4546
add_method_to_class,
@@ -48,7 +49,7 @@
4849
from mypy.semanal_shared import find_dataclass_transform_spec, require_bool_literal_argument
4950
from mypy.server.trigger import make_wildcard_trigger
5051
from mypy.state import state
51-
from mypy.typeops import map_type_from_supertype
52+
from mypy.typeops import map_type_from_supertype, try_getting_literals_from_type
5253
from mypy.types import (
5354
AnyType,
5455
CallableType,
@@ -517,7 +518,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
517518

518519
is_in_init_param = field_args.get("init")
519520
if is_in_init_param is None:
520-
is_in_init = True
521+
is_in_init = self._get_default_init_value_for_field_specifier(stmt.rvalue)
521522
else:
522523
is_in_init = bool(self._api.parse_bool(is_in_init_param))
523524

@@ -760,6 +761,30 @@ def _get_bool_arg(self, name: str, default: bool) -> bool:
760761
return require_bool_literal_argument(self._api, expression, name, default)
761762
return default
762763

764+
def _get_default_init_value_for_field_specifier(self, call: Expression) -> bool:
765+
"""
766+
Find a default value for the `init` parameter of the specifier being called. If the
767+
specifier's type signature includes an `init` parameter with a type of `Literal[True]` or
768+
`Literal[False]`, return the appropriate boolean value from the literal. Otherwise,
769+
fall back to the standard default of `True`.
770+
"""
771+
if not isinstance(call, CallExpr):
772+
return True
773+
774+
specifier_type = _get_callee_type(call)
775+
if specifier_type is None:
776+
return True
777+
778+
parameter = specifier_type.argument_by_name("init")
779+
if parameter is None:
780+
return True
781+
782+
literals = try_getting_literals_from_type(parameter.typ, bool, "builtins.bool")
783+
if literals is None or len(literals) != 1:
784+
return True
785+
786+
return literals[0]
787+
763788
def _infer_dataclass_attr_init_type(
764789
self, sym: SymbolTableNode, name: str, context: Context
765790
) -> Type | None:

mypy/semanal.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@
216216
calculate_tuple_fallback,
217217
find_dataclass_transform_spec,
218218
has_placeholder,
219+
parse_bool,
219220
require_bool_literal_argument,
220221
set_callable_name as set_callable_name,
221222
)
@@ -6465,12 +6466,8 @@ def is_initial_mangled_global(self, name: str) -> bool:
64656466
return name == unmangle(name) + "'"
64666467

64676468
def parse_bool(self, expr: Expression) -> bool | None:
6468-
if isinstance(expr, NameExpr):
6469-
if expr.fullname == "builtins.True":
6470-
return True
6471-
if expr.fullname == "builtins.False":
6472-
return False
6473-
return None
6469+
# This wrapper is preserved for plugins.
6470+
return parse_bool(expr)
64746471

64756472
def parse_str_literal(self, expr: Expression) -> str | None:
64766473
"""Attempt to find the string literal value of the given expression. Returns `None` if no

mypy/semanal_shared.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Decorator,
1919
Expression,
2020
FuncDef,
21+
NameExpr,
2122
Node,
2223
OverloadedFuncDef,
2324
RefExpr,
@@ -451,11 +452,20 @@ def require_bool_literal_argument(
451452
default: bool | None = None,
452453
) -> bool | None:
453454
"""Attempt to interpret an expression as a boolean literal, and fail analysis if we can't."""
454-
value = api.parse_bool(expression)
455+
value = parse_bool(expression)
455456
if value is None:
456457
api.fail(
457458
f'"{name}" argument must be a True or False literal', expression, code=LITERAL_REQ
458459
)
459460
return default
460461

461462
return value
463+
464+
465+
def parse_bool(expr: Expression) -> bool | None:
466+
if isinstance(expr, NameExpr):
467+
if expr.fullname == "builtins.True":
468+
return True
469+
if expr.fullname == "builtins.False":
470+
return False
471+
return None

mypy/test/testtypes.py

+147-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,21 @@
77
from mypy.indirection import TypeIndirectionVisitor
88
from mypy.join import join_simple, join_types
99
from mypy.meet import meet_types, narrow_declared_type
10-
from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, COVARIANT, INVARIANT
10+
from mypy.nodes import (
11+
ARG_NAMED,
12+
ARG_OPT,
13+
ARG_POS,
14+
ARG_STAR,
15+
ARG_STAR2,
16+
CONTRAVARIANT,
17+
COVARIANT,
18+
INVARIANT,
19+
ArgKind,
20+
CallExpr,
21+
Expression,
22+
NameExpr,
23+
)
24+
from mypy.plugins.common import find_shallow_matching_overload_item
1125
from mypy.state import state
1226
from mypy.subtypes import is_more_precise, is_proper_subtype, is_same_type, is_subtype
1327
from mypy.test.helpers import Suite, assert_equal, assert_type, skip
@@ -1287,3 +1301,135 @@ def assert_union_result(self, t: ProperType, expected: list[Type]) -> None:
12871301
t2 = remove_instance_last_known_values(t)
12881302
assert type(t2) is UnionType
12891303
assert t2.items == expected
1304+
1305+
1306+
class ShallowOverloadMatchingSuite(Suite):
1307+
def setUp(self) -> None:
1308+
self.fx = TypeFixture()
1309+
1310+
def test_simple(self) -> None:
1311+
fx = self.fx
1312+
ov = self.make_overload([[("x", fx.anyt, ARG_NAMED)], [("y", fx.anyt, ARG_NAMED)]])
1313+
# Match first only
1314+
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 0)
1315+
# Match second only
1316+
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "y")), 1)
1317+
# No match -- invalid keyword arg name
1318+
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "z")), 1)
1319+
# No match -- missing arg
1320+
self.assert_find_shallow_matching_overload_item(ov, make_call(), 1)
1321+
# No match -- extra arg
1322+
self.assert_find_shallow_matching_overload_item(
1323+
ov, make_call(("foo", "x"), ("foo", "z")), 1
1324+
)
1325+
1326+
def test_match_using_types(self) -> None:
1327+
fx = self.fx
1328+
ov = self.make_overload(
1329+
[
1330+
[("x", fx.nonet, ARG_POS)],
1331+
[("x", fx.lit_false, ARG_POS)],
1332+
[("x", fx.lit_true, ARG_POS)],
1333+
[("x", fx.anyt, ARG_POS)],
1334+
]
1335+
)
1336+
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0)
1337+
self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.False", None)), 1)
1338+
self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.True", None)), 2)
1339+
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", None)), 3)
1340+
1341+
def test_none_special_cases(self) -> None:
1342+
fx = self.fx
1343+
ov = self.make_overload(
1344+
[[("x", fx.callable(fx.nonet), ARG_POS)], [("x", fx.nonet, ARG_POS)]]
1345+
)
1346+
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1)
1347+
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
1348+
ov = self.make_overload([[("x", fx.str_type, ARG_POS)], [("x", fx.nonet, ARG_POS)]])
1349+
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1)
1350+
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
1351+
ov = self.make_overload(
1352+
[[("x", UnionType([fx.str_type, fx.a]), ARG_POS)], [("x", fx.nonet, ARG_POS)]]
1353+
)
1354+
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1)
1355+
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
1356+
ov = self.make_overload([[("x", fx.o, ARG_POS)], [("x", fx.nonet, ARG_POS)]])
1357+
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0)
1358+
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
1359+
ov = self.make_overload(
1360+
[[("x", UnionType([fx.str_type, fx.nonet]), ARG_POS)], [("x", fx.nonet, ARG_POS)]]
1361+
)
1362+
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0)
1363+
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
1364+
ov = self.make_overload([[("x", fx.anyt, ARG_POS)], [("x", fx.nonet, ARG_POS)]])
1365+
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0)
1366+
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
1367+
1368+
def test_optional_arg(self) -> None:
1369+
fx = self.fx
1370+
ov = self.make_overload(
1371+
[[("x", fx.anyt, ARG_NAMED)], [("y", fx.anyt, ARG_OPT)], [("z", fx.anyt, ARG_NAMED)]]
1372+
)
1373+
self.assert_find_shallow_matching_overload_item(ov, make_call(), 1)
1374+
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 0)
1375+
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "y")), 1)
1376+
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "z")), 2)
1377+
1378+
def test_two_args(self) -> None:
1379+
fx = self.fx
1380+
ov = self.make_overload(
1381+
[
1382+
[("x", fx.nonet, ARG_OPT), ("y", fx.anyt, ARG_OPT)],
1383+
[("x", fx.anyt, ARG_OPT), ("y", fx.anyt, ARG_OPT)],
1384+
]
1385+
)
1386+
self.assert_find_shallow_matching_overload_item(ov, make_call(), 0)
1387+
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", "x")), 0)
1388+
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 1)
1389+
self.assert_find_shallow_matching_overload_item(
1390+
ov, make_call(("foo", "y"), ("None", "x")), 0
1391+
)
1392+
self.assert_find_shallow_matching_overload_item(
1393+
ov, make_call(("foo", "y"), ("bar", "x")), 1
1394+
)
1395+
1396+
def assert_find_shallow_matching_overload_item(
1397+
self, ov: Overloaded, call: CallExpr, expected_index: int
1398+
) -> None:
1399+
c = find_shallow_matching_overload_item(ov, call)
1400+
assert c in ov.items
1401+
assert ov.items.index(c) == expected_index
1402+
1403+
def make_overload(self, items: list[list[tuple[str, Type, ArgKind]]]) -> Overloaded:
1404+
result = []
1405+
for item in items:
1406+
arg_types = []
1407+
arg_names = []
1408+
arg_kinds = []
1409+
for name, typ, kind in item:
1410+
arg_names.append(name)
1411+
arg_types.append(typ)
1412+
arg_kinds.append(kind)
1413+
result.append(
1414+
CallableType(
1415+
arg_types, arg_kinds, arg_names, ret_type=NoneType(), fallback=self.fx.o
1416+
)
1417+
)
1418+
return Overloaded(result)
1419+
1420+
1421+
def make_call(*items: tuple[str, str | None]) -> CallExpr:
1422+
args: list[Expression] = []
1423+
arg_names = []
1424+
arg_kinds = []
1425+
for arg, name in items:
1426+
shortname = arg.split(".")[-1]
1427+
n = NameExpr(shortname)
1428+
n.fullname = arg
1429+
args.append(n)
1430+
arg_names.append(name)
1431+
if name:
1432+
arg_kinds.append(ARG_NAMED)
1433+
else:
1434+
arg_kinds.append(ARG_POS)
1435+
return CallExpr(NameExpr("f"), args, arg_kinds, arg_names)

0 commit comments

Comments
 (0)