|
18 | 18 | in2out,
|
19 | 19 | node_rewriter,
|
20 | 20 | )
|
21 |
| -from pytensor.graph.rewriting.db import SequenceDB |
22 | 21 | from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
|
| 22 | +from pytensor.tensor import as_tensor_variable |
23 | 23 | from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
|
24 | 24 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
25 | 25 | from pytensor.tensor.exceptions import NotScalarConstantError
|
@@ -380,6 +380,99 @@ def is_dimshuffle_useless(new_order, input):
|
380 | 380 | return is_useless
|
381 | 381 |
|
382 | 382 |
|
| 383 | +@node_rewriter([Elemwise]) |
| 384 | +def local_elemwise_lift_scalars(fgraph, node): |
| 385 | + op = node.op |
| 386 | + |
| 387 | + if not isinstance(op, Elemwise): |
| 388 | + return False |
| 389 | + |
| 390 | + if not all(input.ndim == 0 for input in node.inputs): |
| 391 | + return False |
| 392 | + |
| 393 | + scalars = [aes.as_scalar(input) for input in node.inputs] |
| 394 | + |
| 395 | + # TODO Something like |
| 396 | + # copy_stack_trace(node.outputs[0], new_res) |
| 397 | + return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs] |
| 398 | + |
| 399 | + |
| 400 | +compile.optdb["specialize"].register( |
| 401 | + "local_elemwise_lift_scalars", |
| 402 | + local_elemwise_lift_scalars, |
| 403 | + "fast_run_numba", |
| 404 | + "fast_compile_numba", |
| 405 | +) |
| 406 | + |
| 407 | + |
| 408 | +@node_rewriter([Elemwise]) |
| 409 | +def push_elemwise_constants(fgraph, node): |
| 410 | + """Push constant scalars from inputs to elemwise to inputs of the |
| 411 | + contained scalar op. |
| 412 | + """ |
| 413 | + op = node.op |
| 414 | + |
| 415 | + if not isinstance(op, Elemwise): |
| 416 | + return False |
| 417 | + |
| 418 | + if any(op.inplace_pattern): |
| 419 | + return False |
| 420 | + |
| 421 | + if not isinstance(node.op.scalar_op, aes.Composite): |
| 422 | + return False |
| 423 | + |
| 424 | + def is_constant_scalar(x): |
| 425 | + return isinstance(x, TensorConstant) and all(x.broadcastable) |
| 426 | + |
| 427 | + push_idxs = [] |
| 428 | + push_values = [] |
| 429 | + keep_values = [] |
| 430 | + for i, input in enumerate(node.inputs): |
| 431 | + if is_constant_scalar(input): |
| 432 | + push_idxs.append(i) |
| 433 | + val = input.value |
| 434 | + push_values.append(aes.constant(val.item(), dtype=val.dtype)) |
| 435 | + elif ( |
| 436 | + input.owner |
| 437 | + and isinstance(input.owner.op, DimShuffle) |
| 438 | + and is_constant_scalar(input.owner.inputs[0]) |
| 439 | + ): |
| 440 | + push_idxs.append(i) |
| 441 | + val = input.owner.inputs[0].value |
| 442 | + push_values.append(aes.constant(val.item(), dtype=val.dtype)) |
| 443 | + else: |
| 444 | + keep_values.append(input) |
| 445 | + |
| 446 | + if not push_values: |
| 447 | + return False |
| 448 | + |
| 449 | + inner_graph = node.op.scalar_op.fgraph |
| 450 | + to_replace = [input for i, input in enumerate(inner_graph.inputs) if i in push_idxs] |
| 451 | + |
| 452 | + # Clone the inner graph, it might be used somewhere else |
| 453 | + inner_graph, mapping = inner_graph.clone_get_equiv() |
| 454 | + inner_graph.replace_all( |
| 455 | + (mapping[old], new) for old, new in zip(to_replace, push_values) |
| 456 | + ) |
| 457 | + |
| 458 | + new_inputs = [ |
| 459 | + input for i, input in enumerate(inner_graph.inputs) if i not in push_idxs |
| 460 | + ] |
| 461 | + return ( |
| 462 | + Elemwise(scalar_op=aes.Composite(new_inputs, inner_graph.outputs)) |
| 463 | + .make_node(*keep_values) |
| 464 | + .outputs |
| 465 | + ) |
| 466 | + |
| 467 | + |
| 468 | +compile.optdb["specialize"].register( |
| 469 | + "push_elemwise_constants", |
| 470 | + push_elemwise_constants, |
| 471 | + "fast_run_numba", |
| 472 | + "fast_compile_numba", |
| 473 | +) |
| 474 | + |
| 475 | + |
383 | 476 | @register_canonicalize
|
384 | 477 | @register_specialize
|
385 | 478 | @node_rewriter([DimShuffle])
|
@@ -898,34 +991,13 @@ def print_profile(cls, stream, prof, level=0):
|
898 | 991 | print(blanc, " time_toposort", prof[7], file=stream)
|
899 | 992 |
|
900 | 993 |
|
901 |
| -if config.tensor__local_elemwise_fusion: |
902 |
| - # Must be after gpu(48.5) and before AddDestroyHandler(49.5) |
903 |
| - fuse_seqopt = SequenceDB() |
904 |
| - fuse_seqopt.register( |
905 |
| - "composite_elemwise_fusion", |
906 |
| - FusionOptimizer(local_elemwise_fusion), |
907 |
| - "fast_run", |
908 |
| - "fusion", |
909 |
| - position=1, |
910 |
| - ) |
911 |
| - compile.optdb.register( # type: ignore |
912 |
| - "elemwise_fusion", |
913 |
| - fuse_seqopt, |
914 |
| - "fast_run", |
915 |
| - "fusion", |
916 |
| - "local_elemwise_fusion", |
917 |
| - "FusionOptimizer", |
918 |
| - position=49, |
919 |
| - ) |
920 |
| -else: |
921 |
| - compile.optdb.register( # type: ignore |
922 |
| - "elemwise_fusion", |
923 |
| - FusionOptimizer(local_elemwise_fusion), |
924 |
| - "fusion", |
925 |
| - "local_elemwise_fusion", |
926 |
| - "FusionOptimizer", |
927 |
| - position=49, |
928 |
| - ) |
| 994 | +compile.optdb["elemwise_fusion"].register( # type: ignore |
| 995 | + "composite_elemwise_fusion", |
| 996 | + FusionOptimizer(local_elemwise_fusion), |
| 997 | + "fast_run", |
| 998 | + "fusion", |
| 999 | + position=1, |
| 1000 | +) |
929 | 1001 |
|
930 | 1002 |
|
931 | 1003 | @register_canonicalize
|
|
0 commit comments