Skip to content

Commit 0b9ab2d

Browse files
authored
Refactor nanops (#2236)
* Inhouse nanops * Cleanup nanops * remove NAT_TYPES * flake8. * another flake8 * recover nat types * remove keep_dims option from nanops (to make them compatible with numpy==1.11). * Test aggregation over multiple dimensions * Remove print. * Docs. More cleanup. * flake8 * Bug fix. Better test coverage. * using isnull, where_method. Remove unnecessary conditional branching. * More refactoring based on the comments * remove dtype from nanmedian * Fix for nanmedian * Add tests for dataset * Add tests with resample. * lint * updated whatsnew * Revise from comments. * Use .any and .all method instead of np.any / np.all * Avoid using numpy methods * Avoid casting to int for dask array * Update whatsnew
1 parent 5155ef9 commit 0b9ab2d

11 files changed

+519
-183
lines changed

doc/whats-new.rst

+9
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ Documentation
3636
Enhancements
3737
~~~~~~~~~~~~
3838

39+
- min_count option is newly supported in :py:meth:`~xarray.DataArray.sum`,
40+
:py:meth:`~xarray.DataArray.prod` and :py:meth:`~xarray.Dataset.sum`, and
41+
:py:meth:`~xarray.Dataset.prod`.
42+
(:issue:`2230`)
43+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
44+
3945
- :py:meth:`plot()` now accepts the kwargs ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits.
4046
By `Deepak Cherian <https://github.com/dcherian>`_. (:issue:`2224`)
4147

@@ -78,6 +84,9 @@ Bug fixes
7884
- Tests can be run in parallel with pytest-xdist
7985
By `Tony Tung <https://github.com/ttung>`_.
8086

87+
- Follow up the renamings in dask; from dask.ghost to dask.overlap
88+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
89+
8190
- Now raises a ValueError when there is a conflict between dimension names and
8291
level names of MultiIndex. (:issue:`2299`)
8392
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

xarray/core/common.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import warnings
44
from distutils.version import LooseVersion
5+
from textwrap import dedent
56

67
import numpy as np
78
import pandas as pd
@@ -27,20 +28,20 @@ def wrapped_func(self, dim=None, axis=None, keep_attrs=False,
2728
allow_lazy=True, **kwargs)
2829
return wrapped_func
2930

30-
_reduce_extra_args_docstring = \
31-
"""dim : str or sequence of str, optional
31+
_reduce_extra_args_docstring = dedent("""\
32+
dim : str or sequence of str, optional
3233
Dimension(s) over which to apply `{name}`.
3334
axis : int or sequence of int, optional
3435
Axis(es) over which to apply `{name}`. Only one of the 'dim'
3536
and 'axis' arguments can be supplied. If neither are supplied, then
36-
`{name}` is calculated over axes."""
37+
`{name}` is calculated over axes.""")
3738

38-
_cum_extra_args_docstring = \
39-
"""dim : str or sequence of str, optional
39+
_cum_extra_args_docstring = dedent("""\
40+
dim : str or sequence of str, optional
4041
Dimension over which to apply `{name}`.
4142
axis : int or sequence of int, optional
4243
Axis over which to apply `{name}`. Only one of the 'dim'
43-
and 'axis' arguments can be supplied."""
44+
and 'axis' arguments can be supplied.""")
4445

4546

