diff --git a/pymc/data.py b/pymc/data.py index c03461cc5b..356596ea87 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -485,6 +485,8 @@ def determine_coords( dim_name = value.index.name if dim_name is not None: coords[dim_name] = value.index + if dims is None: + dims = dims_name # If value is a df, we also interpret the columns as coords: if hasattr(value, "columns"): @@ -495,7 +497,31 @@ def determine_coords( dim_name = value.columns.name if dim_name is not None: coords[dim_name] = value.columns - + if dims is None: + dims = dims_name + + if isinstance(value, xr.DataArray): + dim_name = None + if dims is not None: + dim_name = dims[0] + if dim_name is None and value.dims[0] is not None: + dim_name = value.dims[0] + if dim_name is not None: + coords[dim_name] = dff.indexes.get(str(dff.dims[0])).values + if dims is None: + dims = dims_name + + if isinstance(value, xr.DataArray): + dim_name = None + if dims is not None: + dim_name = dims[1] + if dim_name is None and value.dims[1] is not None: + dim_name = value.dims[1] + if dim_name is not None: + coords[dim_name] = dff.indexes.get(str(dff.dims[1])).values + if dims is None: + dims = dims_name + if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: raise pm.exceptions.ShapeError( @@ -506,11 +532,7 @@ def determine_coords( for size, dim in zip(value.shape, dims): coord = model.coords.get(dim, None) if coord is None and dim is not None: - coords[dim] = range(size) - - if dims is None: - # TODO: Also determine dim names from the index - dims = [None] * np.ndim(value) + coords[dim] = range(size) return coords, dims