Skip to content

Commit 90880e2

Browse files
author
Ian Schweer
committed
Allow function dispatch for constants
1 parent d4a2b2b commit 90880e2

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs):
3636

3737
@pytorch_typify.register(slice)
3838
@pytorch_typify.register(NoneType)
39-
@pytorch_typify.register(np.number)
4039
def pytorch_typify_no_conversion_needed(data, **kwargs):
4140
return data
4241

42+
@pytorch_typify.register(np.number)
43+
def pytorch_typify_extract(data, **kwargs):
44+
return data.item()
4345

4446
@singledispatch
4547
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
@@ -57,11 +59,20 @@ def pytorch_funcify_FunctionGraph(
5759
conversion_func=pytorch_funcify,
5860
**kwargs,
5961
):
62+
def constants_wrapper(x, **kwargs):
63+
x = pytorch_typify(x)
64+
65+
@torch.compiler.assume_constant_result
66+
def torch_assume_constant(arg=x):
67+
return arg
68+
69+
return torch_assume_constant
70+
6071
built_kwargs = {"conversion_func": conversion_func, **kwargs}
6172
return fgraph_to_python(
6273
fgraph,
6374
conversion_func,
64-
type_conversion_fn=pytorch_typify,
75+
type_conversion_fn=constants_wrapper,
6576
fgraph_name=fgraph_name,
6677
**built_kwargs,
6778
)

pytensor/link/pytorch/linker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class wrapper:
5151
"""
5252

5353
def __init__(self, fn, gen_functors):
54+
self._fn = fn
5455
self.fn = torch.compile(fn)
5556
self.gen_functors = gen_functors.copy()
5657

pytensor/link/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,9 +749,25 @@ def fgraph_to_python(
749749
)
750750
if input_storage[0] is not None or isinstance(i, Constant):
751751
# Constants need to be assigned locally and referenced
752-
global_env[local_input_name] = type_conversion_fn(
752+
getter_or_value = type_conversion_fn(
753753
input_storage[0], variable=i, storage=input_storage, **kwargs
754754
)
755+
if callable(getter_or_value):
756+
# we got passed a function, this could be used to indicate something
757+
# to the backend. We'll embed it
758+
new_output_name = unique_name(i)
759+
getter_unique_name = unique_name(getter_or_value)
760+
global_env[getter_unique_name] = getter_or_value
761+
assign_str = (
762+
f"{new_output_name} = {getter_unique_name}()"
763+
)
764+
body_assigns.append(assign_str)
765+
node_input_names.append(new_output_name)
766+
continue
767+
else:
768+
global_env[local_input_name] = type_conversion_fn(
769+
input_storage[0], variable=i, storage=input_storage, **kwargs
770+
)
755771
# TODO: We could attempt to use the storage arrays directly
756772
# E.g. `local_input_name = f"{local_input_name}[0]"`
757773
node_input_names.append(local_input_name)

0 commit comments

Comments
 (0)