4647
class ImplementsDatasetReduce(object):
@@ -308,12 +309,12 @@ def assign_coords(self, **kwargs):
308309
assigned : same type as caller
309310
A new object with the new coordinates in addition to the existing
310311
data.
311-
312+
312313
Examples
313314
--------
314-
315+
315316
Convert longitude coordinates from 0-359 to -180-179:
316-
317+
317318
>>> da = xr.DataArray(np.random.rand(4),
318319
... coords=[np.array([358, 359, 0, 1])],
319320
... dims='lon')
@@ -445,11 +446,11 @@ def groupby(self, group, squeeze=True):
445446
grouped : GroupBy
446447
A `GroupBy` object patterned after `pandas.GroupBy` that can be
447448
iterated over in the form of `(unique_value, grouped_array)` pairs.
448-
449+
449450
Examples
450451
--------
451452
Calculate daily anomalies for daily data:
452-
453+
453454
>>> da = xr.DataArray(np.linspace(0, 1826, num=1827),
454455
... coords=[pd.date_range('1/1/2000', '31/12/2004',
455456
... freq='D')],
@@ -465,7 +466,7 @@ def groupby(self, group, squeeze=True):
465466
Coordinates:
466467
* time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ...
467468
dayofyear (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ...
468-
469+
469470
See Also
470471
--------
471472
core.groupby.DataArrayGroupBy
@@ -589,7 +590,7 @@ def resample(self, freq=None, dim=None, how=None, skipna=None,
589590
closed=None, label=None, base=0, keep_attrs=False, **indexer):
590591
"""Returns a Resample object for performing resampling operations.
591592
592-
Handles both downsampling and upsampling. If any intervals contain no
593+
Handles both downsampling and upsampling. If any intervals contain no
593594
values from the original object, they will be given the value ``NaN``.
594595
595596
Parameters
@@ -616,11 +617,11 @@ def resample(self, freq=None, dim=None, how=None, skipna=None,
616617
-------
617618
resampled : same type as caller
618619
This object resampled.
619-
620+
620621
Examples
621622
--------
622623
Downsample monthly time-series data to seasonal data:
623-
624+
624625
>>> da = xr.DataArray(np.linspace(0, 11, num=12),
625626
... coords=[pd.date_range('15/12/1999',
626627
... periods=12, freq=pd.DateOffset(months=1))],
@@ -637,13 +638,13 @@ def resample(self, freq=None, dim=None, how=None, skipna=None,
637638
* time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01
638639
639640
Upsample monthly time-series data to daily data:
640-
641+
641642
>>> da.resample(time='1D').interpolate('linear')
642643
<xarray.DataArray (time: 337)>
643644
array([ 0. , 0.032258, 0.064516, ..., 10.935484, 10.967742, 11. ])
644645
Coordinates:
645646
* time (time) datetime64[ns] 1999-12-15 1999-12-16 1999-12-17 ...
646-
647+
647648
References
648649
----------
649650
@@ -957,8 +958,8 @@ def contains_cftime_datetimes(var):
957958
sample = sample.item()
958959
return isinstance(sample, cftime_datetime)
959960
else:
960-
return False
961-
961+
return False
962+
962963

963964
def _contains_datetime_like_objects(var):
964965
"""Check if a variable contains datetime like objects (either

xarray/core/dtypes.py

+3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def maybe_promote(dtype):
9898
return np.dtype(dtype), fill_value
9999

100100

101+
NAT_TYPES = (np.datetime64('NaT'), np.timedelta64('NaT'))
102+
103+
101104
def get_fill_value(dtype):
102105
"""Return an appropriate fill value for this dtype.
103106

xarray/core/duck_array_ops.py

+44-142
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@
1717
from .nputils import nanfirst, nanlast
1818
from .pycompat import dask_array_type
1919

20-
try:
21-
import bottleneck as bn
22-
has_bottleneck = True
23-
except ImportError:
24-
# use numpy methods instead
25-
bn = np
26-
has_bottleneck = False
27-
2820
try:
2921
import dask.array as dask_array
3022
from . import dask_array_compat
@@ -175,7 +167,7 @@ def array_notnull_equiv(arr1, arr2):
175167
def count(data, axis=None):
176168
"""Count the number of non-NA in this array along the given axis or axes
177169
"""
178-
return sum(~isnull(data), axis=axis)
170+
return np.sum(~isnull(data), axis=axis)
179171

180172

181173
def where(condition, x, y):
@@ -213,159 +205,69 @@ def _ignore_warnings_if(condition):
213205
yield
214206

215207

216-
def _nansum_object(value, axis=None, **kwargs):
217-
""" In house nansum for object array """
218-
value = fillna(value, 0)
219-
return _dask_or_eager_func('sum')(value, axis=axis, **kwargs)
220-
221-
222-
def _nan_minmax_object(func, get_fill_value, value, axis=None, **kwargs):
223-
""" In house nanmin and nanmax for object array """
224-
fill_value = get_fill_value(value.dtype)
225-
valid_count = count(value, axis=axis)
226-
filled_value = fillna(value, fill_value)
227-
data = _dask_or_eager_func(func)(filled_value, axis=axis, **kwargs)
228-
if not hasattr(data, 'dtype'): # scalar case
229-
data = dtypes.fill_value(value.dtype) if valid_count == 0 else data
230-
return np.array(data, dtype=value.dtype)
231-
return where_method(data, valid_count != 0)
232-
233-
234-
def _nan_argminmax_object(func, get_fill_value, value, axis=None, **kwargs):
235-
""" In house nanargmin, nanargmax for object arrays. Always return integer
236-
type """
237-
fill_value = get_fill_value(value.dtype)
238-
valid_count = count(value, axis=axis)
239-
value = fillna(value, fill_value)
240-
data = _dask_or_eager_func(func)(value, axis=axis, **kwargs)
241-
# dask seems return non-integer type
242-
if isinstance(value, dask_array_type):
243-
data = data.astype(int)
244-
245-
if (valid_count == 0).any():
246-
raise ValueError('All-NaN slice encountered')
247-
248-
return np.array(data, dtype=int)
249-
250-
251-
def _nanmean_ddof_object(ddof, value, axis=None, **kwargs):
252-
""" In house nanmean. ddof argument will be used in _nanvar method """
253-
valid_count = count(value, axis=axis)
254-
value = fillna(value, 0)
255-
# As dtype inference is impossible for object dtype, we assume float
256-
# https://github.com/dask/dask/issues/3162
257-
dtype = kwargs.pop('dtype', None)
258-
if dtype is None and value.dtype.kind == 'O':
259-
dtype = value.dtype if value.dtype.kind in ['cf'] else float
260-
261-
data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs)
262-
data = data / (valid_count - ddof)
263-
return where_method(data, valid_count != 0)
264-
265-
266-
def _nanvar_object(value, axis=None, **kwargs):
267-
ddof = kwargs.pop('ddof', 0)
268-
kwargs_mean = kwargs.copy()
269-
kwargs_mean.pop('keepdims', None)
270-
value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis,
271-
keepdims=True, **kwargs_mean)
272-
squared = (value.astype(value_mean.dtype) - value_mean)**2
273-
return _nanmean_ddof_object(ddof, squared, axis=axis, **kwargs)
274-
275-
276-
_nan_object_funcs = {
277-
'sum': _nansum_object,
278-
'min': partial(_nan_minmax_object, 'min', dtypes.get_pos_infinity),
279-
'max': partial(_nan_minmax_object, 'max', dtypes.get_neg_infinity),
280-
'argmin': partial(_nan_argminmax_object, 'argmin',
281-
dtypes.get_pos_infinity),
282-
'argmax': partial(_nan_argminmax_object, 'argmax',
283-
dtypes.get_neg_infinity),
284-
'mean': partial(_nanmean_ddof_object, 0),
285-
'var': _nanvar_object,
286-
}
287-
288-
289-
def _create_nan_agg_method(name, numeric_only=False, np_compat=False,
290-
no_bottleneck=False, coerce_strings=False):
208+
def _create_nan_agg_method(name, coerce_strings=False):
209+
from . import nanops
210+
291211
def f(values, axis=None, skipna=None, **kwargs):
292212
if kwargs.pop('out', None) is not None:
293213
raise TypeError('`out` is not valid for {}'.format(name))
294214

295-
# If dtype is supplied, we use numpy's method.
296-
dtype = kwargs.get('dtype', None)
297215
values = asarray(values)
298216

299-
# dask requires dtype argument for object dtype
300-
if (values.dtype == 'object' and name in ['sum', ]):
301-
kwargs['dtype'] = values.dtype if dtype is None else dtype
302-
303217
if coerce_strings and values.dtype.kind in 'SU':
304218
values = values.astype(object)
305219

220+
func = None
306221
if skipna or (skipna is None and values.dtype.kind in 'cfO'):
307-
if values.dtype.kind not in ['u', 'i', 'f', 'c']:
308-
func = _nan_object_funcs.get(name, None)
309-
using_numpy_nan_func = True
310-
if func is None or values.dtype.kind not in 'Ob':
311-
raise NotImplementedError(
312-
'skipna=True not yet implemented for %s with dtype %s'
313-
% (name, values.dtype))
314-
else:
315-
nanname = 'nan' + name
316-
if (isinstance(axis, tuple) or not values.dtype.isnative or
317-
no_bottleneck or (dtype is not None and
318-
np.dtype(dtype) != values.dtype)):
319-
# bottleneck can't handle multiple axis arguments or
320-
# non-native endianness
321-
if np_compat:
322-
eager_module = npcompat
323-
else:
324-
eager_module = np
325-
else:
326-
kwargs.pop('dtype', None)
327-
eager_module = bn
328-
func = _dask_or_eager_func(nanname, eager_module)
329-
using_numpy_nan_func = (eager_module is np or
330-
eager_module is npcompat)
222+
nanname = 'nan' + name
223+
func = getattr(nanops, nanname)
331224
else:
332225
func = _dask_or_eager_func(name)
333-
using_numpy_nan_func = False
334-
with _ignore_warnings_if(using_numpy_nan_func):
335-
try:
336-
return func(values, axis=axis, **kwargs)
337-
except AttributeError:
338-
if isinstance(values, dask_array_type):
339-
try: # dask/dask#3133 dask sometimes needs dtype argument
340-
return func(values, axis=axis, dtype=values.dtype,
341-
**kwargs)
342-
except AttributeError:
343-
msg = '%s is not yet implemented on dask arrays' % name
344-
else:
345-
assert using_numpy_nan_func
346-
msg = ('%s is not available with skipna=False with the '
347-
'installed version of numpy; upgrade to numpy 1.12 '
348-
'or newer to use skipna=True or skipna=None' % name)
349-
raise NotImplementedError(msg)
350-
f.numeric_only = numeric_only
226+
227+
try:
228+
return func(values, axis=axis, **kwargs)
229+
except AttributeError:
230+
if isinstance(values, dask_array_type):
231+
try: # dask/dask#3133 dask sometimes needs dtype argument
232+
# if func does not accept dtype, then raises TypeError
233+
return func(values, axis=axis, dtype=values.dtype,
234+
**kwargs)
235+
except (AttributeError, TypeError):
236+
msg = '%s is not yet implemented on dask arrays' % name
237+
else:
238+
msg = ('%s is not available with skipna=False with the '
239+
'installed version of numpy; upgrade to numpy 1.12 '
240+
'or newer to use skipna=True or skipna=None' % name)
241+
raise NotImplementedError(msg)
242+
351243
f.__name__ = name
352244
return f
353245

354246

247+
# Attributes `numeric_only`, `available_min_count` is used for docs.
248+
# See ops.inject_reduce_methods
355249
argmax = _create_nan_agg_method('argmax', coerce_strings=True)
356250
argmin = _create_nan_agg_method('argmin', coerce_strings=True)
357251
max = _create_nan_agg_method('max', coerce_strings=True)
358252
min = _create_nan_agg_method('min', coerce_strings=True)
359-
sum = _create_nan_agg_method('sum', numeric_only=True)
360-
mean = _create_nan_agg_method('mean', numeric_only=True)
361-
std = _create_nan_agg_method('std', numeric_only=True)
362-
var = _create_nan_agg_method('var', numeric_only=True)
363-
median = _create_nan_agg_method('median', numeric_only=True)
364-
prod = _create_nan_agg_method('prod', numeric_only=True, no_bottleneck=True)
365-
cumprod_1d = _create_nan_agg_method(
366-
'cumprod', numeric_only=True, no_bottleneck=True)
367-
cumsum_1d = _create_nan_agg_method(
368-
'cumsum', numeric_only=True, no_bottleneck=True)
253+
sum = _create_nan_agg_method('sum')
254+
sum.numeric_only = True
255+
sum.available_min_count = True
256+
mean = _create_nan_agg_method('mean')
257+
mean.numeric_only = True
258+
std = _create_nan_agg_method('std')
259+
std.numeric_only = True
260+
var = _create_nan_agg_method('var')
261+
var.numeric_only = True
262+
median = _create_nan_agg_method('median')
263+
median.numeric_only = True
264+
prod = _create_nan_agg_method('prod')
265+
prod.numeric_only = True
266+
sum.available_min_count = True
267+
cumprod_1d = _create_nan_agg_method('cumprod')
268+
cumprod_1d.numeric_only = True
269+
cumsum_1d = _create_nan_agg_method('cumsum')
270+
cumsum_1d.numeric_only = True
369271

370272

371273
def _nd_cum_func(cum_func, array, axis, **kwargs):

0 commit comments

Comments
 (0)