23
23
"""
24
24
25
25
import logging
26
- from typing import TYPE_CHECKING , Optional , Union
26
+ from typing import Union
27
27
28
28
import numpy as np
29
29
65
65
)
66
66
from pytensor .tensor .elemwise import DimShuffle , Elemwise
67
67
from pytensor .tensor .exceptions import NotScalarConstantError
68
- from pytensor .tensor .extra_ops import broadcast_shape , broadcast_to
68
+ from pytensor .tensor .extra_ops import broadcast_arrays
69
69
from pytensor .tensor .math import Sum , add
70
70
from pytensor .tensor .math import all as at_all
71
71
from pytensor .tensor .math import eq
72
- from pytensor .tensor .shape import Shape_i
72
+ from pytensor .tensor .shape import Shape_i , shape_padleft
73
73
from pytensor .tensor .sort import TopKOp
74
74
from pytensor .tensor .type import DenseTensorType , TensorType
75
75
from pytensor .tensor .var import TensorConstant , TensorVariable
76
76
from pytensor .utils import NoDuplicateOptWarningFilter
77
77
78
78
79
- if TYPE_CHECKING :
80
- from pytensor .tensor .rewriting .shape import ShapeFeature
81
-
82
-
83
79
_logger = logging .getLogger ("pytensor.tensor.rewriting.basic" )
84
80
_logger .addFilter (NoDuplicateOptWarningFilter ())
85
81
@@ -261,31 +257,16 @@ def local_scalar_tensor_scalar(fgraph, node):
261
257
def local_elemwise_alloc (fgraph , node ):
262
258
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
263
259
264
- `Alloc`\s are effectively a type of `Elemwise` operation
265
- (e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so
266
- this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to
267
- `Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it
268
- broadcasts).
269
-
270
- In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant
271
- `Alloc`\s.
272
-
273
260
The rewrite essentially performs the following replacement:
274
- ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``,
275
- when ``y.shape`` for some input ``y`` (or the combined shapes of the
276
- non-`Alloc`\s) is sufficient to maintain the same/correct output shape.
261
+ ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``
277
262
278
- In it's current form, it also explicitly accounts for `DimShuffle`\s of
263
+ In its current form, it also explicitly accounts for `DimShuffle`\s of
279
264
`Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which
280
265
introduces them as a canonicalization of `Alloc`'s with leading
281
266
broadcastable dimensions.
282
267
"""
283
- # Rewrite is only applicable when there are at least two inputs
284
268
if len (node .inputs ) == 1 :
285
- return False
286
-
287
- if len (node .outputs ) > 1 :
288
- return False
269
+ return None
289
270
290
271
def dimshuffled_alloc (i ):
291
272
return (
@@ -305,76 +286,40 @@ def dimshuffled_alloc(i):
305
286
if len (alloc_idxs ) == 0 :
306
287
return False
307
288
308
- # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
309
- # baseline for the dimensions.
310
- ref_var_idx = None
311
- for idx , i in enumerate (node .inputs ):
312
- if i .type .broadcastable == node .outputs [0 ].type .broadcastable :
313
- # Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
314
- # `Alloc`, so that all `Alloc`s can be rewritten.
315
- if idx not in alloc_idxs :
316
- ref_var_idx = idx
317
- break
318
-
319
- # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
320
- if ref_var_idx is None :
321
- for idx , i in enumerate (node .inputs ):
322
- # XXX: This broadcastable comparison doesn't work
323
- if (
324
- i .type .broadcastable == node .outputs [0 ].type .broadcastable
325
- ) and idx in alloc_idxs :
326
- ref_var_idx = idx
327
- break
328
-
329
- if not hasattr (fgraph , "shape_feature" ):
330
- return False
331
-
332
- input_shapes = [
333
- tuple (fgraph .shape_feature .get_shape (i , j ) for j in range (i .type .ndim ))
334
- for i in node .inputs
335
- ]
336
- bcasted_shape = broadcast_shape (
337
- * input_shapes ,
338
- arrays_are_shapes = True ,
339
- )
340
-
341
289
new_inputs = list (node .inputs )
342
290
for idx in alloc_idxs :
343
291
i = node .inputs [idx ]
344
292
345
- # Remove `Alloc`
293
+ # Remove simple `Alloc`
346
294
if isinstance (i .owner .op , Alloc ):
347
- new_alloc = broadcast_to ( i .owner .inputs [0 ], bcasted_shape )
295
+ new_inp = i .owner .inputs [0 ]
348
296
349
- # TODO FIXME: This shouldn't be handled here.
350
- # `DimShuffle`s should be lifted through `Alloc`s
351
- # by other, more general rewrites.
352
- # Remove `Alloc` in `DimShuffle`
297
+ # Remove `Dimshuffle(Alloc)`
353
298
elif isinstance (i .owner .op , DimShuffle ):
354
299
old_alloc = i .owner .inputs [0 ]
355
- new_alloc = old_alloc .owner .inputs [0 ]
300
+ old_alloc_inp = old_alloc .owner .inputs [0 ]
301
+ missing_ndims = old_alloc .type .ndim - old_alloc_inp .type .ndim
302
+ if missing_ndims > 0 :
303
+ # The `Alloc` added new dimensions to the left.
304
+ # We replace those cases with a `DimShuffle` here.
305
+ # Nested dimshuffles will be merged later by other rewrites.
306
+ old_alloc_inp = shape_padleft (old_alloc_inp , missing_ndims )
356
307
# We need to keep the old `DimShuffle`. It could swap axes or
357
308
# add dimensions anywhere.
358
- if new_alloc .ndim != old_alloc .ndim :
359
- # The `Alloc` can add dimensions to the value.
360
- # We replace those cases with a `DimShuffle` here.
361
- nb_dim_to_add = old_alloc .ndim - new_alloc .ndim
362
- new_alloc = new_alloc .dimshuffle (
363
- ["x" ] * nb_dim_to_add + list (range (new_alloc .ndim ))
364
- )
365
- new_alloc = broadcast_to (i .owner .op (new_alloc ), bcasted_shape )
309
+ new_inp = i .owner .op (old_alloc_inp )
366
310
367
- copy_stack_trace (i , new_alloc )
368
- new_inputs [idx ] = new_alloc
311
+ copy_stack_trace (i , new_inp )
312
+ new_inputs [idx ] = new_inp
369
313
370
- # If this assert is triggered, it means we are recreating an equivalent graph
371
- # which would result in cyclical merge rewrites.
372
- if all (new is old for new , old in zip (new_inputs , node .inputs )):
373
- return
314
+ new_outs = node .op (* new_inputs , return_list = True )
374
315
375
- ret = node .op (* new_inputs , return_list = True )
376
- copy_stack_trace (node .outputs , ret )
377
- return ret
316
+ if new_outs [0 ].type .broadcastable != node .outputs [0 ].type .broadcastable :
317
+ new_outs = [
318
+ alloc_like (new_out , node .outputs [0 ], fgraph ) for new_out in new_outs
319
+ ]
320
+
321
+ copy_stack_trace (node .outputs , new_outs )
322
+ return new_outs
378
323
379
324
380
325
@register_canonicalize ("shape_unsafe" )
@@ -406,6 +351,7 @@ def local_fill_sink(fgraph, node):
406
351
407
352
# The newly created node c doesn't has 'clients',
408
353
# so this iteration is took place with node.outputs[0]
354
+ # TODO: This should just be a WalkingGraphRewrite!
409
355
replacements = {node .outputs [0 ]: c }
410
356
for client , cl_idx in fgraph .clients [node .outputs [0 ]]:
411
357
if (
@@ -438,23 +384,15 @@ def local_fill_to_alloc(fgraph, node):
438
384
with their dependencies on those tensors' shapes, and sometimes those
439
385
shapes can be computed without needing to compute the tensors themselves.
440
386
441
- XXX: This rewrite can produce inconsistent results, so do *not* consider
442
- making it a canonicalization until those inconsistencies are
443
- resolved/justified.
387
+ Like `local_fill_sink` this rewrites assumes non-broadcastable shapes are equivalent,
388
+ which could mask shape errors.
444
389
"""
445
390
shape_ref , values_ref = node .inputs
446
391
out_type = node .outputs [0 ].type
447
392
448
393
if values_ref .type .broadcastable == out_type .broadcastable :
449
394
# The assumption here is that `values_ref` already has the same shape
450
395
# as `shape_ref`, so a `fill`/`Alloc` is unnecessary.
451
-
452
- # XXX FIXME TODO: The only way this can be determined is if one
453
- # absolutely knows that the shapes of `shape_ref` and `values_ref` are
454
- # equal.
455
- # This is an old rewrite, and it's only a
456
- # "specialization/stabilization", so we're going to leave it be for
457
- # now.
458
396
return [values_ref ]
459
397
460
398
if shape_ref .type .broadcastable == out_type .broadcastable :
@@ -465,6 +403,9 @@ def local_fill_to_alloc(fgraph, node):
465
403
copy_stack_trace (node .outputs [0 ], o )
466
404
return [o ]
467
405
406
+ # The case that is not covered is when `shape_ref` is broadcasted by `values_ref`
407
+ # TODO: Return broadcast_to(values_ref, broadcast_shapes(values_ref.shape, shape_ref.shape))
408
+
468
409
return
469
410
470
411
@@ -1014,36 +955,30 @@ def local_sum_make_vector(fgraph, node):
1014
955
return [element_sum ]
1015
956
1016
957
1017
- @register_useless ("local_remove_switch_const_cond " )
1018
- @register_canonicalize ("fast_compile" , "local_remove_switch_const_cond " )
1019
- @register_specialize
1020
- @node_rewriter ([Elemwise ])
958
+ @register_useless ("shape_unsafe " )
959
+ @register_canonicalize ("fast_compile" , "shape_unsafe " )
960
+ @register_specialize ( "shape_unsafe" )
961
+ @node_rewriter ([switch ])
1021
962
def local_useless_switch (fgraph , node ):
1022
963
"""
1023
964
This rewrite makes the following changes in a graph:
1024
965
1025
- at. switch(cond, left, right) ->
1026
- if cond is constant and cond == 0: right
1027
- if cond is constant and cond != 0: left
1028
- if left is right -> left
966
+ switch(cond, left, right) ->
967
+ if cond is constant and cond == 0: right
968
+ if cond is constant and cond != 0: left
969
+ if left is right -> left
1029
970
1030
971
and
1031
972
1032
- at. switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
973
+ switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
1033
974
1034
975
"""
1035
- if not isinstance (node .op .scalar_op , aes .Switch ):
1036
- return False
1037
-
1038
- shape_feature : Optional ["ShapeFeature" ] = getattr (fgraph , "shape_feature" , None )
1039
-
1040
- if shape_feature is None :
1041
- return False
1042
976
1043
977
left = node .inputs [1 ]
1044
978
right = node .inputs [2 ]
1045
979
cond_var = node .inputs [0 ]
1046
980
cond = extract_constant (cond_var , only_process_constants = True )
981
+ out_bcast = node .outputs [0 ].type .broadcastable
1047
982
1048
983
if (isinstance (cond , np .ndarray ) and cond .ndim == 0 ) or isinstance (
1049
984
cond , (np .number , np .bool_ )
@@ -1058,14 +993,8 @@ def local_useless_switch(fgraph, node):
1058
993
else :
1059
994
out = correct_out
1060
995
1061
- input_shapes = [
1062
- tuple (shape_feature .get_shape (inp , i ) for i in range (inp .type .ndim ))
1063
- for inp in node .inputs
1064
- ]
1065
-
1066
- out_shape = broadcast_shape (* input_shapes , arrays_are_shapes = True )
1067
-
1068
- out = alloc (out , * out_shape )
996
+ if out .type .broadcastable != out_bcast :
997
+ out = broadcast_arrays (out , * node .inputs )[0 ]
1069
998
1070
999
# Copy over stacktrace from selected output to new output
1071
1000
copy_stack_trace (node .outputs + correct_out , out )
@@ -1075,10 +1004,10 @@ def local_useless_switch(fgraph, node):
1075
1004
if left == right :
1076
1005
# Note: No need to copy over stacktrace, because the input node
1077
1006
# already has its own stacktrace
1078
- if cond . type . is_super ( left .type ) :
1007
+ if left .type . broadcastable == out_bcast :
1079
1008
return [left ]
1080
1009
1081
- ret = fill ( cond , left )
1010
+ ret = broadcast_arrays ( left , cond )[ 0 ]
1082
1011
1083
1012
# Copy over stacktrace from switch output and correct branch
1084
1013
copy_stack_trace (node .outputs + left , ret )
0 commit comments