Skip to content

Commit ccca97d

Browse files
committed
Create new rewrites for elemwise
There is no need for an Elemwise Op if all inputs have rank 0. And we don't need to use scalar constants as inputs of the Elemwise, they can be inputs for the scalar_op.
1 parent 16d1cbe commit ccca97d

File tree

5 files changed

+119
-34
lines changed

5 files changed

+119
-34
lines changed

pytensor/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
9292
env=env,
9393
stdout=subprocess.PIPE,
9494
stderr=(subprocess.PIPE if hide_stderr else None),
95-
**popen_kwargs
95+
**popen_kwargs,
9696
)
9797
break
9898
except OSError:

pytensor/compile/mode.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,16 @@ def apply(self, fgraph):
255255
"specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6
256256
) # must be after gpu stuff at 48.5
257257

258+
# Must be before add_destroy_handler
259+
optdb.register(
260+
"elemwise_fusion",
261+
SequenceDB(),
262+
"fast_run",
263+
"fusion",
264+
"local_elemwise_fusion",
265+
position=49,
266+
)
267+
258268
# especially constant merge
259269
optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49)
260270

@@ -453,7 +463,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
453463
)
454464
NUMBA = Mode(
455465
NumbaLinker(),
456-
RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
466+
RewriteDatabaseQuery(
467+
include=["fast_run", "fast_run_numba", "fast_compile_numba"],
468+
exclude=["cxx_only", "BlasOpt"],
469+
),
457470
)
458471

459472

pytensor/tensor/rewriting/elemwise.py

Lines changed: 101 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
in2out,
1919
node_rewriter,
2020
)
21-
from pytensor.graph.rewriting.db import SequenceDB
2221
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
22+
from pytensor.tensor import as_tensor_variable
2323
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
2424
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2525
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -380,6 +380,99 @@ def is_dimshuffle_useless(new_order, input):
380380
return is_useless
381381

382382

383+
@node_rewriter([Elemwise])
384+
def local_elemwise_lift_scalars(fgraph, node):
385+
op = node.op
386+
387+
if not isinstance(op, Elemwise):
388+
return False
389+
390+
if not all(input.ndim == 0 for input in node.inputs):
391+
return False
392+
393+
scalars = [aes.as_scalar(input) for input in node.inputs]
394+
395+
# TODO Something like
396+
# copy_stack_trace(node.outputs[0], new_res)
397+
return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs]
398+
399+
400+
compile.optdb["specialize"].register(
401+
"local_elemwise_lift_scalars",
402+
local_elemwise_lift_scalars,
403+
"fast_run_numba",
404+
"fast_compile_numba",
405+
)
406+
407+
408+
@node_rewriter([Elemwise])
409+
def push_elemwise_constants(fgraph, node):
410+
"""Push constant scalars from inputs to elemwise to inputs of the
411+
contained scalar op.
412+
"""
413+
op = node.op
414+
415+
if not isinstance(op, Elemwise):
416+
return False
417+
418+
if any(op.inplace_pattern):
419+
return False
420+
421+
if not isinstance(node.op.scalar_op, aes.Composite):
422+
return False
423+
424+
def is_constant_scalar(x):
425+
return isinstance(x, TensorConstant) and all(x.broadcastable)
426+
427+
push_idxs = []
428+
push_values = []
429+
keep_values = []
430+
for i, input in enumerate(node.inputs):
431+
if is_constant_scalar(input):
432+
push_idxs.append(i)
433+
val = input.value
434+
push_values.append(aes.constant(val.item(), dtype=val.dtype))
435+
elif (
436+
input.owner
437+
and isinstance(input.owner.op, DimShuffle)
438+
and is_constant_scalar(input.owner.inputs[0])
439+
):
440+
push_idxs.append(i)
441+
val = input.owner.inputs[0].value
442+
push_values.append(aes.constant(val.item(), dtype=val.dtype))
443+
else:
444+
keep_values.append(input)
445+
446+
if not push_values:
447+
return False
448+
449+
inner_graph = node.op.scalar_op.fgraph
450+
to_replace = [input for i, input in enumerate(inner_graph.inputs) if i in push_idxs]
451+
452+
# Clone the inner graph, it might be used somewhere else
453+
inner_graph, mapping = inner_graph.clone_get_equiv()
454+
inner_graph.replace_all(
455+
(mapping[old], new) for old, new in zip(to_replace, push_values)
456+
)
457+
458+
new_inputs = [
459+
input for i, input in enumerate(inner_graph.inputs) if i not in push_idxs
460+
]
461+
return (
462+
Elemwise(scalar_op=aes.Composite(new_inputs, inner_graph.outputs))
463+
.make_node(*keep_values)
464+
.outputs
465+
)
466+
467+
468+
compile.optdb["specialize"].register(
469+
"push_elemwise_constants",
470+
push_elemwise_constants,
471+
"fast_run_numba",
472+
"fast_compile_numba",
473+
)
474+
475+
383476
@register_canonicalize
384477
@register_specialize
385478
@node_rewriter([DimShuffle])
@@ -898,34 +991,13 @@ def print_profile(cls, stream, prof, level=0):
898991
print(blanc, " time_toposort", prof[7], file=stream)
899992

