Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Implement more meaningful Reshape operation #882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ricardoV94 opened this issue Jul 4, 2024 · 0 comments
Closed

Implement more meaningful Reshape operation #882

ricardoV94 opened this issue Jul 4, 2024 · 0 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 4, 2024

Description

Analyzing graphs with reshape operations is rather complex because Reshape represents what we want, but not "what it means"".

Except for esoteric cases where Reshape shapes may come from a complex computation / shapes of other variables, it is usually a case of multiplying some dimensions (merging) and diving others (splitting). We could represent these cases with some sort of symbolic mapping:

x = tensor(shape=(4, 3, 2))
x.reshape(4, 6)  # JoinDims(0, (1, 2))

It almost begs for an extension of DimShuffle, which was brought up before: Theano/Theano#4640

Splitting dims is trickier, because there are many choices, we can split in different orders and sizes

x = tensor(shape=(12,))
x.reshape(2, 2, 3)
x.reshape(2, 3, 2)
x.reshape(4, 3)
...

Still an Op that achieves the same as splitting via reshape but knows which dims are going where (and in what quantities), would be more readable


An example where Reshape is currently hard to work with is during vectorization. If we have a common graph like reshape(x, x.shape[0] * x.shape[1], -1) we cannot return the desired output reshape(new_x, x.shape[0], x.shape[1] * x.shape[2], -1) eagerly because there is a chain of complex operations we must vectorize before we get to the Reshape node (Shape -> Subtensor -> Mul -> MakeVector). So we need to put it in a costly Blockwise and try our best to remove it during rewrites. This came up in #702 when vectorizing tensordot to get a batched_tensordot

Such a problem wouldn't exist with a symbolic reshape that is told what dims are being joined/split.

It also makes rewrites to remove/lift reshapes much simpler than they currently are:

def local_useless_reshape(fgraph, node):
"""Remove two kinds of useless `Reshape`.
- Remove `Reshape` when both the input and output have a single dimension.
- Remove `Reshape` when reshaping to the shape of the input.
"""
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
if inp.type.ndim != output.type.ndim:
return False
# Simple case: both input and output have a single dimension.
# TODO FIXME XXX: This could hide errors if the user provides inconsistent
# shapes.
if (
inp.type.ndim == 1
and output.type.ndim == 1
and all(
s1 == s2
for s1, s2 in zip(inp.type.shape, output.type.shape)
if s1 == 1 or s2 == 1
)
):
return [inp]
# Second case: all the shapes match the input shape
# Match Reshape(x, x.shape)
if output_shape.owner and isinstance(output_shape.owner.op, Shape):
shape_input = output_shape.owner.inputs[0]
if shape_input == inp:
return [inp]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
output_shape_is = output_shape.owner.inputs
shape_feature = getattr(fgraph, "shape_feature", None)
nb_m1 = 0
shape_match = [False] * inp.type.ndim
for dim in range(inp.type.ndim):
outshp_i = output_shape_is[dim]
# Match Shape_i{dim}(input)
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Shape_i)
and outshp_i.owner.op.i == dim
and outshp_i.owner.inputs[0] == inp
):
shape_match[dim] = True
continue
# Match Shape(input)[dim]
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Subtensor)
and len(outshp_i.owner.inputs) == 2
and extract_constant(outshp_i.owner.inputs[1]) == dim
):
subtensor_inp = outshp_i.owner.inputs[0]
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
shape_input_i = subtensor_inp.owner.inputs[0]
if shape_input_i == inp:
shape_match[dim] = True
continue
# Match 1 if input.type.shape[dim] == 1
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.type.shape[dim] == 1 and cst_outshp_i == 1:
shape_match[dim] = True
continue
# Match -1
if cst_outshp_i == -1:
shape_match[dim] = True
nb_m1 += 1
continue
# Match shape_of[input][dim] or its constant equivalent
if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=1)
== extract_constant(outshp_i, only_process_constants=1)
):
shape_match[dim] = True
continue
if all(shape_match) and nb_m1 <= 1:
return [inp]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return False


This is somewhat related to why we have Second and Alloc. The first one is easier to reason about because it tells us more immediately that we are broadcasting with the shape of a variable, whereas Alloc specifies the desired output without its meaning (specially after some rewrites, where the shape may become dissociated from the original variable)

Notes
-----
There are two ways of broadcasting arrays:
second(x, y) == alloc(y, broadcast_shapes(x.shape, y.shape))
The second can be more efficient because x doesn't usually need to be computed when we only want its shape.
It may also allow other rewrites that don't try to modify x when it has multiple clients (for fear of duplicating computation).
However, the first one is easier to reason about.
Knowing we have such a graph allows to do certain rewrites such as "sinking" broadcasting operations below Elemwise.
The same rewrites with alloc would be more complicated as we would need to symbolically combine the shapes of each one.
As an example contrast rewriting the following two equivalent graphs
alloc(x, broadcast_shapes(x.shape, y.shape)) + alloc(y, broadcast_shapes(x.shape, y.shape)) -> x + y
second(y, x) + second(x, y) -> x + y
Theano developers (mostly) preferred to use the first form during canonicalization and introduce the second form later,
via rewrites like `local_fill_to_alloc`, and using the `alloc_like` helper inside rewrites.
Many stabilize and stabilization rewrites refuse to be applied when a variable has multiple clients, so this is important.
"""

@pymc-devs pymc-devs locked and limited conversation to collaborators Jul 4, 2024
@ricardoV94 ricardoV94 converted this issue into discussion #883 Jul 4, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Projects
None yet
Development

No branches or pull requests

1 participant