Skip to content

Commit 3c6eff4

Browse files
Progress: 2608 pass, 97 skip, 84 xfail, 6 xpass
With these changes (and pandas-dev/pandas#53970 and hgrecco/pint#1615) the test suite passes or xpasses everything (no failures or error). Indeed, the uncertainties code has essentially doubled the scope of the test suite (to test with and without it). The biggest gotcha is that the EA for complex numbers is not compatible with the EA for uncertainties, due to incompatible hacks: The hack for complex numbers is to np.nan (which is, technically, a complex number) for na_value across all numeric types. But that doesn't work for uncertainties, because uncertainties doesn't accept np.nan as an uncertain value. The hack for uncertainties is to use pd.NA for na_value. This works for Int64, Float64, and uncertainties, but doesn't work for complex (which cannot tolerate NAType). Some careful subclassing fills in what doesn't easily work, with fixtures to prevent the improper mixing of complex and uncertainty types in the same python environment. Happy to discuss! Signed-off-by: Michael Tiemann <[email protected]>
1 parent dbf5ad1 commit 3c6eff4

File tree

3 files changed

+235
-124
lines changed

3 files changed

+235
-124
lines changed

pint_pandas/pint_array.py

+145-53
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __new__(cls, units=None):
7575
if not isinstance(units, _Unit):
7676
units = cls._parse_dtype_strict(units)
7777
# ureg.unit returns a quantity with a magnitude of 1
78-
# eg 1 mm. Initialising a quantity and taking it's unit
78+
# eg 1 mm. Initialising a quantity and taking its unit
7979
# TODO: Seperate units from quantities in pint
8080
# to simplify this bit
8181
units = cls.ureg.Quantity(1, units).units
@@ -148,7 +148,10 @@ def name(self):
148148

149149
@property
150150
def na_value(self):
151-
return self.ureg.Quantity(pd.NA, self.units)
151+
if HAS_UNCERTAINTIES:
152+
return self.ureg.Quantity(pd.NA, self.units)
153+
else:
154+
return self.ureg.Quantity(np.nan, self.units)
152155

153156
def __hash__(self):
154157
# make myself hashable
@@ -318,12 +321,41 @@ def __setitem__(self, key, value):
318321
# doing nothing here seems to be ok
319322
return
320323

324+
master_scalar = None
325+
try:
326+
master_scalar = next(i for i in self._data if pd.notna(i))
327+
except StopIteration:
328+
pass
329+
321330
if isinstance(value, _Quantity):
322331
value = value.to(self.units).magnitude
323-
elif is_list_like(value) and len(value) > 0 and isinstance(value[0], _Quantity):
324-
value = [item.to(self.units).magnitude for item in value]
332+
elif is_list_like(value) and len(value) > 0:
333+
if isinstance(value[0], _Quantity):
334+
value = [item.to(self.units).magnitude for item in value]
335+
elif HAS_UNCERTAINTIES and isinstance(master_scalar, UFloat):
336+
if not all([isinstance(i, UFloat) or pd.isna(i) for i in value]):
337+
value = [
338+
i if isinstance(i, UFloat) or pd.isna(i) else ufloat(i, 0)
339+
for i in value
340+
]
341+
if len(value) == 1:
342+
value = value[0]
325343

326344
key = check_array_indexer(self, key)
345+
# Filter out invalid values for our array type(s)
346+
if HAS_UNCERTAINTIES:
347+
if isinstance(value, UFloat):
348+
pass
349+
elif is_list_like(value):
350+
from pandas.core.dtypes.common import is_scalar
351+
352+
if is_scalar(key):
353+
msg = "Value must be scalar. {}".format(value)
354+
raise ValueError(msg)
355+
elif type(value) is object:
356+
if pd.notna(value):
357+
msg = "Invalid object. {}".format(value)
358+
raise ValueError(msg)
327359
try:
328360
self._data[key] = value
329361
except IndexError as e:
@@ -535,45 +567,24 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
535567
if dtype is None and isinstance(master_scalar, _Quantity):
536568
dtype = PintType(master_scalar.units)
537569

538-
def quantify_nan(item, promote_to_ufloat):
539-
if pd.isna(item):
540-
return dtype.ureg.Quantity(item, dtype.units)
541-
# FIXME: most of this code is never executed (except the final return)
542-
if promote_to_ufloat:
543-
if type(item) is UFloat:
544-
return item * dtype.units
545-
if type(item) is float:
546-
if np.isnan(item):
547-
return _ufloat_nan * dtype.units
548-
else:
549-
return UFloat(item, 0) * dtype.units
550-
else:
551-
if type(item) is float:
552-
return item * dtype.units
553-
return item
554-
555570
if isinstance(master_scalar, _Quantity):
556-
# A quantified master_scalar does not guarantee that we don't have NA and/or np.nan values in our scalars
557-
if HAS_UNCERTAINTIES:
558-
promote_to_ufloat = any(
559-
[isinstance(item.m, UFloat) for item in scalars if pd.notna(item)]
560-
)
561-
else:
562-
promote_to_ufloat = False
563-
scalars = [
564-
item
565-
if isinstance(item, _Quantity)
566-
else quantify_nan(item, promote_to_ufloat)
567-
for item in scalars
568-
]
571+
promote_to_ufloat = False
569572
scalars = [
570573
(item.to(dtype.units).magnitude if hasattr(item, "to") else item)
571574
for item in scalars
572575
]
573576
elif HAS_UNCERTAINTIES:
574-
promote_to_ufloat = any([isinstance(item, UFloat) for item in scalars])
577+
# When creating empty arrays, make them large enoguh to hold UFloats in case we need to do so later
578+
if len(scalars) == 0:
579+
promote_to_ufloat = True
580+
else:
581+
promote_to_ufloat = any([isinstance(item, UFloat) for item in scalars])
575582
else:
576583
promote_to_ufloat = False
584+
if len(scalars) == 0:
585+
if promote_to_ufloat:
586+
return cls([_ufloat_nan], dtype=dtype, copy=copy)[1:]
587+
return cls(scalars, dtype=dtype, copy=copy)
577588
if promote_to_ufloat:
578589
scalars = [
579590
item
@@ -639,6 +650,10 @@ def factorize(
639650
# Complete control over factorization.
640651
if HAS_UNCERTAINTIES and self._data.dtype.kind == "O":
641652
arr, na_value = self._values_for_factorize()
653+
# Unique elements make it easy to partition on na_value if we need to
654+
arr_list = list(dict.fromkeys(arr))
655+
na_index = len(arr_list)
656+
arr = np.array(arr_list)
642657

643658
if not use_na_sentinel:
644659
# factorize can now handle differentiating various types of null values.
@@ -649,36 +664,51 @@ def factorize(
649664
if null_mask.any():
650665
# Don't modify (potentially user-provided) array
651666
arr = np.where(null_mask, na_value, arr)
652-
653-
codes = [-1] * len(self.data)
654-
# Note that item is a local variable provided in the loop below
667+
else:
668+
try:
669+
na_index = arr.tolist().index(na_value)
670+
except ValueError:
671+
# Keep as len(arr)
672+
pass
673+
codes = np.array([-1] * len(self.data), dtype=np.intp)
674+
# Note: item is a local variable provided in the loop below
675+
# Note: partitioning arr on pd.NA means item is never pd.NA
655676
vf = np.vectorize(
656-
lambda x: True
657-
if (x_na := pd.isna(x)) * (item_na := pd.isna(item))
658-
else (x_na == item_na and x == item),
677+
lambda x: False if pd.isna(x) else x == item,
659678
otypes=[bool],
660679
)
661-
for code, item in enumerate(arr):
680+
for code, item in enumerate(arr[: na_index + 1]):
662681
code_mask = vf(self._data)
682+
# Don't count the NA we have seen
663683
codes = np.where(code_mask, code, codes)
664-
665-
uniques_ea = self._from_factorized(arr, self)
684+
if use_na_sentinel and na_index < len(arr):
685+
for code, item in enumerate(arr[na_index:]):
686+
code_mask = vf(self._data)
687+
# Don't count the NA we have seen
688+
codes = np.where(code_mask, code, codes)
689+
uniques_ea = self._from_factorized(
690+
arr[:na_index].tolist() + arr[na_index + 1 :].tolist(), self
691+
)
692+
else:
693+
uniques_ea = self._from_factorized(arr, self)
666694
return codes, uniques_ea
667695
else:
668-
return super(PintArray, self).factorize(self, use_na_sentinel)
696+
return super(PintArray, self).factorize(use_na_sentinel)
669697

670698
@classmethod
671699
def _from_factorized(cls, values, original):
700+
from pandas._libs.lib import infer_dtype
701+
if infer_dtype(values) != "object":
702+
values = pd.array(values, copy=False)
672703
return cls(values, dtype=original.dtype)
673704

674705
def _values_for_factorize(self):
675706
arr = self._data
676-
if HAS_UNCERTAINTIES and arr.dtype.kind == "O":
677-
unique_data = []
678-
for item in arr:
679-
if item not in unique_data:
680-
unique_data.append(item)
681-
return np.array(unique_data), pd.NA
707+
if arr.dtype.kind == "O":
708+
if HAS_UNCERTAINTIES and arr.size > 0 and isinstance(arr[0], UFloat):
709+
# Canonicalize uncertain NaNs
710+
arr = np.where(unp.isnan(arr), self.dtype.na_value.m, arr)
711+
return np.array(arr, copy=False), self.dtype.na_value.m
682712
return arr._values_for_factorize()
683713

684714
def value_counts(self, dropna=True):
@@ -706,7 +736,7 @@ def value_counts(self, dropna=True):
706736
# compute counts on the data with no nans
707737
data = self._data
708738
nafilt = data.isna()
709-
na_value = pd.NA
739+
na_value = self.dtype.na_value.m
710740
data = data[~nafilt]
711741
if HAS_UNCERTAINTIES and data.dtype.kind == "O":
712742
unique_data = []
@@ -746,6 +776,68 @@ def unique(self):
746776
)
747777
return self._from_sequence(unique(data), dtype=self.dtype)
748778

779+
def shift(self, periods: int = 1, fill_value=None):
780+
"""
781+
Shift values by desired number.
782+
783+
Newly introduced missing values are filled with
784+
a missing value type consistent with the existing elements
785+
or ``self.dtype.na_value`` if none exist.
786+
787+
Parameters
788+
----------
789+
periods : int, default 1
790+
The number of periods to shift. Negative values are allowed
791+
for shifting backwards.
792+
793+
fill_value : object, optional
794+
The scalar value to use for newly introduced missing values.
795+
The default is ``self.dtype.na_value``.
796+
797+
Returns
798+
-------
799+
ExtensionArray
800+
Shifted.
801+
802+
Notes
803+
-----
804+
If ``self`` is empty or ``periods`` is 0, a copy of ``self`` is
805+
returned.
806+
807+
If ``periods > len(self)``, then an array of size
808+
len(self) is returned, with all values filled with
809+
``self.dtype.na_value``.
810+
811+
For 2-dimensional ExtensionArrays, we are always shifting along axis=0.
812+
"""
813+
if not len(self) or periods == 0:
814+
return self.copy()
815+
816+
if pd.isna(fill_value):
817+
fill_value = self.dtype.na_value.m
818+
819+
if HAS_UNCERTAINTIES:
820+
if self.data.dtype.kind == "O":
821+
try:
822+
notna_value = next(i for i in self._data if pd.notna(i))
823+
if isinstance(notna_value, UFloat):
824+
fill_value = _ufloat_nan
825+
except StopIteration:
826+
pass
827+
elif self.data.dtype.kind == "f":
828+
fill_value = np.nan
829+
830+
empty = self._from_sequence(
831+
[fill_value] * min(abs(periods), len(self)), dtype=self.dtype
832+
)
833+
if periods > 0:
834+
a = empty
835+
b = self[:-periods]
836+
else:
837+
a = self[abs(periods) :]
838+
b = empty
839+
return self._concat_same_type([a, b])
840+
749841
def __contains__(self, item) -> bool:
750842
if not isinstance(item, _Quantity):
751843
return False
@@ -895,7 +987,7 @@ def __array__(self, dtype=None, copy=False):
895987

896988
def _to_array_of_quantity(self, copy=False):
897989
qtys = [
898-
self._Q(item, self._dtype.units) if item is not pd.NA else item
990+
self._Q(item, self._dtype.units) if item is not self.dtype.na_value.m else item
899991
for item in self._data
900992
]
901993
with warnings.catch_warnings(record=True):

pint_pandas/testsuite/test_issues.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class TestIssue21(BaseExtensionTests):
6565
def test_offset_concat(self):
6666
q_a = ureg.Quantity(np.arange(5) + ufloat(0, 0), ureg.Unit("degC"))
6767
q_b = ureg.Quantity(np.arange(6) + ufloat(0, 0), ureg.Unit("degC"))
68-
q_a_ = np.append(q_a, ureg.Quantity(ufloat(np.nan, 0), ureg.Unit("degC")))
68+
q_a_ = np.append(q_a, ureg.Quantity(pd.NA, ureg.Unit("degC")))
6969

7070
a = pd.Series(PintArray(q_a))
7171
b = pd.Series(PintArray(q_b))
@@ -179,13 +179,10 @@ def test_issue_127():
179179
assert a == b
180180

181181

182+
@pytest.mark.skipif(
183+
not HAS_UNCERTAINTIES, reason="this test depends entirely on HAS_UNCERTAINTIES being True"
184+
)
182185
def test_issue_139():
183-
from pint.compat import HAS_UNCERTAINTIES
184-
185-
assert HAS_UNCERTAINTIES
186-
from uncertainties import ufloat
187-
from uncertainties import unumpy as unp
188-
189186
q1 = 1.234
190187
q2 = 5.678
191188
q_nan = np.nan

0 commit comments

Comments
 (0)