|
| 1 | +""" |
| 2 | +This script demonstrates the implementation of the Sigmoid Linear Unit (SiLU) |
| 3 | +or swish function. |
| 4 | +* https://en.wikipedia.org/wiki/Rectifier_(neural_networks) |
| 5 | +* https://en.wikipedia.org/wiki/Swish_function |
| 6 | +
|
| 7 | +The function takes a vector x of K real numbers as input and returns x * sigmoid(x). |
| 8 | +Swish is a smooth, non-monotonic function defined as f(x) = x * sigmoid(x). |
| 9 | +Extensive experiments shows that Swish consistently matches or outperforms ReLU |
| 10 | +on deep networks applied to a variety of challenging domains such as |
| 11 | +image classification and machine translation. |
| 12 | +
|
| 13 | +This script is inspired by a corresponding research paper. |
| 14 | +* https://arxiv.org/abs/1710.05941 |
| 15 | +""" |
| 16 | + |
| 17 | +import numpy as np |
| 18 | + |
| 19 | + |
| 20 | +def sigmoid(vector: np.array) -> np.array: |
| 21 | + """ |
| 22 | + Mathematical function sigmoid takes a vector x of K real numbers as input and |
| 23 | + returns 1/ (1 + e^-x). |
| 24 | + https://en.wikipedia.org/wiki/Sigmoid_function |
| 25 | +
|
| 26 | + >>> sigmoid(np.array([-1.0, 1.0, 2.0])) |
| 27 | + array([0.26894142, 0.73105858, 0.88079708]) |
| 28 | + """ |
| 29 | + return 1 / (1 + np.exp(-vector)) |
| 30 | + |
| 31 | + |
| 32 | +def sigmoid_linear_unit(vector: np.array) -> np.array: |
| 33 | + """ |
| 34 | + Implements the Sigmoid Linear Unit (SiLU) or swish function |
| 35 | +
|
| 36 | + Parameters: |
| 37 | + vector (np.array): A numpy array consisting of real |
| 38 | + values. |
| 39 | +
|
| 40 | + Returns: |
| 41 | + swish_vec (np.array): The input numpy array, after applying |
| 42 | + swish. |
| 43 | +
|
| 44 | + Examples: |
| 45 | + >>> sigmoid_linear_unit(np.array([-1.0, 1.0, 2.0])) |
| 46 | + array([-0.26894142, 0.73105858, 1.76159416]) |
| 47 | +
|
| 48 | + >>> sigmoid_linear_unit(np.array([-2])) |
| 49 | + array([-0.23840584]) |
| 50 | + """ |
| 51 | + return vector * sigmoid(vector) |
| 52 | + |
| 53 | + |
| 54 | +if __name__ == "__main__": |
| 55 | + import doctest |
| 56 | + |
| 57 | + doctest.testmod() |
0 commit comments