30
30
cast ,
31
31
constant ,
32
32
extract_constant ,
33
- fill ,
34
33
get_underlying_scalar_constant_value ,
35
34
ones_like ,
36
35
switch ,
@@ -2041,8 +2040,6 @@ def local_zero_div(fgraph, node):
2041
2040
@register_specialize
2042
2041
@node_rewriter ([at_pow ])
2043
2042
def local_pow_specialize (fgraph , node ):
2044
- # here, we are past the point of canonicalization, so we don't want
2045
- # to put in un-necessary fills.
2046
2043
if node .op == at_pow :
2047
2044
# the idea here is that we have pow(x, y)
2048
2045
odtype = node .outputs [0 ].dtype
@@ -2057,7 +2054,7 @@ def local_pow_specialize(fgraph, node):
2057
2054
if np .all (y == 1 ):
2058
2055
rval = [xsym ]
2059
2056
if np .all (y == 0 ):
2060
- rval = [fill ( xsym , np . asarray ( 1 , dtype = odtype ) )]
2057
+ rval = [alloc_like ( 1 , xsym , fgraph )]
2061
2058
if np .all (y == 0.5 ):
2062
2059
rval = [sqrt (xsym )]
2063
2060
if np .all (y == - 0.5 ):
@@ -2158,9 +2155,7 @@ def local_mul_specialize(fgraph, node):
2158
2155
mul(-1, x, y) -/-> neg(mul(x, y))
2159
2156
2160
2157
"""
2161
- # here, we are past the point of canonicalization, so we don't
2162
- # want to put in un-necessary fills.
2163
- #
2158
+
2164
2159
# at this point [post canonicalize], mul() may have many inputs.
2165
2160
if node .op == mul :
2166
2161
# the idea here is that we have pow(x, y)
@@ -2221,16 +2216,7 @@ def local_mul_specialize(fgraph, node):
2221
2216
2222
2217
@register_specialize
2223
2218
@node_rewriter ([add ])
2224
- def local_add_specialize (fgraph , node ):
2225
- """Remove zeros from ``add``s.
2226
-
2227
- TODO: This should be a canonicalization, no?
2228
- """
2229
- # here, we are past the point of canonicalization, so we don't want
2230
- # to put in un-necessary fills.
2231
- if node .op != add :
2232
- return False
2233
-
2219
+ def local_add_remove_zeros (fgraph , node ):
2234
2220
new_inputs = []
2235
2221
for inp in node .inputs :
2236
2222
try :
@@ -2253,12 +2239,12 @@ def local_add_specialize(fgraph, node):
2253
2239
# Reuse call to constant for cache()
2254
2240
cst = constant (np .zeros ((1 ,) * ndim , dtype = dtype ))
2255
2241
assert cst .type .broadcastable == (True ,) * ndim
2256
- return [broadcast_arrays (cst , * node . inputs )[ 0 ] ]
2242
+ return [alloc_like (cst , node_output , fgraph ) ]
2257
2243
2258
2244
if len (new_inputs ) == 1 :
2259
- ret = [broadcast_arrays (new_inputs [0 ], * node . inputs )[ 0 ] ]
2245
+ ret = [alloc_like (new_inputs [0 ], node_output , fgraph ) ]
2260
2246
else :
2261
- ret = [broadcast_arrays (add (* new_inputs ), * node . inputs )[ 0 ] ]
2247
+ ret = [alloc_like (add (* new_inputs ), node_output , fgraph ) ]
2262
2248
2263
2249
# The dtype should not be changed. It can happen if the input
2264
2250
# that was forcing upcasting was equal to 0.
@@ -2376,7 +2362,7 @@ def local_log1p(fgraph, node):
2376
2362
ninp = nonconsts [0 ]
2377
2363
if ninp .dtype != log_arg .type .dtype :
2378
2364
ninp = ninp .astype (node .outputs [0 ].dtype )
2379
- return [broadcast_arrays (log1p (ninp ), * scalar_inputs ) [0 ]]
2365
+ return [alloc_like (log1p (ninp ), node . outputs [0 ], fgraph ) ]
2380
2366
2381
2367
elif log_arg .owner and log_arg .owner .op == sub :
2382
2368
one = extract_constant (log_arg .owner .inputs [0 ], only_process_constants = True )
@@ -3572,10 +3558,11 @@ def local_reciprocal_1_plus_exp(fgraph, node):
3572
3558
if nonconsts [0 ].owner and nonconsts [0 ].owner .op == exp :
3573
3559
if scalars_ and np .allclose (np .sum (scalars_ ), 1 ):
3574
3560
out = [
3575
- broadcast_arrays (
3561
+ alloc_like (
3576
3562
sigmoid (neg (nonconsts [0 ].owner .inputs [0 ])),
3577
- * scalar_inputs ,
3578
- )[0 ]
3563
+ node .outputs [0 ],
3564
+ fgraph ,
3565
+ )
3579
3566
]
3580
3567
# keep combined stack traces of
3581
3568
# exp(x): nonconsts[0],
0 commit comments