@@ -389,56 +389,51 @@ def local_useless_slice(fgraph, node):
389
389
def local_replace_slice (fgraph , node ):
390
390
"""
391
391
Rewrite Subtensor of the form:
392
- 1. X[0:-1:1] -> X[:-1 ]
393
- 2. X[0:-1] -> X[:-1 ]
394
- 3. X[:-1] -> X[:-1 ]
392
+ 1. X[0:-1:1] -> X[None:None:None ]
393
+ 2. X[0:-1:2 ] -> X[None:None:2 ]
394
+ 3. X[3 :-1] -> X[3:None:None ]
395
395
396
396
"""
397
397
idxs = get_idx_list (node .inputs , node .op .idx_list )
398
398
x = node .inputs [0 ]
399
399
400
400
if not idxs :
401
- return [node .inputs [0 ]]
402
-
403
- # pytensor.dprint(node)
404
- # print(node.inputs[0].type.shape)
401
+ return [x ]
405
402
406
- # last_slice = len(idxs)
407
403
new_idxs = list (idxs )
404
+ idx_flag = False
405
+ for dim , s in enumerate (new_idxs ):
406
+ if not isinstance (s , slice ):
407
+ continue
408
+
409
+ flag = False
410
+ start = s .start
411
+ stop = s .stop
412
+ step = s .step
413
+ if start is None or extract_constant (start , only_process_constants = True ) == 0 :
414
+ flag = True
415
+ start = None
408
416
409
- # flag = False
410
- # index = -1
411
- # call s slice
412
- for dim , s in enumerate (idxs ):
413
417
if (
414
- isinstance (s , slice )
415
- and (s .start is None or extract_constant (s .start , only_process_constants = True ) == 0 )
416
- and (extract_constant (s .stop , only_process_constants = True ) == - 1 or extract_constant (s .stop , only_process_constants = True ) == node .inputs [0 ].type .shape [dim ])
417
- and (s .step is None or extract_constant (s .step , only_process_constants = True ) == 1 )
418
+ extract_constant (stop , only_process_constants = True ) == - 1
419
+ or extract_constant (stop , only_process_constants = True ) == x .type .shape [dim ]
418
420
):
419
- # break
420
- if index != - 1 :
421
- new_idxs [dim ] = slice (None , None , None )
422
- else :
423
- # exchange with if
424
- continue
425
- # if nothing changewd, return None
426
- # if index != -1:
427
- # new_idxs[dim] = slice(None, None, None)
428
-
429
- # new_subtensor = Subtensor(tuple(new_idxs))
430
- # new_subtensor_inputs = get_slice_elements(
431
- # new_idxs, lambda x: isinstance(x, Variable)
432
- # )
433
- # out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
434
- # # Copy over previous output stacktrace
435
- # copy_stack_trace(node.outputs, out)
436
- # return [out]
437
- if change :
438
- x [tuple (new_idxs )]
421
+ flag = True
422
+ stop = None
423
+
424
+ if step is None or extract_constant (step , only_process_constants = True ) == 1 :
425
+ flag = True
426
+ step = None
427
+
428
+ if flag :
429
+ idx_flag = True
430
+ new_idxs [dim ] = slice (start , stop , step )
431
+
432
+ if idx_flag is True :
433
+ return [x [tuple (new_idxs )]]
439
434
else :
440
435
# Subtensor is not needed at all
441
- return None
436
+ return [ x ]
442
437
443
438
444
439
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
0 commit comments