Skip to content

Commit eeb9fa3

Browse files
Save keyword arguments in Pad OpFromGraph
1 parent 9c76a8f commit eeb9fa3

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

pytensor/tensor/pad.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,14 @@ class Pad(OpFromGraph):
356356
Wrapper Op for Pad graphs
357357
"""
358358

359+
def __init__(self, inputs, outputs, pad_mode, reflect_type=None, kind=None):
360+
self.pad_mode = pad_mode
361+
self.reflect_type = reflect_type
362+
self.kind = kind
363+
self.reflect_type = reflect_type
364+
365+
super().__init__(inputs=inputs, outputs=outputs)
366+
359367

360368
def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs):
361369
if any(value not in allowed_kwargs[mode] for value in kwargs.keys()):
@@ -388,9 +396,6 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
388396
stat_length = as_tensor(stat_length, name="stat_length")
389397
inputs += [stat_length]
390398

391-
attrs.update(
392-
{"stat_func": stat_func, "stat_length_input": stat_length is not None}
393-
)
394399
outputs = _stat_pad(x, pad_width, stat_func, stat_length)
395400

396401
elif mode == "linear_ramp":
@@ -401,15 +406,14 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
401406
outputs = _linear_ramp_pad(x, pad_width, end_values)
402407

403408
elif mode == "wrap":
404-
attrs.update({"kind": "wrap"})
405409
outputs = _looping_pad(x, pad_width, kind="wrap")
406410

407411
elif mode == "symmetric":
408412
reflect_type = kwargs.pop("reflect_type", "even")
409413
if reflect_type == "odd":
410414
raise NotImplementedError("Odd reflection not implemented")
411415

412-
attrs.update({"kind": reflect_type})
416+
attrs.update({"reflect_type": reflect_type})
413417
outputs = _looping_pad(x, pad_width, kind="symmetric")
414418

415419
elif mode == "reflect":
@@ -421,11 +425,7 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
421425
else:
422426
raise ValueError(f"Invalid mode: {mode}")
423427

424-
op = Pad(inputs=inputs, outputs=[outputs])(*inputs) # type: ignore
425-
426-
setattr(op, "pad_mode", mode)
427-
for pad_arg, value in attrs.items():
428-
setattr(op, pad_arg, value)
428+
op = Pad(inputs=inputs, outputs=[outputs], pad_mode=mode, **attrs)(*inputs)
429429
return op
430430

431431

0 commit comments

Comments
 (0)