@@ -356,6 +356,14 @@ class Pad(OpFromGraph):
356
356
Wrapper Op for Pad graphs
357
357
"""
358
358
359
+ def __init__ (self , inputs , outputs , pad_mode , reflect_type = None , kind = None ):
360
+ self .pad_mode = pad_mode
361
+ self .reflect_type = reflect_type
362
+ self .kind = kind
363
+ self .reflect_type = reflect_type
364
+
365
+ super ().__init__ (inputs = inputs , outputs = outputs )
366
+
359
367
360
368
def pad (x : TensorLike , pad_width : TensorLike , mode : PadMode = "constant" , ** kwargs ):
361
369
if any (value not in allowed_kwargs [mode ] for value in kwargs .keys ()):
@@ -388,9 +396,6 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
388
396
stat_length = as_tensor (stat_length , name = "stat_length" )
389
397
inputs += [stat_length ]
390
398
391
- attrs .update (
392
- {"stat_func" : stat_func , "stat_length_input" : stat_length is not None }
393
- )
394
399
outputs = _stat_pad (x , pad_width , stat_func , stat_length )
395
400
396
401
elif mode == "linear_ramp" :
@@ -401,15 +406,14 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
401
406
outputs = _linear_ramp_pad (x , pad_width , end_values )
402
407
403
408
elif mode == "wrap" :
404
- attrs .update ({"kind" : "wrap" })
405
409
outputs = _looping_pad (x , pad_width , kind = "wrap" )
406
410
407
411
elif mode == "symmetric" :
408
412
reflect_type = kwargs .pop ("reflect_type" , "even" )
409
413
if reflect_type == "odd" :
410
414
raise NotImplementedError ("Odd reflection not implemented" )
411
415
412
- attrs .update ({"kind " : reflect_type })
416
+ attrs .update ({"reflect_type " : reflect_type })
413
417
outputs = _looping_pad (x , pad_width , kind = "symmetric" )
414
418
415
419
elif mode == "reflect" :
@@ -421,11 +425,7 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
421
425
else :
422
426
raise ValueError (f"Invalid mode: { mode } " )
423
427
424
- op = Pad (inputs = inputs , outputs = [outputs ])(* inputs ) # type: ignore
425
-
426
- setattr (op , "pad_mode" , mode )
427
- for pad_arg , value in attrs .items ():
428
- setattr (op , pad_arg , value )
428
+ op = Pad (inputs = inputs , outputs = [outputs ], pad_mode = mode , ** attrs )(* inputs )
429
429
return op
430
430
431
431
0 commit comments