Skip to content

Commit 3e3621a

Browse files
BowenBaoTharinduRusira
authored andcommitted
Fix assume_constant_result for UnspecializedNNModuleVariable methods (pytorch#127695)
Fixes pytorch#127509 Pull Request resolved: pytorch#127695 Approved by: https://github.com/jansel
1 parent 77f73d3 commit 3e3621a

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

test/dynamo/test_export.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
15091509
graph, guards = torch._dynamo.export(model)(inp)
15101510
self.assertEqual(model(inp), graph(inp))
15111511

1512+
def test_export_with_constant_in_unspecialized_nn_module(self):
1513+
class Module(torch.nn.Module):
1514+
def __init__(self, y):
1515+
super().__init__()
1516+
self.y = y
1517+
1518+
@torch._dynamo.assume_constant_result
1519+
def check(self):
1520+
return self.y[0].item() == 1
1521+
1522+
def forward(self, x):
1523+
# This line leads to module obj being tracked as UnspecializedNNModuleVariable in dynamo
1524+
self.device = x.device
1525+
1526+
if self.check():
1527+
return x + 1
1528+
else:
1529+
return x + 2
1530+
1531+
model = Module(torch.tensor([1]))
1532+
inp = torch.ones(3, 4)
1533+
graph, _ = torch._dynamo.export(model)(inp)
1534+
self.assertEqual(model(inp), graph(inp))
1535+
15121536
def test_export_decomp(self):
15131537
def f(x):
15141538
return x.t() + x.t()

torch/_dynamo/variables/functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ def call_function(
339339
return self.obj.call_method(
340340
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
341341
)
342+
if self.is_constant:
343+
fn = getattr(self.obj.value, self.fn.__name__)
344+
return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
342345
return super().call_function(tx, args, kwargs)
343346

344347
def inspect_parameter_names(self):

0 commit comments

Comments
 (0)