@@ -381,49 +381,64 @@ def local_useless_slice(fgraph, node):
381
381
return [node .inputs [0 ]]
382
382
383
383
384
- @register_infer_shape
384
+ @register_useless
385
385
@register_canonicalize
386
+ @register_stabilize
386
387
@register_specialize
387
388
@node_rewriter ([Subtensor ])
388
389
def local_replace_slice (fgraph , node ):
389
390
"""
390
- Remove Subtensor of the form:
391
+ Rewrite Subtensor of the form:
391
392
1. X[0:-1:1] -> X[:-1]
392
393
2. X[0:-1] -> X[:-1]
393
394
3. X[:-1] -> X[:-1]
394
395
395
396
"""
396
397
idxs = get_idx_list (node .inputs , node .op .idx_list )
398
+ x = node .inputs [0 ]
397
399
398
400
if not idxs :
399
401
return [node .inputs [0 ]]
400
402
403
+ # pytensor.dprint(node)
404
+ # print(node.inputs[0].type.shape)
405
+
401
406
# last_slice = len(idxs)
407
+ new_idxs = list (idxs )
402
408
403
- for s in idxs [::- 1 ]:
409
+ # flag = False
410
+ # index = -1
411
+ # call s slice
412
+ for dim , s in enumerate (idxs ):
404
413
if (
405
414
isinstance (s , slice )
406
- and (
407
- s .start is None
408
- or extract_constant (s .start , only_process_constants = True ) == 0
409
- )
410
- and extract_constant (s .stop , only_process_constants = True ) == - 1
411
- and (
412
- s .step is None
413
- or extract_constant (s .step , only_process_constants = True ) == 1
414
- )
415
+ and (s .start is None or extract_constant (s .start , only_process_constants = True ) == 0 )
416
+ and (extract_constant (s .stop , only_process_constants = True ) == - 1 or extract_constant (s .stop , only_process_constants = True ) == node .inputs [0 ].type .shape [dim ])
417
+ and (s .step is None or extract_constant (s .step , only_process_constants = True ) == 1 )
415
418
):
416
- # This does not work.
417
- # I get the error that
418
- # ```
419
- # s.start = None
420
- # AttributeError: readonly attribute
421
- # ```
422
- s .start = None
423
- s .stop = - 1
424
- s .step = None
419
+ # break
420
+ if index != - 1 :
421
+ new_idxs [dim ] = slice (None , None , None )
425
422
else :
426
- break
423
+ # exchange with if
424
+ continue
425
+ # if nothing changewd, return None
426
+ # if index != -1:
427
+ # new_idxs[dim] = slice(None, None, None)
428
+
429
+ # new_subtensor = Subtensor(tuple(new_idxs))
430
+ # new_subtensor_inputs = get_slice_elements(
431
+ # new_idxs, lambda x: isinstance(x, Variable)
432
+ # )
433
+ # out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
434
+ # # Copy over previous output stacktrace
435
+ # copy_stack_trace(node.outputs, out)
436
+ # return [out]
437
+ if change :
438
+ x [tuple (new_idxs )]
439
+ else :
440
+ # Subtensor is not needed at all
441
+ return None
427
442
428
443
429
444
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
0 commit comments