@@ -237,7 +237,7 @@ def test_groupby_quantile_nullable_array(values, q):
237
237
idx = Index (["x" , "y" ], name = "a" )
238
238
true_quantiles = [0.5 ]
239
239
240
- expected = pd .Series (true_quantiles * 2 , index = idx , name = "b" )
240
+ expected = pd .Series (true_quantiles * 2 , index = idx , name = "b" , dtype = "Float64" )
241
241
tm .assert_series_equal (result , expected )
242
242
243
243
@@ -266,14 +266,21 @@ def test_groupby_quantile_NA_float(any_float_dtype):
266
266
df = DataFrame ({"x" : [1 , 1 ], "y" : [0.2 , np .nan ]}, dtype = any_float_dtype )
267
267
result = df .groupby ("x" )["y" ].quantile (0.5 )
268
268
exp_index = Index ([1.0 ], dtype = any_float_dtype , name = "x" )
269
- expected = pd .Series ([0.2 ], dtype = float , index = exp_index , name = "y" )
270
- tm .assert_series_equal (expected , result )
269
+
270
+ if any_float_dtype in ["Float32" , "Float64" ]:
271
+ expected_dtype = any_float_dtype
272
+ else :
273
+ expected_dtype = None
274
+
275
+ expected = pd .Series ([0.2 ], dtype = expected_dtype , index = exp_index , name = "y" )
276
+ tm .assert_series_equal (result , expected )
271
277
272
278
result = df .groupby ("x" )["y" ].quantile ([0.5 , 0.75 ])
273
279
expected = pd .Series (
274
280
[0.2 ] * 2 ,
275
281
index = pd .MultiIndex .from_product ((exp_index , [0.5 , 0.75 ]), names = ["x" , None ]),
276
282
name = "y" ,
283
+ dtype = expected_dtype ,
277
284
)
278
285
tm .assert_series_equal (result , expected )
279
286
@@ -283,12 +290,68 @@ def test_groupby_quantile_NA_int(any_int_ea_dtype):
283
290
df = DataFrame ({"x" : [1 , 1 ], "y" : [2 , 5 ]}, dtype = any_int_ea_dtype )
284
291
result = df .groupby ("x" )["y" ].quantile (0.5 )
285
292
expected = pd .Series (
286
- [3.5 ], dtype = float , index = Index ([1 ], name = "x" , dtype = any_int_ea_dtype ), name = "y"
293
+ [3.5 ],
294
+ dtype = "Float64" ,
295
+ index = Index ([1 ], name = "x" , dtype = any_int_ea_dtype ),
296
+ name = "y" ,
287
297
)
288
298
tm .assert_series_equal (expected , result )
289
299
290
300
result = df .groupby ("x" ).quantile (0.5 )
291
- expected = DataFrame ({"y" : 3.5 }, index = Index ([1 ], name = "x" , dtype = any_int_ea_dtype ))
301
+ expected = DataFrame (
302
+ {"y" : 3.5 }, dtype = "Float64" , index = Index ([1 ], name = "x" , dtype = any_int_ea_dtype )
303
+ )
304
+ tm .assert_frame_equal (result , expected )
305
+
306
+
307
+ @pytest .mark .parametrize (
308
+ "interpolation, val1, val2" , [("lower" , 2 , 2 ), ("higher" , 2 , 3 ), ("nearest" , 2 , 2 )]
309
+ )
310
+ def test_groupby_quantile_all_na_group_masked (
311
+ interpolation , val1 , val2 , any_numeric_ea_dtype
312
+ ):
313
+ # GH#37493
314
+ df = DataFrame (
315
+ {"a" : [1 , 1 , 1 , 2 ], "b" : [1 , 2 , 3 , pd .NA ]}, dtype = any_numeric_ea_dtype
316
+ )
317
+ result = df .groupby ("a" ).quantile (q = [0.5 , 0.7 ], interpolation = interpolation )
318
+ expected = DataFrame (
319
+ {"b" : [val1 , val2 , pd .NA , pd .NA ]},
320
+ dtype = any_numeric_ea_dtype ,
321
+ index = pd .MultiIndex .from_arrays (
322
+ [pd .Series ([1 , 1 , 2 , 2 ], dtype = any_numeric_ea_dtype ), [0.5 , 0.7 , 0.5 , 0.7 ]],
323
+ names = ["a" , None ],
324
+ ),
325
+ )
326
+ tm .assert_frame_equal (result , expected )
327
+
328
+
329
+ @pytest .mark .parametrize ("interpolation" , ["midpoint" , "linear" ])
330
+ def test_groupby_quantile_all_na_group_masked_interp (
331
+ interpolation , any_numeric_ea_dtype
332
+ ):
333
+ # GH#37493
334
+ df = DataFrame (
335
+ {"a" : [1 , 1 , 1 , 2 ], "b" : [1 , 2 , 3 , pd .NA ]}, dtype = any_numeric_ea_dtype
336
+ )
337
+ result = df .groupby ("a" ).quantile (q = [0.5 , 0.75 ], interpolation = interpolation )
338
+
339
+ if any_numeric_ea_dtype == "Float32" :
340
+ expected_dtype = any_numeric_ea_dtype
341
+ else :
342
+ expected_dtype = "Float64"
343
+
344
+ expected = DataFrame (
345
+ {"b" : [2.0 , 2.5 , pd .NA , pd .NA ]},
346
+ dtype = expected_dtype ,
347
+ index = pd .MultiIndex .from_arrays (
348
+ [
349
+ pd .Series ([1 , 1 , 2 , 2 ], dtype = any_numeric_ea_dtype ),
350
+ [0.5 , 0.75 , 0.5 , 0.75 ],
351
+ ],
352
+ names = ["a" , None ],
353
+ ),
354
+ )
292
355
tm .assert_frame_equal (result , expected )
293
356
294
357
@@ -298,7 +361,7 @@ def test_groupby_quantile_allNA_column(dtype):
298
361
df = DataFrame ({"x" : [1 , 1 ], "y" : [pd .NA ] * 2 }, dtype = dtype )
299
362
result = df .groupby ("x" )["y" ].quantile (0.5 )
300
363
expected = pd .Series (
301
- [np .nan ], dtype = float , index = Index ([1.0 ], dtype = dtype ), name = "y"
364
+ [np .nan ], dtype = dtype , index = Index ([1.0 ], dtype = dtype ), name = "y"
302
365
)
303
366
expected .index .name = "x"
304
367
tm .assert_series_equal (expected , result )
0 commit comments