|
41 | 41 | )
|
42 | 42 | from pytensor.graph.rewriting.db import RewriteDatabase
|
43 | 43 | from pytensor.raise_op import Assert, CheckAndRaise, assert_op
|
| 44 | +from pytensor.scalar.basic import Second |
44 | 45 | from pytensor.tensor.basic import (
|
45 | 46 | Alloc,
|
46 | 47 | AllocEmpty,
|
@@ -320,56 +321,52 @@ def dimshuffled_alloc(i):
|
320 | 321 | return new_outs
|
321 | 322 |
|
322 | 323 |
|
323 |
| -@register_canonicalize("shape_unsafe") |
324 | 324 | @node_rewriter([Elemwise])
|
325 | 325 | def local_fill_sink(fgraph, node):
|
326 | 326 | """
|
327 | 327 | f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
|
328 | 328 | f need to be an elemwise that isn't a fill.
|
329 | 329 | """
|
330 |
| - if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill: |
| 330 | + if isinstance(node.op.scalar_op, Second): |
331 | 331 | return False
|
| 332 | + |
332 | 333 | models = []
|
333 | 334 | inputs = []
|
334 | 335 | for inp in node.inputs:
|
335 | 336 | if inp.owner and inp.owner.op == fill:
|
336 |
| - models.append(inp.owner.inputs[0]) |
337 |
| - inputs.append(inp.owner.inputs[1]) |
| 337 | + a, b = inp.owner.inputs |
| 338 | + if b.type.dtype != inp.dtype: |
| 339 | + # The input was implicitly casted by the fill operation |
| 340 | + b = b.cast(inp.dtype) |
| 341 | + models.append(a) |
| 342 | + inputs.append(b) |
338 | 343 | else:
|
339 | 344 | inputs.append(inp)
|
| 345 | + |
340 | 346 | if not models:
|
341 | 347 | return False
|
342 |
| - c = node.op(*inputs) |
343 |
| - for model in models: |
344 |
| - if ( |
345 |
| - model.type.dtype != c.type.dtype |
346 |
| - or model.type.broadcastable != c.type.broadcastable |
347 |
| - ): |
348 |
| - c = fill(model, c) |
349 | 348 |
|
350 |
| - # The newly created node c doesn't has 'clients', |
351 |
| - # so this iteration is took place with node.outputs[0] |
352 |
| - # TODO: This should just be a WalkingGraphRewrite! |
353 |
| - replacements = {node.outputs[0]: c} |
354 |
| - for client, cl_idx in fgraph.clients[node.outputs[0]]: |
355 |
| - if ( |
356 |
| - hasattr(client, "op") |
357 |
| - and isinstance(client.op, Elemwise) |
358 |
| - and client.op != fill |
359 |
| - ): |
360 |
| - client_inputs = client.inputs[:] |
361 |
| - client_inputs[cl_idx] = c |
362 |
| - new_client = client.op(*client_inputs) |
363 |
| - |
364 |
| - # Add clients to new_client |
365 |
| - fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[ |
366 |
| - client.outputs[0] |
367 |
| - ] |
368 |
| - r = local_fill_sink.transform(fgraph, new_client.owner) |
369 |
| - if not r: |
370 |
| - continue |
371 |
| - replacements.update(r) |
372 |
| - return replacements |
| 349 | + outputs = node.op.make_node(*inputs).outputs |
| 350 | + |
| 351 | + # Check if we need to propagate the fill to the new outputs |
| 352 | + # It's enough to check the first output, as Elemwise outputs must all have the same shapes |
| 353 | + # Note: There are orderings that may require fewer fills. |
| 354 | + old_bcast_pattern = node.outputs[0].type.broadcastable |
| 355 | + models_iter = iter(models) |
| 356 | + while old_bcast_pattern != outputs[0].type.broadcastable: |
| 357 | + model = next(models_iter) |
| 358 | + # Only apply this model if it would actually do anything |
| 359 | + if broadcasted_by(outputs[0], model): |
| 360 | + outputs = [fill(model, output) for output in outputs] |
| 361 | + |
| 362 | + return outputs |
| 363 | + |
| 364 | + |
| 365 | +# The rewrite is wrapped in an in2out GraphRewriter |
| 366 | +# so that fill can be sinked until the terminal nodes in a single pass through the graph |
| 367 | +# without triggering other rewrites after each local substitution |
| 368 | +topological_fill_sink = in2out(local_fill_sink) |
| 369 | +register_canonicalize(topological_fill_sink, "shape_unsafe") |
373 | 370 |
|
374 | 371 |
|
375 | 372 | @register_specialize("shape_unsafe")
|
|
0 commit comments