Skip to content

Commit 4f94380

Browse files
authored
Add clip\clamp activation (qubvel-org#518)
1 parent a288d33 commit 4f94380

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

segmentation_models_pytorch/base/modules.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ def forward(self, x):
7373
return torch.argmax(x, dim=self.dim)
7474

7575

76+
class Clamp(nn.Module):
77+
def __init__(self, min=0, max=1):
78+
super().__init__()
79+
self.min, self.max = min, max
80+
81+
def forward(self, x):
82+
return torch.clamp(x, self.min, self.max)
83+
84+
7685
class Activation(nn.Module):
7786

7887
def __init__(self, name, **params):
@@ -95,6 +104,8 @@ def __init__(self, name, **params):
95104
self.activation = ArgMax(**params)
96105
elif name == 'argmax2d':
97106
self.activation = ArgMax(dim=1, **params)
107+
elif name == 'clamp':
108+
self.activation = Clamp(**params)
98109
elif callable(name):
99110
self.activation = name(**params)
100111
else:

0 commit comments

Comments
 (0)