Skip to content

Commit d445dd0

Browse files
committed
BUG: handle positional-only parameters in @ normalize
1 parent 280c964 commit d445dd0

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

torch_np/_funcs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,18 @@ def normalizer(func):
5959
def wrapped(*args, **kwds):
6060
sig = inspect.signature(func)
6161

62-
dct = {}
62+
lst, dct = [], {}
6363
# loop over positional parameters and actual arguments
6464
for arg, (name, parm) in zip(args, sig.parameters.items()):
6565
print(arg, name, parm.annotation)
6666
normalizer = normalizers.get(parm.annotation, None)
6767
if normalizer:
68-
dct[name] = normalizer(arg, name)
68+
# dct[name] = normalizer(arg, name)
69+
lst.append(normalizer(arg))
6970
else:
7071
# untyped arguments pass through
71-
dct[name] = arg
72+
# dct[name] = arg
73+
lst.append(arg)
7274

7375
# normalize keyword arguments
7476
for name, arg in kwds.items():
@@ -86,7 +88,7 @@ def wrapped(*args, **kwds):
8688
else:
8789
dct[name] = arg
8890

89-
ba = sig.bind(**dct)
91+
ba = sig.bind(*lst, **dct)
9092
ba.apply_defaults()
9193

9294
# Now that all parameters have been consumed, check:

torch_np/tests/numpy_tests/lib/test_shape_base_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def test_basic(self):
597597
assert type(res) is np.ndarray
598598

599599
aa = np.ones((3, 1, 4, 1, 1))
600-
assert aa.squeeze().base is aa
600+
assert aa.squeeze().get()._base is aa.get()
601601

602602
def test_squeeze_axis(self):
603603
A = [[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]

0 commit comments

Comments
 (0)