Skip to content

Commit 9619b37

Browse files
BowenBaoignaciobartol
authored andcommitted
Fix 'get_attr' call in dynamo 'run_node' (pytorch#127696)
Fixes pytorch#124858 Pull Request resolved: pytorch#127696 Approved by: https://github.com/jansel ghstack dependencies: pytorch#127695
1 parent 7e24f3d commit 9619b37

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

test/dynamo/test_decorators.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,28 @@ def fn(a, b, c):
464464

465465
self.assertEqual(cnt.frame_count, 1)
466466

467+
def test_assume_constant_result_on_user_defined_fn(self):
468+
@torch._dynamo.assume_constant_result
469+
def const_fn(n, s):
470+
return torch.full([n], s)
471+
472+
def fn(B):
473+
B = const_fn(B.size(0), 13)
474+
X = B * 2
475+
return X.tolist()
476+
477+
B_list = [8] * 32
478+
479+
B = torch.tensor(B_list, dtype=torch.int32)
480+
torch._dynamo.decorators.mark_static(B, 0)
481+
482+
torch._dynamo.config.capture_scalar_outputs = True
483+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
484+
485+
self.assertEqual(
486+
fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B)
487+
)
488+
467489

468490
if __name__ == "__main__":
469491
from torch._dynamo.test_case import run_tests

torch/_dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1906,7 +1906,7 @@ def make_error_message(e):
19061906
assert nnmodule is not None
19071907
return nnmodule(*args, **kwargs)
19081908
elif op == "get_attr":
1909-
return tracer.get_submodule(node.target)
1909+
return tracer.output_graph.get_submodule(node.target)
19101910
elif op == "placeholder":
19111911
assert "example_value" in node.meta
19121912
return node.meta["example_value"]

0 commit comments

Comments
 (0)