Skip to content

Commit f120d65

Browse files
committed
Keep dtype whenever possible; add _update_array; docstring fixes
1 parent 07784f0 commit f120d65

File tree

5 files changed

+120
-29
lines changed

5 files changed

+120
-29
lines changed

pandas/core/generic.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -4178,7 +4178,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
41784178
"""
41794179
Modify in place using non-NA values from another DataFrame.
41804180
4181-
Aligns on indices. There is no return value.
4181+
Series/DataFrame will be aligned on indexes, and whenever possible,
4182+
the dtype of the individual Series of the caller will be preserved.
4183+
4184+
There is no return value.
41824185
41834186
Parameters
41844187
----------
@@ -4198,7 +4201,7 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
41984201
* False: only update values that are NA in
41994202
the original DataFrame.
42004203
4201-
filter_func : callable(1d-array) -> boolean 1d-array, optional
4204+
filter_func : callable(1d-array) -> bool 1d-array, optional
42024205
Can choose to replace values other than NA. Return True for values
42034206
that should be updated.
42044207
errors : {'raise', 'ignore'}, default 'ignore'
@@ -4208,7 +4211,7 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
42084211
Raises
42094212
------
42104213
ValueError
4211-
When `raise_conflict` is True and there's overlapping non-NA data.
4214+
When `errors='ignore'` and there's overlapping non-NA data.
42124215
42134216
Returns
42144217
-------
@@ -4275,10 +4278,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
42754278
>>> new_df = pd.DataFrame({'B': [4, np.nan, 6]})
42764279
>>> df.update(new_df)
42774280
>>> df
4278-
A B
4279-
0 1 4.0
4280-
1 2 500.0
4281-
2 3 6.0
4281+
A B
4282+
0 1 4
4283+
1 2 500
4284+
2 3 6
42824285
"""
42834286
from pandas import Series, DataFrame
42844287
# TODO: Support other joins
@@ -4292,14 +4295,20 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
42924295
this = self.values
42934296
that = other.values
42944297