900993

901-
if config.tensor__local_elemwise_fusion:
902-
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
903-
fuse_seqopt = SequenceDB()
904-
fuse_seqopt.register(
905-
"composite_elemwise_fusion",
906-
FusionOptimizer(local_elemwise_fusion),
907-
"fast_run",
908-
"fusion",
909-
position=1,
910-
)
911-
compile.optdb.register( # type: ignore
912-
"elemwise_fusion",
913-
fuse_seqopt,
914-
"fast_run",
915-
"fusion",
916-
"local_elemwise_fusion",
917-
"FusionOptimizer",
918-
position=49,
919-
)
920-
else:
921-
compile.optdb.register( # type: ignore
922-
"elemwise_fusion",
923-
FusionOptimizer(local_elemwise_fusion),
924-
"fusion",
925-
"local_elemwise_fusion",
926-
"FusionOptimizer",
927-
position=49,
928-
)
994+
compile.optdb["elemwise_fusion"].register( # type: ignore
995+
"composite_elemwise_fusion",
996+
FusionOptimizer(local_elemwise_fusion),
997+
"fast_run",
998+
"fusion",
999+
position=1,
1000+
)
9291001

9301002

9311003
@register_canonicalize

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pytensor.scalar.basic as aes
1010
import pytensor.scalar.math as aes_math
11+
from pytensor import compile
1112
from pytensor.graph.basic import Constant, Variable
1213
from pytensor.graph.rewriting.basic import (
1314
NodeRewriter,
@@ -91,7 +92,7 @@
9192
register_uncanonicalize,
9293
register_useless,
9394
)
94-
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
95+
from pytensor.tensor.rewriting.elemwise import FusionOptimizer
9596
from pytensor.tensor.shape import Shape, Shape_i
9697
from pytensor.tensor.subtensor import Subtensor
9798
from pytensor.tensor.type import (
@@ -2922,7 +2923,7 @@ def local_add_mul_fusion(fgraph, node):
29222923
return [output]
29232924

29242925

2925-
fuse_seqopt.register(
2926+
compile.optdb["elemwise_fusion"].register(
29262927
"local_add_mul_fusion",
29272928
FusionOptimizer(local_add_mul_fusion),
29282929
"fast_run",

pytensor/tensor/rewriting/subtensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,6 @@ def local_subtensor_lift(fgraph, node):
469469
return [rbcast_subt_x]
470470

471471

472-
@register_canonicalize
473472
@register_specialize
474473
@node_rewriter([Subtensor])
475474
def local_subtensor_merge(fgraph, node):

0 commit comments

Comments
 (0)