diff --git a/docs/source/api/math.rst b/docs/source/api/math.rst index 80741743cb..499001bd46 100644 --- a/docs/source/api/math.rst +++ b/docs/source/api/math.rst @@ -58,6 +58,7 @@ expressions rather than NumPy or Python code. logsumexp invlogit logit + softmax .. automodule:: pymc3.math :members: diff --git a/pymc3/math.py b/pymc3/math.py index dba3f938a6..e55c56ed42 100644 --- a/pymc3/math.py +++ b/pymc3/math.py @@ -195,6 +195,11 @@ def invlogit(x, eps=sys.float_info.epsilon): return (1.0 - 2.0 * eps) / (1.0 + tt.exp(-x)) + eps +def softmax(x): + """Generalization of the inverse logit function to multiple dimensions.""" + return tt.nnet.softmax(x) + + def logbern(log_p): if np.isnan(log_p): raise FloatingPointError("log_p can't be nan.")