4295-
# missing.update_array returns an np.ndarray
4296-
updated_values = missing.update_array(this, that,
4298+
# will return None if "this" remains unchanged
4299+
updated_array = missing._update_array(this, that,
42974300
overwrite=overwrite,
42984301
filter_func=filter_func,
42994302
errors=errors)
43004303
# don't overwrite unnecessarily
4301-
if updated_values is not None:
4302-
self._update_inplace(Series(updated_values, index=self.index))
4304+
if updated_array is not None:
4305+
# avoid unnecessary upcasting (introduced by alignment)
4306+
try:
4307+
updated = Series(updated_array, index=self.index,
4308+
dtype=this.dtype)
4309+
except ValueError:
4310+
updated = Series(updated_array, index=self.index)
4311+
self._update_inplace(updated)
43034312
else: # DataFrame
43044313
if not isinstance(other, ABCDataFrame):
43054314
other = DataFrame(other)
@@ -4310,11 +4319,23 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
43104319
this = self[col].values
43114320
that = other[col].values
43124321

4313-
updated = missing.update_array(this, that, overwrite=overwrite,
4314-
filter_func=filter_func,
4315-
errors=errors)
4322+
# will return None if "this" remains unchanged
4323+
updated_array = missing._update_array(this, that,
4324+
overwrite=overwrite,
4325+
filter_func=filter_func,
4326+
errors=errors)
43164327
# don't overwrite unnecessarily
4317-
if updated is not None:
4328+
if updated_array is not None:
4329+
# no problem to set DataFrame column with array
4330+
updated = updated_array
4331+
4332+
if updated_array.dtype != this.dtype:
4333+
# avoid unnecessary upcasting (introduced by alignment)
4334+
try:
4335+
updated = Series(updated_array, index=self.index,
4336+
dtype=this.dtype)
4337+
except ValueError:
4338+
pass
43184339
self[col] = updated
43194340

43204341
def filter(self, items=None, like=None, regex=None, axis=None):

pandas/core/missing.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,25 @@ def update_array(this, that, overwrite=True, filter_func=None,
106106
107107
Returns
108108
-------
109-
updated : np.ndarray (one-dimensional) or None
110-
The updated array. Return None if `this` remains unchanged
109+
updated : np.ndarray (one-dimensional)
110+
The updated array.
111111
112112
See Also
113113
--------
114114
Series.update : Similar method for `Series`.
115115
DataFrame.update : Similar method for `DataFrame`.
116116
dict.update : Similar method for `dict`.
117117
"""
118+
updated = _update_array(this, that, overwrite=overwrite,
119+
filter_func=filter_func, errors=errors)
120+
return this if updated is None else updated
121+
122+
123+
def _update_array(this, that, overwrite=True, filter_func=None,
124+
errors='ignore'):
125+
"""
126+
Same as update_array, except we return None if `this` is not updated.
127+
"""
118128
import pandas.core.computation.expressions as expressions
119129

120130
if filter_func is not None:

pandas/core/series.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -2390,7 +2390,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
23902390
"""
23912391
Modify Series in place using non-NA values from passed Series.
23922392
2393-
Aligns on index.
2393+
Series will be aligned on indexes, and whenever possible, the dtype of
2394+
the caller will be preserved.
2395+
2396+
There is no return value.
23942397
23952398
Parameters
23962399
----------
@@ -2411,7 +2414,7 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
24112414
the original DataFrame.
24122415
24132416
.. versionadded:: 0.24.0
2414-
filter_func : callable(1d-array) -> boolean 1d-array, optional
2417+
filter_func : callable(1d-array) -> bool 1d-array, optional
24152418
Can choose to replace values other than NA. Return True for values
24162419
that should be updated.
24172420
@@ -2422,10 +2425,19 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
24222425
24232426
.. versionadded:: 0.24.0
24242427
2428+
Raises
2429+
------
2430+
ValueError
2431+
When `errors='ignore'` and there's overlapping non-NA data.
2432+
2433+
Returns
2434+
-------
2435+
Nothing, the Series is modified inplace.
2436+
24252437
See Also
24262438
--------
24272439
DataFrame.update : Similar method for `DataFrame`.
2428-
dict.update : Similar method for `dict`
2440+
dict.update : Similar method for `dict`.
24292441
24302442
Examples
24312443
--------
@@ -2459,10 +2471,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
24592471
>>> s = pd.Series([1, 2, 3])
24602472
>>> s.update(pd.Series([4, np.nan, 6]))
24612473
>>> s
2462-
0 4.0
2463-
1 2.0
2464-
2 6.0
2465-
dtype: float64
2474+
0 4
2475+
1 2
2476+
2 6
2477+
dtype: int64
24662478
"""
24672479
super(Series, self).update(other, join=join, overwrite=overwrite,
24682480
filter_func=filter_func,

pandas/tests/frame/test_combine_concat.py

+19
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,25 @@ def test_update_dtypes(self):
279279
columns=['A', 'B', 'bool1', 'bool2'])
280280
assert_frame_equal(df, expected)
281281

282+
df = DataFrame([[10, 100], [11, 101], [12, 102]], columns=['A', 'B'])
283+
other = DataFrame([[61, 601], [63, 603]], columns=['A', 'B'],
284+
index=[1, 3])
285+
df.update(other)
286+
287+
expected = DataFrame([[10, 100], [61, 601], [12, 102]],
288+
columns=['A', 'B'])
289+
assert_frame_equal(df, expected)
290+
291+
# we always try to keep original dtype, even if other has different one
292+
df.update(other.astype(float))
293+
assert_frame_equal(df, expected)
294+
295+
# if keeping the dtype is not possible, we allow upcasting
296+
df.update(other + 0.1)
297+
expected = DataFrame([[10., 100.], [61.1, 601.1], [12., 102.]],
298+
columns=['A', 'B'])
299+
assert_frame_equal(df, expected)
300+
282301
def test_update_nooverwrite(self):
283302
df = DataFrame([[1.5, nan, 3.],
284303
[1.5, nan, 3.],

pandas/tests/series/test_combine_concat.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pandas as pd
1111
from pandas import DataFrame, DatetimeIndex, Series, compat, date_range
1212
import pandas.util.testing as tm
13-
from pandas.util.testing import assert_series_equal
13+
from pandas.util.testing import assert_series_equal, assert_frame_equal
1414

1515

1616
class TestSeriesCombine():
@@ -105,8 +105,8 @@ def test_combine_first(self):
105105
assert_series_equal(s, result)
106106

107107
def test_update(self):
108-
s = Series([1.5, nan, 3., 4., nan])
109-
s2 = Series([nan, 3.5, nan, 5.])
108+
s = Series([1.5, np.nan, 3., 4., np.nan])
109+
s2 = Series([np.nan, 3.5, np.nan, 5.])
110110
s.update(s2)
111111

112112
expected = Series([1.5, 3.5, 3., 5., np.nan])
@@ -116,8 +116,35 @@ def test_update(self):
116116
df = DataFrame([{"a": 1}, {"a": 3, "b": 2}])
117117
df['c'] = np.nan
118118

119-
# this will fail as long as series is a sub-class of ndarray
120-
# df['c'].update(Series(['foo'],index=[0])) #####
119+
df['c'].update(Series(['foo'], index=[0]))
120+
expected = DataFrame([[1, np.nan, 'foo'], [3, 2., np.nan]],
121+
columns=['a', 'b', 'c'])
122+
assert_frame_equal(df, expected)
123+
124+
def test_update_dtypes(self):
125+
s = Series([1., 2., False, True])
126+
127+
other = Series([45])
128+
s.update(other)
129+
130+
expected = Series([45., 2., False, True])
131+
assert_series_equal(s, expected)
132+
133+
s = Series([10, 11, 12])
134+
other = Series([61, 63], index=[1, 3])
135+
s.update(other)
136+
137+
expected = Series([10, 61, 12])
138+
assert_series_equal(s, expected)
139+
140+
# we always try to keep original dtype, even if other has different one
141+
s.update(other.astype(float))
142+
assert_series_equal(s, expected)
143+
144+
# if keeping the dtype is not possible, we allow upcasting
145+
s.update(other + 0.1)
146+
expected = Series([10., 61.1, 12.])
147+
assert_series_equal(s, expected)
121148

122149
def test_update_nooverwrite(self):
123150
s = Series([0, 1, 2, np.nan, np.nan, 5, 6, np.nan])
@@ -129,6 +156,8 @@ def test_update_nooverwrite(self):
129156
assert_series_equal(s, expected)
130157

131158
def test_update_filtered(self):
159+
# for small values, np.arange defaults to int32,
160+
# but pandas default (e.g. for "expected" below) is int64
132161
s = Series(np.arange(8), dtype='int64')
133162
other = Series(np.arange(8), dtype='int64') + 10
134163

0 commit comments

Comments
 (0)