|
103 | 103 | local_mul_canonizer,
|
104 | 104 | local_mul_switch_sink,
|
105 | 105 | local_reduce_chain,
|
| 106 | + local_reduce_join, |
106 | 107 | local_sum_prod_of_mul_or_div,
|
107 | 108 | mul_canonizer,
|
108 | 109 | parse_mul_tree,
|
@@ -3415,6 +3416,24 @@ def test_not_supported_unequal_shapes(self):
|
3415 | 3416 | f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
|
3416 | 3417 | )
|
3417 | 3418 |
|
| 3419 | + def test_non_ds_inputs(self): |
| 3420 | + """Make sure rewrite works when inputs to join are not the usual DimShuffle. |
| 3421 | +
|
| 3422 | + Sum{axis=1} [id A] <Vector(float64, shape=(3,))> |
| 3423 | + └─ Join [id B] <Matrix(float64, shape=(3, 3))> |
| 3424 | + ├─ 1 [id C] <Scalar(int8, shape=())> |
| 3425 | + ├─ ExpandDims{axis=1} [id D] <Matrix(float64, shape=(3, 1))> |
| 3426 | + ├─ Sub [id E] <Matrix(float64, shape=(3, 1))> |
| 3427 | + └─ Sub [id F] <Matrix(float64, shape=(3, 1))> |
| 3428 | + """ |
| 3429 | + x = vector("x") |
| 3430 | + out = join(0, exp(x[None]), log(x[None])).sum(axis=0) |
| 3431 | + |
| 3432 | + fg = FunctionGraph([x], [out], clone=False) |
| 3433 | + [rewritten_out] = local_reduce_join.transform(fg, out.owner) |
| 3434 | + expected_out = add(exp(x), log(x)) |
| 3435 | + assert equal_computations([rewritten_out], [expected_out]) |
| 3436 | + |
3418 | 3437 |
|
3419 | 3438 | def test_local_useless_adds():
|
3420 | 3439 | default_mode = get_default_mode()
|
|
0 commit comments