Skip to content

Commit 287817a

Browse files
move check for datetime tz to hashing function
1 parent 3bd0404 commit 287817a

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed

pandas/core/dtypes/cast.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def maybe_promote(dtype, fill_value=np.nan):
333333
return dtype, fill_value
334334

335335

336-
def infer_dtype_from_scalar(val, pandas_dtype=False, use_datetimetz=True):
336+
def infer_dtype_from_scalar(val, pandas_dtype=False):
337337
"""
338338
interpret the dtype from a scalar
339339
@@ -368,7 +368,7 @@ def infer_dtype_from_scalar(val, pandas_dtype=False, use_datetimetz=True):
368368

369369
elif isinstance(val, (np.datetime64, datetime)):
370370
val = tslib.Timestamp(val)
371-
if val is tslib.NaT or val.tz is None or not use_datetimetz:
371+
if val is tslib.NaT or val.tz is None:
372372
dtype = np.dtype('M8[ns]')
373373
else:
374374
if pandas_dtype:

pandas/core/util/hashing.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import itertools
55

66
import numpy as np
7-
from pandas._libs import hashing
7+
from pandas._libs import hashing, tslib
88
from pandas.core.dtypes.generic import (
99
ABCMultiIndex,
1010
ABCIndexClass,
@@ -317,7 +317,15 @@ def _hash_scalar(val, encoding='utf8', hash_key=None):
317317
# this is to be consistent with the _hash_categorical implementation
318318
return np.array([np.iinfo(np.uint64).max], dtype='u8')
319319

320-
dtype, val = infer_dtype_from_scalar(val, use_datetimetz=False)
320+
if getattr(val, 'tzinfo', None) is not None:
321+
# for tz-aware datetimes, we need the underlying naive UTC value and
322+
# not the tz aware object or pd extension type (as
323+
# infer_dtype_from_scalar would do)
324+
if not isinstance(val, tslib.Timestamp):
325+
val = tslib.Timestamp(val)
326+
val = val.tz_convert(None)
327+
328+
dtype, val = infer_dtype_from_scalar(val)
321329
vals = np.array([val], dtype=dtype)
322330

323331
return hash_array(vals, hash_key=hash_key, encoding=encoding,

pandas/tests/util/test_hashing.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import datetime
23

34
from warnings import catch_warnings
45
import numpy as np
@@ -81,16 +82,20 @@ def test_hash_tuples(self):
8182

8283
def test_hash_tuple(self):
8384
# test equivalence between hash_tuples and hash_tuple
84-
for tup in [(1, 'one'), (1, np.nan), (1.0, pd.NaT, 'A')]:
85+
for tup in [(1, 'one'), (1, np.nan), (1.0, pd.NaT, 'A'),
86+
('A', pd.Timestamp("2012-01-01"))]:
8587
result = hash_tuple(tup)
8688
expected = hash_tuples([tup])[0]
8789
assert result == expected
8890

8991
def test_hash_scalar(self):
9092
for val in [1, 1.4, 'A', b'A', u'A', pd.Timestamp("2012-01-01"),
9193
pd.Timestamp("2012-01-01", tz='Europe/Brussels'),
92-
pd.Period('2012-01-01', freq='D'), pd.Timedelta('1 days'),
93-
pd.Interval(0, 1), np.nan, pd.NaT, None]:
94+
datetime.datetime(2012, 1, 1),
95+
pd.Timestamp("2012-01-01", tz='EST').to_pydatetime(),
96+
pd.Timedelta('1 days'), datetime.timedelta(1),
97+
pd.Period('2012-01-01', freq='D'), pd.Interval(0, 1),
98+
np.nan, pd.NaT, None]:
9499
result = _hash_scalar(val)
95100
expected = hash_array(np.array([val], dtype=object),
96101
categorize=True)

0 commit comments

Comments
 (0)