@@ -336,98 +336,68 @@ def local_subtensor_of_dot(fgraph, node):
336
336
@register_useless
337
337
@register_canonicalize
338
338
@register_specialize
339
+ @register_stabilize
339
340
@node_rewriter ([Subtensor ])
340
341
def local_useless_slice (fgraph , node ):
341
342
"""
342
343
Remove Subtensor of the form:
343
344
1. X[0, :] -> X[0]
344
345
2. X[:] -> X
345
346
346
- """
347
- idxs = get_idx_list (node .inputs , node .op .idx_list )
348
-
349
- if not idxs :
350
- return [node .inputs [0 ]]
351
-
352
- last_useless_slice = len (idxs )
353
- for s in idxs [::- 1 ]:
354
- # check if slice and then check slice indices
355
- if (
356
- isinstance (s , slice )
357
- and s .start is None
358
- and s .stop is None
359
- and (
360
- s .step is None
361
- or extract_constant (s .step , only_process_constants = True ) == 1
362
- )
363
- ):
364
- last_useless_slice -= 1
365
- else :
366
- break
367
- # check if we removed something
368
- if last_useless_slice < len (idxs ):
369
- new_idxs = idxs [:last_useless_slice ]
370
-
371
- if new_idxs :
372
- new_subtensor = Subtensor (new_idxs )
373
- new_subtensor_inputs = get_slice_elements (
374
- new_idxs , lambda x : isinstance (x , Variable )
375
- )
376
- out = new_subtensor (node .inputs [0 ], * new_subtensor_inputs )
377
- # Copy over previous output stacktrace
378
- copy_stack_trace (node .outputs , out )
379
- return [out ]
380
- else :
381
- # Subtensor is not needed at all
382
- return [node .inputs [0 ]]
383
-
384
-
385
- @register_useless
386
- @register_canonicalize
387
- @register_stabilize
388
- @register_specialize
389
- @node_rewriter ([Subtensor ])
390
- def local_replace_slice (fgraph , node ):
391
- """
392
- Rewrite Subtensor of the form:
347
+ Also, rewrite Subtensor of the form:
393
348
X[0:7:1] -> X[None:None:None]
394
- where X is a vector of length 7
349
+ where X is a vector of length 7
395
350
396
351
"""
397
352
idxs = get_idx_list (node .inputs , node .op .idx_list )
398
353
x = node .inputs [0 ]
399
354
400
355
if not idxs :
401
- return
356
+ return [ node . inputs [ 0 ]]
402
357
403
358
new_idxs = list (idxs )
404
- idx_flag = False
359
+ change_flag = False
360
+ last_useful_idx = - 1
405
361
for dim , s in enumerate (new_idxs ):
406
362
if not isinstance (s , slice ):
363
+ last_useful_idx = dim
364
+ continue
365
+
366
+ if s == slice (None ):
407
367
continue
408
368
409
369
start = s .start
410
370
stop = s .stop
411
371
step = s .step
412
- if extract_constant (start , only_process_constants = True ) == 0 :
413
- idx_flag = True
372
+ if (
373
+ start is not None
374
+ and extract_constant (start , only_process_constants = True ) == 0
375
+ ):
376
+ change_flag = True
414
377
start = None
415
378
416
379
if (
417
- x .type .shape [dim ] is not None
380
+ stop is not None
381
+ and x .type .shape [dim ] is not None
418
382
and extract_constant (stop , only_process_constants = True ) == x .type .shape [dim ]
419
383
):
420
- idx_flag = True
384
+ change_flag = True
421
385
stop = None
422
386
423
- if extract_constant (step , only_process_constants = True ) == 1 :
424
- idx_flag = True
387
+ if (
388
+ step is not None
389
+ and extract_constant (step , only_process_constants = True ) == 1
390
+ ):
391
+ change_flag = True
425
392
step = None
426
393
394
+ if not (start is None and stop is None and step is None ):
395
+ last_useful_idx = dim
396
+
427
397
new_idxs [dim ] = slice (start , stop , step )
428
398
429
- if idx_flag is True :
430
- out = x [tuple (new_idxs )]
399
+ if change_flag or (( last_useful_idx + 1 ) < len ( idxs )) :
400
+ out = x [tuple (new_idxs [: last_useful_idx + 1 ] )]
431
401
# Copy over previous output stacktrace
432
402
copy_stack_trace (node .outputs , out )
433
403
0 commit comments