11
11
import pandas .core .algorithms as algos
12
12
import pandas .core .nanops as nanops
13
13
from pandas .compat import zip
14
+ from pandas import to_timedelta , to_datetime
15
+ from pandas .types .common import is_datetime64_dtype , is_timedelta64_dtype
14
16
15
17
import numpy as np
16
18
@@ -81,14 +83,17 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=3,
81
83
array([1, 1, 1, 1, 1], dtype=int64)
82
84
"""
83
85
# NOTE: this binning code is changed a bit from histogram for var(x) == 0
86
+
87
+ # for handling the cut for datetime and timedelta objects
88
+ x_is_series , series_index , name , x = _preprocess_for_cut (x )
89
+ x , dtype = _coerce_to_type (x )
90
+
84
91
if not np .iterable (bins ):
85
92
if is_scalar (bins ) and bins < 1 :
86
93
raise ValueError ("`bins` should be a positive integer." )
87
- try : # for array-like
88
- sz = x .size
89
- except AttributeError :
90
- x = np .asarray (x )
91
- sz = x .size
94
+
95
+ sz = x .size
96
+
92
97
if sz == 0 :
93
98
raise ValueError ('Cannot cut empty array' )
94
99
# handle empty arrays. Can't determine range, so use 0-1.
@@ -114,9 +119,12 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=3,
114
119
if (np .diff (bins ) < 0 ).any ():
115
120
raise ValueError ('bins must increase monotonically.' )
116
121
117
- return _bins_to_cuts (x , bins , right = right , labels = labels ,
118
- retbins = retbins , precision = precision ,
119
- include_lowest = include_lowest )
122
+ fac , bins = _bins_to_cuts (x , bins , right = right , labels = labels ,
123
+ precision = precision ,
124
+ include_lowest = include_lowest , dtype = dtype )
125
+
126
+ return _postprocess_for_cut (fac , bins , retbins , x_is_series ,
127
+ series_index , name )
120
128
121
129
122
130
def qcut (x , q , labels = None , retbins = False , precision = 3 ):
@@ -166,26 +174,26 @@ def qcut(x, q, labels=None, retbins=False, precision=3):
166
174
>>> pd.qcut(range(5), 4, labels=False)
167
175
array([0, 0, 1, 2, 3], dtype=int64)
168
176
"""
177
+ x_is_series , series_index , name , x = _preprocess_for_cut (x )
178
+
179
+ x , dtype = _coerce_to_type (x )
180
+
169
181
if is_integer (q ):
170
182
quantiles = np .linspace (0 , 1 , q + 1 )
171
183
else :
172
184
quantiles = q
173
185
bins = algos .quantile (x , quantiles )
174
- return _bins_to_cuts (x , bins , labels = labels , retbins = retbins ,
175
- precision = precision , include_lowest = True )
186
+ fac , bins = _bins_to_cuts (x , bins , labels = labels ,
187
+ precision = precision , include_lowest = True ,
188
+ dtype = dtype )
176
189
190
+ return _postprocess_for_cut (fac , bins , retbins , x_is_series ,
191
+ series_index , name )
177
192
178
- def _bins_to_cuts (x , bins , right = True , labels = None , retbins = False ,
179
- precision = 3 , name = None , include_lowest = False ):
180
- x_is_series = isinstance (x , Series )
181
- series_index = None
182
-
183
- if x_is_series :
184
- series_index = x .index
185
- if name is None :
186
- name = x .name
187
193
188
- x = np .asarray (x )
194
+ def _bins_to_cuts (x , bins , right = True , labels = None ,
195
+ precision = 3 , include_lowest = False ,
196
+ dtype = None ):
189
197
190
198
side = 'left' if right else 'right'
191
199
ids = bins .searchsorted (x , side = side )
@@ -205,7 +213,8 @@ def _bins_to_cuts(x, bins, right=True, labels=None, retbins=False,
205
213
while True :
206
214
try :
207
215
levels = _format_levels (bins , precision , right = right ,
208
- include_lowest = include_lowest )
216
+ include_lowest = include_lowest ,
217
+ dtype = dtype )
209
218
except ValueError :
210
219
increases += 1
211
220
precision += 1
@@ -229,18 +238,12 @@ def _bins_to_cuts(x, bins, right=True, labels=None, retbins=False,
229
238
fac = fac .astype (np .float64 )
230
239
np .putmask (fac , na_mask , np .nan )
231
240
232
- if x_is_series :
233
- fac = Series (fac , index = series_index , name = name )
234
-
235
- if not retbins :
236
- return fac
237
-
238
241
return fac , bins
239
242
240
243
241
244
def _format_levels (bins , prec , right = True ,
242
- include_lowest = False ):
243
- fmt = lambda v : _format_label (v , precision = prec )
245
+ include_lowest = False , dtype = None ):
246
+ fmt = lambda v : _format_label (v , precision = prec , dtype = dtype )
244
247
if right :
245
248
levels = []
246
249
for a , b in zip (bins , bins [1 :]):
@@ -258,12 +261,16 @@ def _format_levels(bins, prec, right=True,
258
261
else :
259
262
levels = ['[%s, %s)' % (fmt (a ), fmt (b ))
260
263
for a , b in zip (bins , bins [1 :])]
261
-
262
264
return levels
263
265
264
266
265
- def _format_label (x , precision = 3 ):
267
+ def _format_label (x , precision = 3 , dtype = None ):
266
268
fmt_str = '%%.%dg' % precision
269
+
270
+ if is_datetime64_dtype (dtype ):
271
+ return to_datetime (x , unit = 'ns' )
272
+ if is_timedelta64_dtype (dtype ):
273
+ return to_timedelta (x , unit = 'ns' )
267
274
if np .isinf (x ):
268
275
return str (x )
269
276
elif is_float (x ):
@@ -300,3 +307,55 @@ def _trim_zeros(x):
300
307
if len (x ) > 1 and x [- 1 ] == '.' :
301
308
x = x [:- 1 ]
302
309
return x
310
+
311
+
312
+ def _coerce_to_type (x ):
313
+ """
314
+ if the passed data is of datetime/timedelta type,
315
+ this method converts it to integer so that cut method can
316
+ handle it
317
+ """
318
+ dtype = None
319
+
320
+ if is_timedelta64_dtype (x ):
321
+ x = to_timedelta (x ).view (np .int64 )
322
+ dtype = np .timedelta64
323
+ elif is_datetime64_dtype (x ):
324
+ x = to_datetime (x ).view (np .int64 )
325
+ dtype = np .datetime64
326
+
327
+ return x , dtype
328
+
329
+
330
+ def _preprocess_for_cut (x ):
331
+ """
332
+ handles preprocessing for cut where we convert passed
333
+ input to array, strip the index information and store it
334
+ seperately
335
+ """
336
+ x_is_series = isinstance (x , Series )
337
+ series_index = None
338
+ name = None
339
+
340
+ if x_is_series :
341
+ series_index = x .index
342
+ name = x .name
343
+
344
+ x = np .asarray (x )
345
+
346
+ return x_is_series , series_index , name , x
347
+
348
+
349
+ def _postprocess_for_cut (fac , bins , retbins , x_is_series , series_index , name ):
350
+ """
351
+ handles post processing for the cut method where
352
+ we combine the index information if the originally passed
353
+ datatype was a series
354
+ """
355
+ if x_is_series :
356
+ fac = Series (fac , index = series_index , name = name )
357
+
358
+ if not retbins :
359
+ return fac
360
+
361
+ return fac , bins
0 commit comments