@@ -336,15 +336,21 @@ def local_subtensor_of_dot(fgraph, node):
336
336
@register_useless
337
337
@register_canonicalize
338
338
@register_specialize
339
+ @register_stabilize
339
340
@node_rewriter ([Subtensor ])
340
341
def local_useless_slice (fgraph , node ):
341
342
"""
342
343
Remove Subtensor of the form:
343
344
1. X[0, :] -> X[0]
344
345
2. X[:] -> X
345
346
347
+ Also, rewrite Subtensor of the form:
348
+ X[0:7:1] -> X[None:None:None]
349
+ where X is a vector of length 7
350
+
346
351
"""
347
352
idxs = get_idx_list (node .inputs , node .op .idx_list )
353
+ x = node .inputs [0 ]
348
354
349
355
if not idxs :
350
356
return [node .inputs [0 ]]
@@ -364,74 +370,46 @@ def local_useless_slice(fgraph, node):
364
370
last_useless_slice -= 1
365
371
else :
366
372
break
367
- # check if we removed something
368
- if last_useless_slice < len (idxs ):
369
- new_idxs = idxs [:last_useless_slice ]
370
-
371
- if new_idxs :
372
- new_subtensor = Subtensor (new_idxs )
373
- new_subtensor_inputs = get_slice_elements (
374
- new_idxs , lambda x : isinstance (x , Variable )
375
- )
376
- out = new_subtensor (node .inputs [0 ], * new_subtensor_inputs )
377
- # Copy over previous output stacktrace
378
- copy_stack_trace (node .outputs , out )
379
- return [out ]
380
- else :
381
- # Subtensor is not needed at all
382
- return [node .inputs [0 ]]
383
-
384
-
385
- @register_useless
386
- @register_canonicalize
387
- @register_stabilize
388
- @register_specialize
389
- @node_rewriter ([Subtensor ])
390
- def local_replace_slice (fgraph , node ):
391
- """
392
- Rewrite Subtensor of the form:
393
- X[0:7:1] -> X[None:None:None]
394
- where X is a vector of length 7
395
-
396
- """
397
- idxs = get_idx_list (node .inputs , node .op .idx_list )
398
- x = node .inputs [0 ]
399
373
400
- if not idxs :
401
- return
402
-
403
- new_idxs = list (idxs )
404
- idx_flag = False
374
+ new_idxs = list (idxs )[:last_useless_slice ]
375
+ change_flag = False
405
376
for dim , s in enumerate (new_idxs ):
406
- if not isinstance (s , slice ):
377
+ if not isinstance (s , slice ) or s == slice ( None ) :
407
378
continue
408
379
409
380
start = s .start
410
381
stop = s .stop
411
382
step = s .step
412
- if extract_constant (start , only_process_constants = True ) == 0 :
413
- idx_flag = True
383
+ if (
384
+ start is not None
385
+ and extract_constant (start , only_process_constants = True ) == 0
386
+ ):
387
+ change_flag = True
414
388
start = None
415
389
416
390
if (
417
- x .type .shape [dim ] is not None
391
+ stop is not None
392
+ and x .type .shape [dim ] is not None
418
393
and extract_constant (stop , only_process_constants = True ) == x .type .shape [dim ]
419
394
):
420
- idx_flag = True
395
+ change_flag = True
421
396
stop = None
422
397
423
- if extract_constant (step , only_process_constants = True ) == 1 :
424
- idx_flag = True
398
+ if (
399
+ step is not None
400
+ and extract_constant (step , only_process_constants = True ) == 1
401
+ ):
402
+ change_flag = True
425
403
step = None
426
404
427
405
new_idxs [dim ] = slice (start , stop , step )
428
406
429
- if idx_flag is True :
407
+ if change_flag is True :
430
408
out = x [tuple (new_idxs )]
431
409
# Copy over previous output stacktrace
432
410
copy_stack_trace (node .outputs , out )
433
411
434
- return [out ]
412
+ return [out ][: last_useless_slice ]
435
413
436
414
437
415
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
0 commit comments