Skip to content

Commit f617180

Browse files
Juan Orduzmichaelosthege
Juan Orduz
andauthored
Add _as_tensor_variable converter for pandas objects (#5920)
* Add pd.Series and pd.DataFrame support * Explicitly add pandas into requirements.txt * Add pd global import in tests Co-authored-by: Michael Osthege <[email protected]>
1 parent a568762 commit f617180

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

pymc/aesaraf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import aesara
3030
import aesara.tensor as at
3131
import numpy as np
32+
import pandas as pd
3233
import scipy.sparse as sps
3334

3435
from aeppl.abstract import MeasurableVariable
@@ -50,6 +51,7 @@
5051
from aesara.graph.op import Op, compute_test_value
5152
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
5253
from aesara.scalar.basic import Cast
54+
from aesara.tensor.basic import _as_tensor_variable
5355
from aesara.tensor.elemwise import Elemwise
5456
from aesara.tensor.random.op import RandomVariable
5557
from aesara.tensor.random.var import (
@@ -142,6 +144,12 @@ def convert_observed_data(data):
142144
return floatX(ret)
143145

144146

147+
@_as_tensor_variable.register(pd.Series)
148+
@_as_tensor_variable.register(pd.DataFrame)
149+
def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVariable:
150+
return at.as_tensor_variable(df.to_numpy(), *args, **kwargs)
151+
152+
145153
def change_rv_size(
146154
rv: TensorVariable,
147155
new_size: PotentialShapeType,

pymc/tests/test_aesaraf.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import numpy.ma as ma
2020
import numpy.testing as npt
21+
import pandas as pd
2122
import pytest
2223
import scipy.sparse as sps
2324

@@ -47,6 +48,40 @@
4748
from pymc.vartypes import int_types
4849

4950

51+
@pytest.mark.parametrize(
52+
argnames="np_array",
53+
argvalues=[
54+
np.array([[1.0], [2.0], [-1.0]]),
55+
np.array([[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]]),
56+
np.ones(shape=(10, 1)),
57+
],
58+
)
59+
def test_pd_dataframe_as_tensor_variable(np_array: np.ndarray) -> None:
60+
df = pd.DataFrame(np_array)
61+
np.testing.assert_array_equal(x=at.as_tensor_variable(x=df).eval(), y=np_array)
62+
63+
64+
@pytest.mark.parametrize(
65+
argnames="np_array",
66+
argvalues=[np.array([1.0, 2.0, -1.0]), np.ones(shape=4), np.zeros(shape=10), [1, 2, 3, 4]],
67+
)
68+
def test_pd_series_as_tensor_variable(np_array: np.ndarray) -> None:
69+
df = pd.Series(np_array)
70+
np.testing.assert_array_equal(x=at.as_tensor_variable(x=df).eval(), y=np_array)
71+
72+
73+
def test_pd_as_tensor_variable_multiindex() -> None:
74+
75+
tuples = [("L", "Q"), ("L", "I"), ("O", "L"), ("O", "I")]
76+
77+
index = pd.MultiIndex.from_tuples(tuples, names=["Id1", "Id2"])
78+
79+
df = pd.DataFrame({"A": [12.0, 80.0, 30.0, 20.0], "B": [120.0, 700.0, 30.0, 20.0]}, index=index)
80+
np_array = np.array([[12.0, 80.0, 30.0, 20.0], [120.0, 700.0, 30.0, 20.0]]).T
81+
assert isinstance(df.index, pd.MultiIndex)
82+
np.testing.assert_array_equal(x=at.as_tensor_variable(x=df).eval(), y=np_array)
83+
84+
5085
def test_change_rv_size():
5186
loc = at.as_tensor_variable([1, 2])
5287
rv = normal(loc=loc)
@@ -224,7 +259,6 @@ def test_convert_observed_data(input_dtype):
224259
Ensure that convert_observed_data returns the dense array, masked array,
225260
graph variable, TensorVariable, or sparse matrix as appropriate.
226261
"""
227-
pd = pytest.importorskip("pandas")
228262
# Create the various inputs to the function
229263
sparse_input = sps.csr_matrix(np.eye(3)).astype(input_dtype)
230264
dense_input = np.arange(9).reshape((3, 3)).astype(input_dtype)
@@ -300,7 +334,6 @@ def test_convert_observed_data(input_dtype):
300334

301335

302336
def test_pandas_to_array_pandas_index():
303-
pd = pytest.importorskip("pandas")
304337
data = pd.Index([1, 2, 3])
305338
result = convert_observed_data(data)
306339
expected = np.array([1, 2, 3])

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ cachetools>=4.2.1
55
cloudpickle
66
fastprogress>=0.2.0
77
numpy>=1.15.0
8+
pandas>=0.24.0
89
scipy>=1.4.1
910
typing-extensions>=3.7.4

0 commit comments

Comments
 (0)