Skip to content

Commit 53cec9b

Browse files
BowenBaoTharinduRusira
authored andcommitted
Fix 'get_attr' call in dynamo 'run_node' (pytorch#127696)
Fixes pytorch#124858 Pull Request resolved: pytorch#127696 Approved by: https://github.com/jansel ghstack dependencies: pytorch#127695
1 parent 3e3621a commit 53cec9b

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

test/dynamo/test_decorators.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,28 @@ def fn(a, b, c):
465465

466466
self.assertEqual(cnt.frame_count, 1)
467467

468+
def test_assume_constant_result_on_user_defined_fn(self):
469+
@torch._dynamo.assume_constant_result
470+
def const_fn(n, s):
471+
return torch.full([n], s)
472+
473+
def fn(B):
474+
B = const_fn(B.size(0), 13)
475+
X = B * 2
476+
return X.tolist()
477+
478+
B_list = [8] * 32
479+
480+
B = torch.tensor(B_list, dtype=torch.int32)
481+
torch._dynamo.decorators.mark_static(B, 0)
482+
483+
torch._dynamo.config.capture_scalar_outputs = True
484+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
485+
486+
self.assertEqual(
487+
fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B)
488+
)
489+
468490

469491
if __name__ == "__main__":
470492
from torch._dynamo.test_case import run_tests

torch/_dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1908,7 +1908,7 @@ def make_error_message(e):
19081908
assert nnmodule is not None
19091909
return nnmodule(*args, **kwargs)
19101910
elif op == "get_attr":
1911-
return tracer.get_submodule(node.target)
1911+
return tracer.output_graph.get_submodule(node.target)
19121912
elif op == "placeholder":
19131913
assert "example_value" in node.meta
19141914
return node.meta["example_value"]

0 commit comments

Comments
 (0)