Skip to content

Commit 931a5af

Browse files
lhelleckesmichaelosthege
authored andcommitted
Refactor convert_observed data to simplify typing
1 parent 7836447 commit 931a5af

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pymc/pytensorf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,12 @@
7878
]
7979

8080

81-
def convert_observed_data(data):
81+
def convert_observed_data(data) -> np.ndarray | Variable:
8282
"""Convert user provided dataset to accepted formats."""
8383

84+
if isgenerator(data):
85+
return floatX(generator(data))
86+
8487
if hasattr(data, "to_numpy") and hasattr(data, "isnull"):
8588
# typically, but not limited to pandas objects
8689
vals = data.to_numpy()
@@ -116,8 +119,6 @@ def convert_observed_data(data):
116119
ret = data
117120
elif sps.issparse(data):
118121
ret = data
119-
elif isgenerator(data):
120-
ret = generator(data)
121122
else:
122123
ret = np.asarray(data)
123124

0 commit comments

Comments
 (0)