46
46
from pyspark .sql .types import (
47
47
NumericType ,
48
48
StructField ,
49
- TimestampType ,
49
+ TimestampNTZType ,
50
+ DataType ,
50
51
)
51
52
52
53
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
@@ -130,6 +131,13 @@ def _resamplekey_scol(self) -> Column:
130
131
else :
131
132
return self ._resamplekey .spark .column
132
133
134
+ @property
135
+ def _resamplekey_type (self ) -> DataType :
136
+ if self ._resamplekey is None :
137
+ return self ._psdf .index .spark .data_type
138
+ else :
139
+ return self ._resamplekey .spark .data_type
140
+
133
141
@property
134
142
def _agg_columns_scols (self ) -> List [Column ]:
135
143
return [s .spark .column for s in self ._agg_columns ]
@@ -154,7 +162,8 @@ def get_make_interval( # type: ignore[return]
154
162
col = col ._jc if isinstance (col , Column ) else F .lit (col )._jc
155
163
return sql_utils .makeInterval (unit , col )
156
164
157
- def _bin_time_stamp (self , origin : pd .Timestamp , ts_scol : Column ) -> Column :
165
+ def _bin_timestamp (self , origin : pd .Timestamp , ts_scol : Column ) -> Column :
166
+ key_type = self ._resamplekey_type
158
167
origin_scol = F .lit (origin )
159
168
(rule_code , n ) = (self ._offset .rule_code , getattr (self ._offset , "n" ))
160
169
left_closed , right_closed = (self ._closed == "left" , self ._closed == "right" )
@@ -188,7 +197,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
188
197
F .year (ts_scol ) - (mod - n )
189
198
)
190
199
191
- return F .to_timestamp (
200
+ ret = F .to_timestamp (
192
201
F .make_date (
193
202
F .when (edge_cond , edge_label ).otherwise (non_edge_label ), F .lit (12 ), F .lit (31 )
194
203
)
@@ -227,7 +236,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
227
236
truncated_ts_scol - self .get_make_interval ("MONTH" , mod - n )
228
237
)
229
238
230
- return F .to_timestamp (
239
+ ret = F .to_timestamp (
231
240
F .last_day (F .when (edge_cond , edge_label ).otherwise (non_edge_label ))
232
241
)
233
242
@@ -242,15 +251,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
242
251
)
243
252
244
253
if left_closed and left_labeled :
245
- return F .date_trunc ("DAY" , ts_scol )
254
+ ret = F .date_trunc ("DAY" , ts_scol )
246
255
elif left_closed and right_labeled :
247
- return F .date_trunc ("DAY" , F .date_add (ts_scol , 1 ))
256
+ ret = F .date_trunc ("DAY" , F .date_add (ts_scol , 1 ))
248
257
elif right_closed and left_labeled :
249
- return F .when (edge_cond , F .date_trunc ("DAY" , F .date_sub (ts_scol , 1 ))).otherwise (
258
+ ret = F .when (edge_cond , F .date_trunc ("DAY" , F .date_sub (ts_scol , 1 ))).otherwise (
250
259
F .date_trunc ("DAY" , ts_scol )
251
260
)
252
261
else :
253
- return F .when (edge_cond , F .date_trunc ("DAY" , ts_scol )).otherwise (
262
+ ret = F .when (edge_cond , F .date_trunc ("DAY" , ts_scol )).otherwise (
254
263
F .date_trunc ("DAY" , F .date_add (ts_scol , 1 ))
255
264
)
256
265
@@ -272,13 +281,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
272
281
else :
273
282
non_edge_label = F .date_sub (truncated_ts_scol , mod - n )
274
283
275
- return F .when (edge_cond , edge_label ).otherwise (non_edge_label )
284
+ ret = F .when (edge_cond , edge_label ).otherwise (non_edge_label )
276
285
277
286
elif rule_code in ["H" , "T" , "S" ]:
278
287
unit_mapping = {"H" : "HOUR" , "T" : "MINUTE" , "S" : "SECOND" }
279
288
unit_str = unit_mapping [rule_code ]
280
289
281
290
truncated_ts_scol = F .date_trunc (unit_str , ts_scol )
291
+ if isinstance (key_type , TimestampNTZType ):
292
+ truncated_ts_scol = F .to_timestamp_ntz (truncated_ts_scol )
282
293
diff = timestampdiff (unit_str , origin_scol , truncated_ts_scol )
283
294
mod = F .lit (0 ) if n == 1 else (diff % F .lit (n ))
284
295
@@ -307,11 +318,16 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
307
318
truncated_ts_scol + self .get_make_interval (unit_str , n ),
308
319
).otherwise (truncated_ts_scol - self .get_make_interval (unit_str , mod - n ))
309
320
310
- return F .when (edge_cond , edge_label ).otherwise (non_edge_label )
321
+ ret = F .when (edge_cond , edge_label ).otherwise (non_edge_label )
311
322
312
323
else :
313
324
raise ValueError ("Got the unexpected unit {}" .format (rule_code ))
314
325
326
+ if isinstance (key_type , TimestampNTZType ):
327
+ return F .to_timestamp_ntz (ret )
328
+ else :
329
+ return ret
330
+
315
331
def _downsample (self , f : str ) -> DataFrame :
316
332
"""
317
333
Downsample the defined function.
@@ -374,12 +390,9 @@ def _downsample(self, f: str) -> DataFrame:
374
390
bin_col_label = verify_temp_column_name (self ._psdf , bin_col_name )
375
391
bin_col_field = InternalField (
376
392
dtype = np .dtype ("datetime64[ns]" ),
377
- struct_field = StructField (bin_col_name , TimestampType (), True ),
378
- )
379
- bin_scol = self ._bin_time_stamp (
380
- ts_origin ,
381
- self ._resamplekey_scol ,
393
+ struct_field = StructField (bin_col_name , self ._resamplekey_type , True ),
382
394
)
395
+ bin_scol = self ._bin_timestamp (ts_origin , self ._resamplekey_scol )
383
396
384
397
agg_columns = [
385
398
psser for psser in self ._agg_columns if (isinstance (psser .spark .data_type , NumericType ))
0 commit comments