Skip to content

Commit 1dc7e4c

Browse files
committed
feat: add parametric relu
1 parent 40f65e8 commit 1dc7e4c

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Parametric Rectified Linear Unit (PReLU)
3+
4+
Use Case: PReLU addresses the problem of dying ReLU by allowing a small, learnable slope for negative values, which can improve model performance.
5+
6+
For more detailed information, you can refer to the following link:
7+
https://en.wikipedia.org/wiki/Rectifier_(neural_networks)#Parametric_ReLU
8+
"""
9+
10+
import numpy as np
11+
12+
13+
def parametric_rectified_linear_unit(
14+
vector: np.ndarray, alpha: np.ndarray
15+
) -> np.ndarray:
16+
"""
17+
Implements the Parametric ReLU (PReLU) activation function.
18+
19+
Parameters:
20+
vector (np.ndarray): The input array for PReLU activation.
21+
alpha (np.ndarray): The learnable slope for negative values, must be the same shape as vector.
22+
23+
Returns:
24+
np.ndarray: The input array after applying the PReLU activation.
25+
26+
Formula:
27+
f(x) = x if x > 0 else f(x) = alpha * x
28+
29+
Examples:
30+
>>> parametric_rectified_linear_unit(vector=np.array([2.3, 0.6, -2, -3.8]), alpha=np.array([0.3]))
31+
array([ 2.3 , 0.6 , -0.6 , -1.14])
32+
33+
>>> parametric_rectified_linear_unit(vector=np.array([-9.2, -0.3, 0.45, -4.56]), alpha=np.array([0.067]))
34+
array([-0.6164 , -0.0201 , 0.45 , -0.30552])
35+
36+
>>> parametric_rectified_linear_unit(vector=np.array([0, 0, 0]), alpha=np.array([0.1, 0.1, 0.1]))
37+
array([0., 0., 0.])
38+
39+
>>> parametric_rectified_linear_unit(vector=np.array([-1, -2, -3]), alpha=np.array([0.5, 1, 1.5]))
40+
array([-0.5, -2. , -4.5])
41+
42+
>>> parametric_rectified_linear_unit(vector=np.array([-1, 2, -3]), alpha=np.array([1, 0.5, 2]))
43+
array([-1., 2., -6.])
44+
45+
>>> parametric_rectified_linear_unit(vector=np.array([-5, -10]), alpha=np.array([2, 3]))
46+
array([-10, -30])
47+
48+
>>> parametric_rectified_linear_unit(vector=np.array([-1, -2]), alpha=np.array([1, 0]))
49+
array([-1, 0])
50+
51+
>>> parametric_rectified_linear_unit(vector=np.array([1, -1]), alpha=np.array([0.5, 2]))
52+
array([ 1., -2.])
53+
"""
54+
55+
return np.where(vector > 0, vector, alpha * vector)
56+
57+
58+
if __name__ == "__main__":
59+
import doctest
60+
61+
doctest.testmod()

0 commit comments

Comments
 (0)