Skip to content

Commit 47b3658

Browse files
Modifying cartesian product to allow for >2D input arrays (#4482)
* Modifying cartesian product to allow for more >2D input arrays * Assert for equality in cartesian test * Mention #4482 in new features Co-authored-by: Michael Osthege <[email protected]>
1 parent f0c823e commit 47b3658

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
+ ...
66

77
### New Features
8+
+ `pm.math.cartesian` can now handle inputs that are themselves >1D (see [#4482](https://github.com/pymc-devs/pymc3/pull/4482)).
89
+ ...
910

1011
### Maintenance

pymc3/math.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,17 @@ def cartesian(*arrays):
101101
102102
Parameters
103103
----------
104-
arrays: 1D array-like
105-
1D arrays where earlier arrays loop more slowly than later ones
104+
arrays: N-D array-like
105+
N-D arrays where earlier arrays loop more slowly than later ones
106106
"""
107107
N = len(arrays)
108-
return np.stack(np.meshgrid(*arrays, indexing="ij"), -1).reshape(-1, N)
108+
arrays_np = [np.asarray(x) for x in arrays]
109+
arrays_2d = [x[:, None] if np.asarray(x).ndim == 1 else x for x in arrays_np]
110+
arrays_integer = [np.arange(len(x)) for x in arrays_2d]
111+
product_integers = np.stack(np.meshgrid(*arrays_integer, indexing="ij"), -1).reshape(-1, N)
112+
return np.concatenate(
113+
[array[product_integers[:, i]] for i, array in enumerate(arrays_2d)], axis=-1
114+
)
109115

110116

111117
def kron_matrix_op(krons, m, op):

pymc3/tests/test_math.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,24 @@ def test_cartesian():
7171
]
7272
)
7373
auto_cart = cartesian(a, b, c)
74-
np.testing.assert_array_almost_equal(manual_cartesian, auto_cart)
74+
np.testing.assert_array_equal(manual_cartesian, auto_cart)
75+
76+
77+
def test_cartesian_2d():
78+
np.random.seed(1)
79+
a = [[1, 2], [3, 4]]
80+
b = [5, 6]
81+
c = [0]
82+
manual_cartesian = np.array(
83+
[
84+
[1, 2, 5, 0],
85+
[1, 2, 6, 0],
86+
[3, 4, 5, 0],
87+
[3, 4, 6, 0],
88+
]
89+
)
90+
auto_cart = cartesian(a, b, c)
91+
np.testing.assert_array_equal(manual_cartesian, auto_cart)
7592

7693

7794
def test_kron_dot():

0 commit comments

Comments
 (0)