Skip to content

Commit 50da472

Browse files
Kuldeep Borkarcclauss
Kuldeep Borkar
andauthored
Implemented Gelu Function (TheAlgorithms#7368)
* Implemented Gelu Function * Renamed file and added more description to function * Extended the name GELU * Update gaussian_error_linear_unit.py Co-authored-by: Christian Clauss <[email protected]>
1 parent b8281d7 commit 50da472

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

maths/gaussian_error_linear_unit.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
This script demonstrates an implementation of the Gaussian Error Linear Unit function.
3+
* https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions
4+
5+
The function takes a vector of K real numbers as input and returns x * sigmoid(1.702*x).
6+
Gaussian Error Linear Unit (GELU) is a high-performing neural network activation
7+
function.
8+
9+
This script is inspired by a corresponding research paper.
10+
* https://arxiv.org/abs/1606.08415
11+
"""
12+
13+
import numpy as np
14+
15+
16+
def sigmoid(vector: np.array) -> np.array:
17+
"""
18+
Mathematical function sigmoid takes a vector x of K real numbers as input and
19+
returns 1/ (1 + e^-x).
20+
https://en.wikipedia.org/wiki/Sigmoid_function
21+
22+
>>> sigmoid(np.array([-1.0, 1.0, 2.0]))
23+
array([0.26894142, 0.73105858, 0.88079708])
24+
"""
25+
return 1 / (1 + np.exp(-vector))
26+
27+
28+
def gaussian_error_linear_unit(vector: np.array) -> np.array:
29+
"""
30+
Implements the Gaussian Error Linear Unit (GELU) function
31+
32+
Parameters:
33+
vector (np.array): A numpy array of shape (1,n)
34+
consisting of real values
35+
36+
Returns:
37+
gelu_vec (np.array): The input numpy array, after applying
38+
gelu.
39+
40+
Examples:
41+
>>> gaussian_error_linear_unit(np.array([-1.0, 1.0, 2.0]))
42+
array([-0.15420423, 0.84579577, 1.93565862])
43+
44+
>>> gaussian_error_linear_unit(np.array([-3]))
45+
array([-0.01807131])
46+
"""
47+
return vector * sigmoid(1.702 * vector)
48+
49+
50+
if __name__ == "__main__":
51+
import doctest
52+
53+
doctest.testmod()

0 commit comments

Comments
 (0)