22
22
from typing import Dict , Optional , Sequence , Tuple , Union , cast
23
23
24
24
import numpy as np
25
+ import pandas as pd
25
26
import pytensor
26
27
import pytensor .tensor as at
28
+ import xarray as xr
27
29
28
30
from pytensor .compile .sharedvalue import SharedVariable
29
31
from pytensor .raise_op import Assert
@@ -205,17 +207,17 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
205
207
206
208
def determine_coords (
207
209
model ,
208
- value ,
210
+ value : Union [ pd . DataFrame , pd . Series , xr . DataArray ] ,
209
211
dims : Optional [Sequence [Optional [str ]]] = None ,
210
- coords : Optional [Dict [str , Sequence ]] = None ,
211
- ) -> Tuple [Dict [str , Sequence ], Sequence [Optional [str ]]]:
212
+ coords : Optional [Dict [str , Union [ Sequence , np . ndarray ] ]] = None ,
213
+ ) -> Tuple [Dict [str , Union [ Sequence , np . ndarray ] ], Sequence [Optional [str ]]]:
212
214
"""Determines coordinate values from data or the model (via ``dims``)."""
213
215
if coords is None :
214
216
coords = {}
215
217
218
+ dim_name = None
216
219
# If value is a df or a series, we interpret the index as coords:
217
220
if hasattr (value , "index" ):
218
- dim_name = None
219
221
if dims is not None :
220
222
dim_name = dims [0 ]
221
223
if dim_name is None and value .index .name is not None :
@@ -225,14 +227,20 @@ def determine_coords(
225
227
226
228
# If value is a df, we also interpret the columns as coords:
227
229
if hasattr (value , "columns" ):
228
- dim_name = None
229
230
if dims is not None :
230
231
dim_name = dims [1 ]
231
232
if dim_name is None and value .columns .name is not None :
232
233
dim_name = value .columns .name
233
234
if dim_name is not None :
234
235
coords [dim_name ] = value .columns
235
236
237
+ if isinstance (value , xr .DataArray ):
238
+ if dims is not None :
239
+ for dim in dims :
240
+ dim_name = dim
241
+ # str is applied because dim entries may be None
242
+ coords [str (dim_name )] = value [dim ].to_numpy ()
243
+
236
244
if isinstance (value , np .ndarray ) and dims is not None :
237
245
if len (dims ) != value .ndim :
238
246
raise pm .exceptions .ShapeError (
@@ -257,21 +265,29 @@ def ConstantData(
257
265
value ,
258
266
* ,
259
267
dims : Optional [Sequence [str ]] = None ,
260
- coords : Optional [Dict [str , Sequence ]] = None ,
268
+ coords : Optional [Dict [str , Union [ Sequence , np . ndarray ] ]] = None ,
261
269
export_index_as_coords = False ,
270
+ infer_dims_and_coords = False ,
262
271
** kwargs ,
263
272
) -> TensorConstant :
264
273
"""Alias for ``pm.Data(..., mutable=False)``.
265
274
266
275
Registers the ``value`` as a :class:`~pytensor.tensor.TensorConstant` with the model.
267
276
For more information, please reference :class:`pymc.Data`.
268
277
"""
278
+ if export_index_as_coords :
279
+ infer_dims_and_coords = export_index_as_coords
280
+ warnings .warn (
281
+ "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead." ,
282
+ DeprecationWarning ,
283
+ )
284
+
269
285
var = Data (
270
286
name ,
271
287
value ,
272
288
dims = dims ,
273
289
coords = coords ,
274
- export_index_as_coords = export_index_as_coords ,
290
+ infer_dims_and_coords = infer_dims_and_coords ,
275
291
mutable = False ,
276
292
** kwargs ,
277
293
)
@@ -283,21 +299,29 @@ def MutableData(
283
299
value ,
284
300
* ,
285
301
dims : Optional [Sequence [str ]] = None ,
286
- coords : Optional [Dict [str , Sequence ]] = None ,
302
+ coords : Optional [Dict [str , Union [ Sequence , np . ndarray ] ]] = None ,
287
303
export_index_as_coords = False ,
304
+ infer_dims_and_coords = False ,
288
305
** kwargs ,
289
306
) -> SharedVariable :
290
307
"""Alias for ``pm.Data(..., mutable=True)``.
291
308
292
309
Registers the ``value`` as a :class:`~pytensor.compile.sharedvalue.SharedVariable`
293
310
with the model. For more information, please reference :class:`pymc.Data`.
294
311
"""
312
+ if export_index_as_coords :
313
+ infer_dims_and_coords = export_index_as_coords
314
+ warnings .warn (
315
+ "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead." ,
316
+ DeprecationWarning ,
317
+ )
318
+
295
319
var = Data (
296
320
name ,
297
321
value ,
298
322
dims = dims ,
299
323
coords = coords ,
300
- export_index_as_coords = export_index_as_coords ,
324
+ infer_dims_and_coords = infer_dims_and_coords ,
301
325
mutable = True ,
302
326
** kwargs ,
303
327
)
@@ -309,8 +333,9 @@ def Data(
309
333
value ,
310
334
* ,
311
335
dims : Optional [Sequence [str ]] = None ,
312
- coords : Optional [Dict [str , Sequence ]] = None ,
336
+ coords : Optional [Dict [str , Union [ Sequence , np . ndarray ] ]] = None ,
313
337
export_index_as_coords = False ,
338
+ infer_dims_and_coords = False ,
314
339
mutable : Optional [bool ] = None ,
315
340
** kwargs ,
316
341
) -> Union [SharedVariable , TensorConstant ]:
@@ -347,7 +372,9 @@ def Data(
347
372
names.
348
373
coords : dict, optional
349
374
Coordinate values to set for new dimensions introduced by this ``Data`` variable.
350
- export_index_as_coords : bool, default=False
375
+ export_index_as_coords : bool
376
+ Deprecated, previous version of "infer_dims_and_coords"
377
+ infer_dims_and_coords : bool, default=False
351
378
If True, the ``Data`` container will try to infer what the coordinates
352
379
and dimension names should be if there is an index in ``value``.
353
380
mutable : bool, optional
@@ -427,6 +454,13 @@ def Data(
427
454
428
455
# Optionally infer coords and dims from the input value.
429
456
if export_index_as_coords :
457
+ infer_dims_and_coords = export_index_as_coords
458
+ warnings .warn (
459
+ "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead." ,
460
+ DeprecationWarning ,
461
+ )
462
+
463
+ if infer_dims_and_coords :
430
464
coords , dims = determine_coords (model , value , dims )
431
465
432
466
if dims :
0 commit comments