Skip to content

Record function returned types #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 5, 2020
Merged
20 changes: 19 additions & 1 deletion record_api/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ class Signature(BaseModel):
var_kw: typing.Optional[typing.Tuple[str, Type]] = None

metadata: typing.Dict[str, int] = pydantic.Field(default_factory=dict)
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = pydantic.Field(
default_factory=dict
)

@pydantic.validator("pos_only_required")
@classmethod
Expand Down Expand Up @@ -408,6 +411,14 @@ def validate_keys_unique(cls, values) -> None:
raise ValueError(repr(all_keys))
return values

@property
def return_type_annotation(self) -> typing.Optional[cst.Annotation]:
return_type_annotation = None
if self.return_type:
return_type = create_type(self.return_type)
return_type_annotation = cst.Annotation(return_type.annotation)
return return_type_annotation

def function_def(
self,
name: str,
Expand All @@ -427,6 +438,7 @@ def function_def(
[cst.SimpleStatementLine([s]) for s in self.body(indent)]
),
decorators,
self.return_type_annotation
)

def body(self, indent: int) -> typing.Iterable[cst.BaseSmallStatement]:
Expand Down Expand Up @@ -508,13 +520,17 @@ def initial_args(self) -> typing.Iterator[Type]:

@classmethod
def from_params(
cls, args: typing.List[object] = [], kwargs: typing.Dict[str, object] = {}
cls,
args: typing.List[object] = [],
kwargs: typing.Dict[str, object] = {},
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = {},
) -> Signature:
# If we don't know what the args/kwargs are, assume the args are positional only
# and the kwargs and keyword only
return cls(
pos_only_required={f"_{i}": create_type(v) for i, v in enumerate(args)},
kw_only_required={k: create_type(v) for k, v in kwargs.items()},
return_type=return_type
)

@classmethod
Expand All @@ -525,6 +541,7 @@ def from_bound_params(
var_pos: typing.Optional[typing.Tuple[str, typing.List[object]]] = None,
kw_only: typing.Dict[str, object] = {},
var_kw: typing.Optional[typing.Tuple[str, typing.Dict[str, object]]] = None,
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = {},
) -> Signature:
return cls(
pos_only_required={k: create_type(v) for k, v in pos_only},
Expand All @@ -538,6 +555,7 @@ def from_bound_params(
if var_kw
else None
),
return_type=return_type
)

def content_equal(self, other: Signature) -> bool:
Expand Down
75 changes: 65 additions & 10 deletions record_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
context_manager: Optional[ContextManager] = None
write_line: Optional[Callable[[dict], None]] = None

FUNCTION_CALL_OP_NAMES = {
"CALL_METHOD",
"CALL_FUNCTION",
"CALL_FUNCTION_KW",
"CALL_FUNCTION_EX",
"LOAD_ATTR",
"BINARY_SUBSCR",
}


