Skip to content

Commit cf8c27c

Browse files
committed
Expand record return values for all function calls
1 parent e30db28 commit cf8c27c

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

record_api/core.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
context_manager: Optional[ContextManager] = None
3030
write_line: Optional[Callable[[dict], None]] = None
3131

32+
FUNCTION_CALL_OP_NAMES = {
33+
"CALL_METHOD",
34+
"CALL_FUNCTION",
35+
"CALL_FUNCTION_KW",
36+
"CALL_FUNCTION_EX",
37+
}
38+
3239

3340
def get_tracer() -> Tracer:
3441
global TRACER
@@ -400,7 +407,8 @@ def __call__(self) -> None:
400407
return self.process(
401408
(self.TOS, self.TOS1), BINARY_OPS[opname], (self.TOS1, self.TOS)
402409
)
403-
if self.previous_stack and self.previous_stack.opname == "CALL_METHOD":
410+
411+
if self.previous_stack and self.previous_stack.opname in FUNCTION_CALL_OP_NAMES:
404412
self.log_called_method()
405413

406414
method_name = f"op_{opname}"
@@ -409,15 +417,16 @@ def __call__(self) -> None:
409417
return None
410418

411419
def log_called_method(self):
412-
filename, line, fn, args, *kwargs = self.previous_stack.log_call_args
413-
kwargs = kwargs[0] if kwargs else {}
414-
log_call(
415-
f"{filename}:{line}",
416-
fn,
417-
tuple(args),
418-
*((kwargs,) if kwargs else ()),
419-
return_type=type(self.TOS),
420-
)
420+
if self.previous_stack.log_call_args:
421+
filename, line, fn, args, *kwargs = self.previous_stack.log_call_args
422+
kwargs = kwargs[0] if kwargs else {}
423+
log_call(
424+
f"{filename}:{line}",
425+
fn,
426+
tuple(args),
427+
*((kwargs,) if kwargs else ()),
428+
return_type=type(self.TOS),
429+
)
421430

422431
# special case subscr b/c we only check first arg, not both
423432
def op_BINARY_SUBSCR(self):
@@ -489,7 +498,7 @@ def op_COMPARE_OP(self):
489498
def op_CALL_FUNCTION(self):
490499
args = self.pop_n(self.oparg)
491500
fn = self.pop()
492-
self.process((fn,), fn, args)
501+
self.process((fn,), fn, args, delay=True)
493502

494503
def op_CALL_FUNCTION_KW(self):
495504
kwargs_keys = self.pop()
@@ -499,7 +508,7 @@ def op_CALL_FUNCTION_KW(self):
499508
args = self.pop_n(self.oparg - n_kwargs)
500509
fn = self.pop()
501510

502-
self.process((fn,), fn, args, kwargs)
511+
self.process((fn,), fn, args, kwargs, delay=True)
503512

504513
def op_CALL_FUNCTION_EX(self):
505514
has_kwarg = self.oparg & int("01", 2)
@@ -513,7 +522,7 @@ def op_CALL_FUNCTION_EX(self):
513522
fn = self.pop()
514523
if inspect.isgenerator(args):
515524
return
516-
self.process((fn,), fn, args, kwargs)
525+
self.process((fn,), fn, args, kwargs, delay=True)
517526

518527
def op_CALL_METHOD(self):
519528
args = self.pop_n(self.oparg)

record_api/test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,20 @@ def test_sort(self):
139139
self.trace("self.a.sort(axis=0)")
140140
self.assertCalls(
141141
call(ANY, getattr, (self.a, "sort")),
142-
call(ANY, self.a.sort, (), {"axis": 0}),
142+
call(ANY, self.a.sort, (), {"axis": 0}, return_type=type(None)),
143143
)
144144

145145
def test_eye(self):
146146
self.trace("np.eye(10, order='F')")
147147
self.assertCalls(
148-
call(ANY, getattr, (np, "eye")), call(ANY, np.eye, (10,), {"order": "F"}),
148+
call(ANY, getattr, (np, "eye")), call(ANY, np.eye, (10,), {"order": "F"}, return_type=np.ndarray),
149149
)
150150

151151
def test_linspace(self):
152152
self.trace("np.linspace(3, 4, endpoint=False)")
153153
self.assertCalls(
154154
call(ANY, getattr, (np, "linspace",)),
155-
call(ANY, np.linspace, (3, 4,), {"endpoint": False}),
155+
call(ANY, np.linspace, (3, 4,), {"endpoint": False}, return_type=np.ndarray),
156156
)
157157

158158
def test_reshape(self):
@@ -167,7 +167,7 @@ def test_concatenate(self):
167167
self.trace("np.concatenate((self.a, self.a), axis=0)")
168168
self.assertCalls(
169169
call(ANY, getattr, (np, "concatenate",)),
170-
call(ANY, np.concatenate, ((self.a, self.a),), {"axis": 0}),
170+
call(ANY, np.concatenate, ((self.a, self.a),), {"axis": 0}, return_type=np.ndarray),
171171
)
172172

173173
def test_ravel_list(self):
@@ -196,7 +196,7 @@ def test_numpy_array_constructor(self):
196196
self.trace("np.ndarray(dtype='int64', shape=tuple())")
197197
self.assertCalls(
198198
call(ANY, getattr, (np, "ndarray")),
199-
call(ANY, np.ndarray, (), {"dtype": "int64", "shape": tuple()}),
199+
call(ANY, np.ndarray, (), {"dtype": "int64", "shape": tuple()}, return_type=np.ndarray),
200200
)
201201

202202
def test_not_contains(self):

0 commit comments

Comments
 (0)