Skip to content

Commit c6e9153

Browse files
michaelosthegeaseyboldttwiecki
authored
More moments (#5025)
* Improve error message for missing get_moment implementations * Add get_moment implementations for Normal, Uniform and Binomial Co-authored-by: Adrian Seyboldt <[email protected]> * Add tests for (Half)Flat moments with symbolic dimensionality Closes #4993 * Apply suggestions from code review * Update pymc/distributions/continuous.py Co-authored-by: Adrian Seyboldt <[email protected]> Co-authored-by: Thomas Wiecki <[email protected]>
1 parent 37ba9a3 commit c6e9153

File tree

4 files changed

+53
-3
lines changed

4 files changed

+53
-3
lines changed

pymc/distributions/continuous.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@ def logcdf(value, lower, upper):
336336
),
337337
)
338338

339+
def get_moment(value, size, lower, upper):
340+
lower = at.full(size, lower, dtype=aesara.config.floatX)
341+
upper = at.full(size, upper, dtype=aesara.config.floatX)
342+
return (lower + upper) / 2
343+
339344

340345
class FlatRV(RandomVariable):
341346
name = "flat"
@@ -366,7 +371,7 @@ def dist(cls, *, size=None, **kwargs):
366371
res.tag.test_value = np.full(size, floatX(0.0))
367372
return res
368373

369-
def get_moment(rv, size, *rv_inputs) -> np.ndarray:
374+
def get_moment(rv, size, *rv_inputs):
370375
return at.zeros(size, dtype=aesara.config.floatX)
371376

372377
def logp(value):
@@ -431,7 +436,7 @@ def dist(cls, *, size=None, **kwargs):
431436
res.tag.test_value = np.full(size, floatX(1.0))
432437
return res
433438

434-
def get_moment(value_var, size, *rv_inputs) -> np.ndarray:
439+
def get_moment(value_var, size, *rv_inputs):
435440
return at.ones(size, dtype=aesara.config.floatX)
436441

437442
def logp(value):
@@ -588,6 +593,9 @@ def logcdf(value, mu, sigma):
588593
0 < sigma,
589594
)
590595

596+
def get_moment(value_var, size, mu, sigma):
597+
return at.full(size, mu, dtype=aesara.config.floatX)
598+
591599

592600
class TruncatedNormalRV(RandomVariable):
593601
name = "truncated_normal"

pymc/distributions/discrete.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,10 @@ def logcdf(value, p):
394394
p <= 1,
395395
)
396396

397+
def get_moment(value, size, p):
398+
p = at.full(size, p)
399+
return at.switch(p < 0.5, at.zeros_like(value), at.ones_like(value))
400+
397401
def _distr_parameters_for_repr(self):
398402
return ["p"]
399403

pymc/distributions/distribution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,9 @@ def dist(
351351

352352
@singledispatch
353353
def _get_moment(op, rv, size, *rv_inputs) -> TensorVariable:
354-
return None
354+
raise NotImplementedError(
355+
f"Random variable {rv} of type {op} has no get_moment implementation."
356+
)
355357

356358

357359
def get_moment(rv: TensorVariable) -> TensorVariable:

pymc/tests/test_initvals.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import aesara.tensor as at
1415
import numpy as np
1516
import pytest
1617

@@ -95,6 +96,11 @@ def test_automatically_assigned_test_values(self):
9596

9697
class TestMoment:
9798
def test_basic(self):
99+
# Standard distributions
100+
rv = pm.Normal.dist(mu=2.3)
101+
np.testing.assert_allclose(get_moment(rv).eval(), 2.3)
102+
103+
# Special distributions
98104
rv = pm.Flat.dist()
99105
assert get_moment(rv).eval() == np.zeros(())
100106
rv = pm.HalfFlat.dist()
@@ -103,3 +109,33 @@ def test_basic(self):
103109
assert np.all(get_moment(rv).eval() == np.zeros((2, 4)))
104110
rv = pm.HalfFlat.dist(size=(2, 4))
105111
assert np.all(get_moment(rv).eval() == np.ones((2, 4)))
112+
113+
@pytest.mark.xfail(reason="Test values are still used for initvals.")
114+
@pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat])
115+
def test_numeric_moment_shape(self, rv_cls):
116+
rv = rv_cls.dist(shape=(2,))
117+
assert not hasattr(rv.tag, "test_value")
118+
assert tuple(get_moment(rv).shape.eval()) == (2,)
119+
120+
@pytest.mark.xfail(reason="Test values are still used for initvals.")
121+
@pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat])
122+
def test_symbolic_moment_shape(self, rv_cls):
123+
s = at.scalar()
124+
rv = rv_cls.dist(shape=(s,))
125+
assert not hasattr(rv.tag, "test_value")
126+
assert tuple(get_moment(rv).shape.eval({s: 4})) == (4,)
127+
pass
128+
129+
@pytest.mark.xfail(reason="Test values are still used for initvals.")
130+
@pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat])
131+
def test_moment_from_dims(self, rv_cls):
132+
with pm.Model(
133+
coords={
134+
"year": [2019, 2020, 2021, 2022],
135+
"city": ["Bonn", "Paris", "Lisbon"],
136+
}
137+
):
138+
rv = rv_cls("rv", dims=("year", "city"))
139+
assert not hasattr(rv.tag, "test_value")
140+
assert tuple(get_moment(rv).shape.eval()) == (4, 3)
141+
pass

0 commit comments

Comments
 (0)