def get_tracer() -> Tracer:
global TRACER
Expand Down Expand Up @@ -273,6 +282,7 @@ def log_call(
fn: Callable,
args: Iterable = (),
kwargs: Mapping[str, Any] = {},
return_type: Any = None,
) -> None:
bound = Bound.create(fn, args, kwargs)
line: Dict = {"location": location, "function": preprocess(fn)}
Expand All @@ -284,6 +294,8 @@ def log_call(
line["params"]["kwargs"] = {k: preprocess(v) for k, v in kwargs.items()}
else:
line["bound_params"] = bound.as_dict()
if return_type:
line['return_type'] = return_type
assert write_line
write_line(line)

Expand All @@ -295,11 +307,16 @@ class Stack:
NULL: ClassVar[object] = object()
current_i: int = dataclasses.field(init=False, default=0)
opcode: int = dataclasses.field(init=False)
previous_stack: Optional[Stack] = None
log_call_args: Tuple = ()

def __post_init__(self):
self.op_stack = get_stack.OpStack(self.frame)
self.opcode = self.frame.f_code.co_code[self.frame.f_lasti]

if self.previous_stack and self.previous_stack.previous_stack:
del self.previous_stack.previous_stack

@property
def oparg(self):
# sort of replicates logic in dis._unpack_opargs but doesn't account for extended
Expand Down Expand Up @@ -360,14 +377,24 @@ def pop_n(self, n: int) -> List:
return l

def process(
self, keyed_args: Tuple, fn: Callable, args: Iterable, kwargs: Mapping = {}
self,
keyed_args: Tuple,
fn: Callable,
args: Iterable,
kwargs: Mapping = {},
delay: bool = False
) -> None:
# Note: This take args as an iterable, instead of as a varargs, so that if we don't trace we don't have to expand the iterable

# Note: This take args as an iterable, instead of as a varargs, so that if
# we don't trace we don't have to expand the iterable
if self.tracer.should_trace(*keyed_args):
filename = self.frame.f_code.co_filename
line = self.frame.f_lineno
# Don't pass kwargs if not used, so we can more easily test mock calls
log_call(f"{filename}:{line}", fn, tuple(args), *((kwargs,) if kwargs else ()))
if not delay:
log_call(f"{filename}:{line}", fn, tuple(args), *((kwargs,) if kwargs else ()))
else:
self.log_call_args = (filename, line, fn, tuple(args), kwargs)

def __call__(self) -> None:
"""
Expand All @@ -383,14 +410,34 @@ def __call__(self) -> None:
(self.TOS, self.TOS1), BINARY_OPS[opname], (self.TOS1, self.TOS)
)

if self.previous_stack and self.previous_stack.opname in FUNCTION_CALL_OP_NAMES:
self.log_called_method()

method_name = f"op_{opname}"
if hasattr(self, method_name):
getattr(self, method_name)()
return None

def log_called_method(self):
if self.previous_stack.log_call_args:
tos = self.TOS
if type(tos) is type and issubclass(tos, Exception):
# Don't record exception
return
return_type = type(tos) if type(tos) != type else tos
filename, line, fn, args, *kwargs = self.previous_stack.log_call_args
kwargs = kwargs[0] if kwargs else {}
log_call(
f"{filename}:{line}",
fn,
tuple(args),
*((kwargs,) if kwargs else ()),
return_type=return_type,
)

# special case subscr b/c we only check first arg, not both
def op_BINARY_SUBSCR(self):
self.process((self.TOS1,), op.getitem, (self.TOS1, self.TOS))
self.process((self.TOS1,), op.getitem, (self.TOS1, self.TOS), delay=True)

def op_STORE_SUBSCR(self):
self.process((self.TOS1,), op.setitem, (self.TOS1, self.TOS, self.TOS2))
Expand All @@ -399,7 +446,7 @@ def op_DELETE_SUBSCR(self):
self.process((self.TOS1,), op.delitem, (self.TOS1, self.TOS))

def op_LOAD_ATTR(self):
self.process((self.TOS,), getattr, (self.TOS, self.opvalname))
self.process((self.TOS,), getattr, (self.TOS, self.opvalname), delay=True)

def op_STORE_ATTR(self):
self.process((self.TOS,), setattr, (self.TOS, self.opvalname, self.TOS1))
Expand Down Expand Up @@ -458,7 +505,7 @@ def op_COMPARE_OP(self):
def op_CALL_FUNCTION(self):
args = self.pop_n(self.oparg)
fn = self.pop()
self.process((fn,), fn, args)
self.process((fn,), fn, args, delay=True)

def op_CALL_FUNCTION_KW(self):
kwargs_keys = self.pop()
Expand All @@ -468,7 +515,7 @@ def op_CALL_FUNCTION_KW(self):
args = self.pop_n(self.oparg - n_kwargs)
fn = self.pop()

self.process((fn,), fn, args, kwargs)
self.process((fn,), fn, args, kwargs, delay=True)

def op_CALL_FUNCTION_EX(self):
has_kwarg = self.oparg & int("01", 2)
Expand All @@ -482,20 +529,21 @@ def op_CALL_FUNCTION_EX(self):
fn = self.pop()
if inspect.isgenerator(args):
return
self.process((fn,), fn, args, kwargs)
self.process((fn,), fn, args, kwargs, delay=True)

def op_CALL_METHOD(self):
args = self.pop_n(self.oparg)
function_or_self = self.pop()
null_or_method = self.pop()
if null_or_method is self.NULL:
function = function_or_self
self.process((function,), function, args)
self.process((function,), function, args, delay=True)
else:
self_ = function_or_self
method = null_or_method
self.process(
(self_,), method, itertools.chain((self_,), args),
delay=True
)


Expand Down Expand Up @@ -548,6 +596,7 @@ class Tracer:
calls_to_modules: List[str]
# the modules we should trace calls from
calls_from_modules: List[str]
previous_stack: Optional[Stack] = None

def __enter__(self):
sys.settrace(self)
Expand Down Expand Up @@ -577,7 +626,13 @@ def __call__(self, frame, event, arg) -> Optional[Tracer]:
return None

if self.should_trace_frame(frame):
Stack(self, frame)()
stack = Stack(
self,
frame,
previous_stack=self.previous_stack,
)
stack()
self.previous_stack = stack if stack.log_call_args else None
return None

def should_trace_frame(self, frame) -> bool:
Expand Down
10 changes: 7 additions & 3 deletions record_api/infer_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ def __main__():


def parse_line(
n: int, function: object, params=None, bound_params=None,
n: int,
function: object,
params=None,
bound_params=None,
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = None
) -> typing.Optional[API]:
if bound_params is not None:
signature = Signature.from_bound_params(**bound_params)
signature = Signature.from_bound_params(**bound_params, return_type=return_type)
else:
signature = Signature.from_params(**params)
signature = Signature.from_params(**params, return_type=return_type)
signature.metadata[f"usage.{LABEL}"] = n
return process_function(create_type(function), s=signature)

Expand Down
Loading