|
7 | 7 | from mypy.indirection import TypeIndirectionVisitor
|
8 | 8 | from mypy.join import join_simple, join_types
|
9 | 9 | from mypy.meet import meet_types, narrow_declared_type
|
10 |
| -from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, COVARIANT, INVARIANT |
| 10 | +from mypy.nodes import ( |
| 11 | + ARG_NAMED, |
| 12 | + ARG_OPT, |
| 13 | + ARG_POS, |
| 14 | + ARG_STAR, |
| 15 | + ARG_STAR2, |
| 16 | + CONTRAVARIANT, |
| 17 | + COVARIANT, |
| 18 | + INVARIANT, |
| 19 | + ArgKind, |
| 20 | + CallExpr, |
| 21 | + Expression, |
| 22 | + NameExpr, |
| 23 | +) |
| 24 | +from mypy.plugins.common import find_shallow_matching_overload_item |
11 | 25 | from mypy.state import state
|
12 | 26 | from mypy.subtypes import is_more_precise, is_proper_subtype, is_same_type, is_subtype
|
13 | 27 | from mypy.test.helpers import Suite, assert_equal, assert_type, skip
|
@@ -1287,3 +1301,135 @@ def assert_union_result(self, t: ProperType, expected: list[Type]) -> None:
|
1287 | 1301 | t2 = remove_instance_last_known_values(t)
|
1288 | 1302 | assert type(t2) is UnionType
|
1289 | 1303 | assert t2.items == expected
|
| 1304 | + |
| 1305 | + |
| 1306 | +class ShallowOverloadMatchingSuite(Suite): |
| 1307 | + def setUp(self) -> None: |
| 1308 | + self.fx = TypeFixture() |
| 1309 | + |
| 1310 | + def test_simple(self) -> None: |
| 1311 | + fx = self.fx |
| 1312 | + ov = self.make_overload([[("x", fx.anyt, ARG_NAMED)], [("y", fx.anyt, ARG_NAMED)]]) |
| 1313 | + # Match first only |
| 1314 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 0) |
| 1315 | + # Match second only |
| 1316 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "y")), 1) |
| 1317 | + # No match -- invalid keyword arg name |
| 1318 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "z")), 1) |
| 1319 | + # No match -- missing arg |
| 1320 | + self.assert_find_shallow_matching_overload_item(ov, make_call(), 1) |
| 1321 | + # No match -- extra arg |
| 1322 | + self.assert_find_shallow_matching_overload_item( |
| 1323 | + ov, make_call(("foo", "x"), ("foo", "z")), 1 |
| 1324 | + ) |
| 1325 | + |
| 1326 | + def test_match_using_types(self) -> None: |
| 1327 | + fx = self.fx |
| 1328 | + ov = self.make_overload( |
| 1329 | + [ |
| 1330 | + [("x", fx.nonet, ARG_POS)], |
| 1331 | + [("x", fx.lit_false, ARG_POS)], |
| 1332 | + [("x", fx.lit_true, ARG_POS)], |
| 1333 | + [("x", fx.anyt, ARG_POS)], |
| 1334 | + ] |
| 1335 | + ) |
| 1336 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0) |
| 1337 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.False", None)), 1) |
| 1338 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.True", None)), 2) |
| 1339 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", None)), 3) |
| 1340 | + |
| 1341 | + def test_none_special_cases(self) -> None: |
| 1342 | + fx = self.fx |
| 1343 | + ov = self.make_overload( |
| 1344 | + [[("x", fx.callable(fx.nonet), ARG_POS)], [("x", fx.nonet, ARG_POS)]] |
| 1345 | + ) |
| 1346 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1) |
| 1347 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) |
| 1348 | + ov = self.make_overload([[("x", fx.str_type, ARG_POS)], [("x", fx.nonet, ARG_POS)]]) |
| 1349 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1) |
| 1350 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) |
| 1351 | + ov = self.make_overload( |
| 1352 | + [[("x", UnionType([fx.str_type, fx.a]), ARG_POS)], [("x", fx.nonet, ARG_POS)]] |
| 1353 | + ) |
| 1354 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1) |
| 1355 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) |
| 1356 | + ov = self.make_overload([[("x", fx.o, ARG_POS)], [("x", fx.nonet, ARG_POS)]]) |
| 1357 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0) |
| 1358 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) |
| 1359 | + ov = self.make_overload( |
| 1360 | + [[("x", UnionType([fx.str_type, fx.nonet]), ARG_POS)], [("x", fx.nonet, ARG_POS)]] |
| 1361 | + ) |
| 1362 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0) |
| 1363 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) |
| 1364 | + ov = self.make_overload([[("x", fx.anyt, ARG_POS)], [("x", fx.nonet, ARG_POS)]]) |
| 1365 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0) |
| 1366 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) |
| 1367 | + |
| 1368 | + def test_optional_arg(self) -> None: |
| 1369 | + fx = self.fx |
| 1370 | + ov = self.make_overload( |
| 1371 | + [[("x", fx.anyt, ARG_NAMED)], [("y", fx.anyt, ARG_OPT)], [("z", fx.anyt, ARG_NAMED)]] |
| 1372 | + ) |
| 1373 | + self.assert_find_shallow_matching_overload_item(ov, make_call(), 1) |
| 1374 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 0) |
| 1375 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "y")), 1) |
| 1376 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "z")), 2) |
| 1377 | + |
| 1378 | + def test_two_args(self) -> None: |
| 1379 | + fx = self.fx |
| 1380 | + ov = self.make_overload( |
| 1381 | + [ |
| 1382 | + [("x", fx.nonet, ARG_OPT), ("y", fx.anyt, ARG_OPT)], |
| 1383 | + [("x", fx.anyt, ARG_OPT), ("y", fx.anyt, ARG_OPT)], |
| 1384 | + ] |
| 1385 | + ) |
| 1386 | + self.assert_find_shallow_matching_overload_item(ov, make_call(), 0) |
| 1387 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", "x")), 0) |
| 1388 | + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 1) |
| 1389 | + self.assert_find_shallow_matching_overload_item( |
| 1390 | + ov, make_call(("foo", "y"), ("None", "x")), 0 |
| 1391 | + ) |
| 1392 | + self.assert_find_shallow_matching_overload_item( |
| 1393 | + ov, make_call(("foo", "y"), ("bar", "x")), 1 |
| 1394 | + ) |
| 1395 | + |
| 1396 | + def assert_find_shallow_matching_overload_item( |
| 1397 | + self, ov: Overloaded, call: CallExpr, expected_index: int |
| 1398 | + ) -> None: |
| 1399 | + c = find_shallow_matching_overload_item(ov, call) |
| 1400 | + assert c in ov.items |
| 1401 | + assert ov.items.index(c) == expected_index |
| 1402 | + |
| 1403 | + def make_overload(self, items: list[list[tuple[str, Type, ArgKind]]]) -> Overloaded: |
| 1404 | + result = [] |
| 1405 | + for item in items: |
| 1406 | + arg_types = [] |
| 1407 | + arg_names = [] |
| 1408 | + arg_kinds = [] |
| 1409 | + for name, typ, kind in item: |
| 1410 | + arg_names.append(name) |
| 1411 | + arg_types.append(typ) |
| 1412 | + arg_kinds.append(kind) |
| 1413 | + result.append( |
| 1414 | + CallableType( |
| 1415 | + arg_types, arg_kinds, arg_names, ret_type=NoneType(), fallback=self.fx.o |
| 1416 | + ) |
| 1417 | + ) |
| 1418 | + return Overloaded(result) |
| 1419 | + |
| 1420 | + |
| 1421 | +def make_call(*items: tuple[str, str | None]) -> CallExpr: |
| 1422 | + args: list[Expression] = [] |
| 1423 | + arg_names = [] |
| 1424 | + arg_kinds = [] |
| 1425 | + for arg, name in items: |
| 1426 | + shortname = arg.split(".")[-1] |
| 1427 | + n = NameExpr(shortname) |
| 1428 | + n.fullname = arg |
| 1429 | + args.append(n) |
| 1430 | + arg_names.append(name) |
| 1431 | + if name: |
| 1432 | + arg_kinds.append(ARG_NAMED) |
| 1433 | + else: |
| 1434 | + arg_kinds.append(ARG_POS) |
| 1435 | + return CallExpr(NameExpr("f"), args, arg_kinds, arg_names) |
0 commit comments