Skip to content

Commit 25a9e3d

Browse files
HyukjinKwonvpolet
authored and
vpolet
committed
[SPARK-44717][PYTHON][PS] Respect TimestampNTZ in resampling
### What changes were proposed in this pull request? This PR proposes to respect `TimestampNTZ` type in resampling at pandas API on Spark. ### Why are the changes needed? It still operates as if the timestamps are `TIMESTAMP_LTZ` even when `spark.sql.timestampType` is set to `TIMESTAMP_NTZ`, which is unexpected. ### Does this PR introduce _any_ user-facing change? This fixes a bug so end users can use exactly same behaviour with pandas with `TimestampNTZType` - pandas does not respect the local timezone with DST. While we might need to follow this even for `TimestampType`, this PR does not address the case as it might be controversial. ### How was this patch tested? Unittest was added. Closes apache#42392 from HyukjinKwon/SPARK-44717. Authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent c53d4e0 commit 25a9e3d

File tree

4 files changed

+88
-18
lines changed

4 files changed

+88
-18
lines changed

python/pyspark/pandas/frame.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13155,7 +13155,9 @@ def resample(
1315513155

1315613156
if on is None and not isinstance(self.index, DatetimeIndex):
1315713157
raise NotImplementedError("resample currently works only for DatetimeIndex")
13158-
if on is not None and not isinstance(as_spark_type(on.dtype), TimestampType):
13158+
if on is not None and not isinstance(
13159+
as_spark_type(on.dtype), (TimestampType, TimestampNTZType)
13160+
):
1315913161
raise NotImplementedError("`on` currently works only for TimestampType")
1316013162

1316113163
agg_columns: List[ps.Series] = []

python/pyspark/pandas/resample.py

+28-15
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
from pyspark.sql.types import (
4747
NumericType,
4848
StructField,
49-
TimestampType,
49+
TimestampNTZType,
50+
DataType,
5051
)
5152

5253
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
@@ -130,6 +131,13 @@ def _resamplekey_scol(self) -> Column:
130131
else:
131132
return self._resamplekey.spark.column
132133

134+
@property
135+
def _resamplekey_type(self) -> DataType:
136+
if self._resamplekey is None:
137+
return self._psdf.index.spark.data_type
138+
else:
139+
return self._resamplekey.spark.data_type
140+
133141
@property
134142
def _agg_columns_scols(self) -> List[Column]:
135143
return [s.spark.column for s in self._agg_columns]
@@ -154,7 +162,8 @@ def get_make_interval( # type: ignore[return]
154162
col = col._jc if isinstance(col, Column) else F.lit(col)._jc
155163
return sql_utils.makeInterval(unit, col)
156164

157-
def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
165+
def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
166+
key_type = self._resamplekey_type
158167
origin_scol = F.lit(origin)
159168
(rule_code, n) = (self._offset.rule_code, getattr(self._offset, "n"))
160169
left_closed, right_closed = (self._closed == "left", self._closed == "right")
@@ -188,7 +197,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
188197
F.year(ts_scol) - (mod - n)
189198
)
190199

191-
return F.to_timestamp(
200+
ret = F.to_timestamp(
192201
F.make_date(
193202
F.when(edge_cond, edge_label).otherwise(non_edge_label), F.lit(12), F.lit(31)
194203
)
@@ -227,7 +236,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
227236
truncated_ts_scol - self.get_make_interval("MONTH", mod - n)
228237
)
229238

230-
return F.to_timestamp(
239+
ret = F.to_timestamp(
231240
F.last_day(F.when(edge_cond, edge_label).otherwise(non_edge_label))
232241
)
233242

@@ -242,15 +251,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
242251
)
243252

244253
if left_closed and left_labeled:
245-
return F.date_trunc("DAY", ts_scol)
254+
ret = F.date_trunc("DAY", ts_scol)
246255
elif left_closed and right_labeled:
247-
return F.date_trunc("DAY", F.date_add(ts_scol, 1))
256+
ret = F.date_trunc("DAY", F.date_add(ts_scol, 1))
248257
elif right_closed and left_labeled:
249-
return F.when(edge_cond, F.date_trunc("DAY", F.date_sub(ts_scol, 1))).otherwise(
258+
ret = F.when(edge_cond, F.date_trunc("DAY", F.date_sub(ts_scol, 1))).otherwise(
250259
F.date_trunc("DAY", ts_scol)
251260
)
252261
else:
253-
return F.when(edge_cond, F.date_trunc("DAY", ts_scol)).otherwise(
262+
ret = F.when(edge_cond, F.date_trunc("DAY", ts_scol)).otherwise(
254263
F.date_trunc("DAY", F.date_add(ts_scol, 1))
255264
)
256265

@@ -272,13 +281,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
272281
else:
273282
non_edge_label = F.date_sub(truncated_ts_scol, mod - n)
274283

275-
return F.when(edge_cond, edge_label).otherwise(non_edge_label)
284+
ret = F.when(edge_cond, edge_label).otherwise(non_edge_label)
276285

277286
elif rule_code in ["H", "T", "S"]:
278287
unit_mapping = {"H": "HOUR", "T": "MINUTE", "S": "SECOND"}
279288
unit_str = unit_mapping[rule_code]
280289

281290
truncated_ts_scol = F.date_trunc(unit_str, ts_scol)
291+
if isinstance(key_type, TimestampNTZType):
292+
truncated_ts_scol = F.to_timestamp_ntz(truncated_ts_scol)
282293
diff = timestampdiff(unit_str, origin_scol, truncated_ts_scol)
283294
mod = F.lit(0) if n == 1 else (diff % F.lit(n))
284295

@@ -307,11 +318,16 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
307318
truncated_ts_scol + self.get_make_interval(unit_str, n),
308319
).otherwise(truncated_ts_scol - self.get_make_interval(unit_str, mod - n))
309320

310-
return F.when(edge_cond, edge_label).otherwise(non_edge_label)
321+
ret = F.when(edge_cond, edge_label).otherwise(non_edge_label)
311322

312323
else:
313324
raise ValueError("Got the unexpected unit {}".format(rule_code))
314325

326+
if isinstance(key_type, TimestampNTZType):
327+
return F.to_timestamp_ntz(ret)
328+
else:
329+
return ret
330+
315331
def _downsample(self, f: str) -> DataFrame:
316332
"""
317333
Downsample the defined function.
@@ -374,12 +390,9 @@ def _downsample(self, f: str) -> DataFrame:
374390
bin_col_label = verify_temp_column_name(self._psdf, bin_col_name)
375391
bin_col_field = InternalField(
376392
dtype=np.dtype("datetime64[ns]"),
377-
struct_field=StructField(bin_col_name, TimestampType(), True),
378-
)
379-
bin_scol = self._bin_time_stamp(
380-
ts_origin,
381-
self._resamplekey_scol,
393+
struct_field=StructField(bin_col_name, self._resamplekey_type, True),
382394
)
395+
bin_scol = self._bin_timestamp(ts_origin, self._resamplekey_scol)
383396

384397
agg_columns = [
385398
psser for psser in self._agg_columns if (isinstance(psser.spark.data_type, NumericType))

python/pyspark/pandas/tests/connect/test_parity_resample.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,25 @@
1616
#
1717
import unittest
1818

19-
from pyspark.pandas.tests.test_resample import ResampleTestsMixin
19+
from pyspark.pandas.tests.test_resample import ResampleTestsMixin, ResampleWithTimezoneMixin
2020
from pyspark.testing.connectutils import ReusedConnectTestCase
2121
from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
2222

2323

24-
class ResampleTestsParityMixin(
24+
class ResampleParityTests(
2525
ResampleTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
2626
):
2727
pass
2828

2929

30+
class ResampleWithTimezoneTests(
31+
ResampleWithTimezoneMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
32+
):
33+
@unittest.skip("SPARK-44731: Support 'spark.sql.timestampType' in Python Spark Connect client")
34+
def test_series_resample_with_timezone(self):
35+
super().test_series_resample_with_timezone()
36+
37+
3038
if __name__ == "__main__":
3139
from pyspark.pandas.tests.connect.test_parity_resample import * # noqa: F401
3240

python/pyspark/pandas/tests/test_resample.py

+47
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import unittest
2020
import inspect
2121
import datetime
22+
import os
23+
2224
import numpy as np
2325
import pandas as pd
2426

@@ -283,10 +285,55 @@ def test_resample_on(self):
283285
)
284286

285287

288+
class ResampleWithTimezoneMixin:
289+
timezone = None
290+
291+
@classmethod
292+
def setUpClass(cls):
293+
cls.timezone = os.environ.get("TZ", None)
294+
os.environ["TZ"] = "America/New_York"
295+
super(ResampleWithTimezoneMixin, cls).setUpClass()
296+
297+
@classmethod
298+
def tearDownClass(cls):
299+
super(ResampleWithTimezoneMixin, cls).tearDownClass()
300+
if cls.timezone is not None:
301+
os.environ["TZ"] = cls.timezone
302+
303+
@property
304+
def pdf(self):
305+
np.random.seed(22)
306+
index = pd.date_range(start="2011-01-02", end="2022-05-01", freq="1D")
307+
return pd.DataFrame(np.random.rand(len(index), 2), index=index, columns=list("AB"))
308+
309+
@property
310+
def psdf(self):
311+
return ps.from_pandas(self.pdf)
312+
313+
def test_series_resample_with_timezone(self):
314+
with self.sql_conf(
315+
{
316+
"spark.sql.session.timeZone": "Asia/Seoul",
317+
"spark.sql.timestampType": "TIMESTAMP_NTZ",
318+
}
319+
):
320+
p_resample = self.pdf.resample(rule="1001H", closed="right", label="right")
321+
ps_resample = self.psdf.resample(rule="1001H", closed="right", label="right")
322+
self.assert_eq(
323+
p_resample.sum().sort_index(),
324+
ps_resample.sum().sort_index(),
325+
almost=True,
326+
)
327+
328+
286329
class ResampleTests(ResampleTestsMixin, PandasOnSparkTestCase, TestUtils):
287330
pass
288331

289332

333+
class ResampleWithTimezoneTests(ResampleWithTimezoneMixin, PandasOnSparkTestCase, TestUtils):
334+
pass
335+
336+
290337
if __name__ == "__main__":
291338
from pyspark.pandas.tests.test_resample import * # noqa: F401
292339

0 commit comments

Comments
 (0)