Skip to content

Commit 4eb7828

Browse files
Intermediate draft changes
1 parent f0f43f0 commit 4eb7828

File tree

2 files changed

+42
-28
lines changed

2 files changed

+42
-28
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -381,49 +381,64 @@ def local_useless_slice(fgraph, node):
381381
return [node.inputs[0]]
382382

383383

384-
@register_infer_shape
384+
@register_useless
385385
@register_canonicalize
386+
@register_stabilize
386387
@register_specialize
387388
@node_rewriter([Subtensor])
388389
def local_replace_slice(fgraph, node):
389390
"""
390-
Remove Subtensor of the form:
391+
Rewrite Subtensor of the form:
391392
1. X[0:-1:1] -> X[:-1]
392393
2. X[0:-1] -> X[:-1]
393394
3. X[:-1] -> X[:-1]
394395
395396
"""
396397
idxs = get_idx_list(node.inputs, node.op.idx_list)
398+
x = node.inputs[0]
397399

398400
if not idxs:
399401
return [node.inputs[0]]
400402

403+
# pytensor.dprint(node)
404+
# print(node.inputs[0].type.shape)
405+
401406
# last_slice = len(idxs)
407+
new_idxs = list(idxs)
402408

403-
for s in idxs[::-1]:
409+
# flag = False
410+
# index = -1
411+
# call s slice
412+
for dim, s in enumerate(idxs):
404413
if (
405414
isinstance(s, slice)
406-
and (
407-
s.start is None
408-
or extract_constant(s.start, only_process_constants=True) == 0
409-
)
410-
and extract_constant(s.stop, only_process_constants=True) == -1
411-
and (
412-
s.step is None
413-
or extract_constant(s.step, only_process_constants=True) == 1
414-
)
415+
and (s.start is None or extract_constant(s.start, only_process_constants=True) == 0)
416+
and (extract_constant(s.stop, only_process_constants=True) == -1 or extract_constant(s.stop, only_process_constants=True) == node.inputs[0].type.shape[dim])
417+
and (s.step is None or extract_constant(s.step, only_process_constants=True) == 1)
415418
):
416-
# This does not work.
417-
# I get the error that
418-
# ```
419-
# s.start = None
420-
# AttributeError: readonly attribute
421-
# ```
422-
s.start = None
423-
s.stop = -1
424-
s.step = None
419+
# break
420+
if index != -1:
421+
new_idxs[dim] = slice(None, None, None)
425422
else:
426-
break
423+
# exchange with if
424+
continue
425+
# if nothing changewd, return None
426+
# if index != -1:
427+
# new_idxs[dim] = slice(None, None, None)
428+
429+
# new_subtensor = Subtensor(tuple(new_idxs))
430+
# new_subtensor_inputs = get_slice_elements(
431+
# new_idxs, lambda x: isinstance(x, Variable)
432+
# )
433+
# out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
434+
# # Copy over previous output stacktrace
435+
# copy_stack_trace(node.outputs, out)
436+
# return [out]
437+
if change:
438+
x[tuple(new_idxs)]
439+
else:
440+
# Subtensor is not needed at all
441+
return None
427442

428443

429444
# fast_compile to allow opt subtensor(cast{float32}(make_vector))

tests/tensor/rewriting/test_subtensor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2406,11 +2406,10 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
24062406

24072407
def test_slice_canonicalize():
24082408
# Dummy test. Need to and will be changed
2409-
x = pt.vector("x")
2410-
y1 = x[:-1]
2411-
y2 = x[0:-1]
2412-
y1 = x[0:-1:1]
2409+
x = tensor("x", shape=(3, 5, None, 9))
2410+
y = x[:-1, 0:5, 0:7, 0:-1:-1]
24132411

2414-
f = pytensor.function([x], [y1, y2, y3])
2412+
f = pytensor.function([x], [y])
2413+
# use equal_computations
24152414
pytensor.dprint(f)
2416-
# assert 0
2415+
assert 0

0 commit comments

Comments
 (0)