@@ -300,50 +300,80 @@ def test_period(self):
300
300
tm .assert_series_equal (notna (s ), ~ exp )
301
301
302
302
303
- def test_array_equivalent ():
304
- assert array_equivalent ( np . array ([ np . nan , np . nan ]), np . array ([ np . nan , np . nan ]))
303
+ @ pytest . mark . parametrize ( "dtype_equal" , [ True , False ])
304
+ def test_array_equivalent ( dtype_equal ):
305
305
assert array_equivalent (
306
- np .array ([np .nan , 1 , np .nan ]), np .array ([np .nan , 1 , np .nan ])
306
+ np .array ([np .nan , np .nan ]), np .array ([np .nan , np .nan ]), dtype_equal = dtype_equal
307
+ )
308
+ assert array_equivalent (
309
+ np .array ([np .nan , 1 , np .nan ]),
310
+ np .array ([np .nan , 1 , np .nan ]),
311
+ dtype_equal = dtype_equal ,
307
312
)
308
313
assert array_equivalent (
309
314
np .array ([np .nan , None ], dtype = "object" ),
310
315
np .array ([np .nan , None ], dtype = "object" ),
316
+ dtype_equal = dtype_equal ,
311
317
)
312
318
# Check the handling of nested arrays in array_equivalent_object
313
319
assert array_equivalent (
314
320
np .array ([np .array ([np .nan , None ], dtype = "object" ), None ], dtype = "object" ),
315
321
np .array ([np .array ([np .nan , None ], dtype = "object" ), None ], dtype = "object" ),
322
+ dtype_equal = dtype_equal ,
316
323
)
317
324
assert array_equivalent (
318
325
np .array ([np .nan , 1 + 1j ], dtype = "complex" ),
319
326
np .array ([np .nan , 1 + 1j ], dtype = "complex" ),
327
+ dtype_equal = dtype_equal ,
320
328
)
321
329
assert not array_equivalent (
322
330
np .array ([np .nan , 1 + 1j ], dtype = "complex" ),
323
331
np .array ([np .nan , 1 + 2j ], dtype = "complex" ),
332
+ dtype_equal = dtype_equal ,
333
+ )
334
+ assert not array_equivalent (
335
+ np .array ([np .nan , 1 , np .nan ]),
336
+ np .array ([np .nan , 2 , np .nan ]),
337
+ dtype_equal = dtype_equal ,
338
+ )
339
+ assert not array_equivalent (
340
+ np .array (["a" , "b" , "c" , "d" ]), np .array (["e" , "e" ]), dtype_equal = dtype_equal
341
+ )
342
+ assert array_equivalent (
343
+ Float64Index ([0 , np .nan ]), Float64Index ([0 , np .nan ]), dtype_equal = dtype_equal
324
344
)
325
345
assert not array_equivalent (
326
- np .array ([np .nan , 1 , np .nan ]), np .array ([np .nan , 2 , np .nan ])
346
+ Float64Index ([0 , np .nan ]), Float64Index ([1 , np .nan ]), dtype_equal = dtype_equal
347
+ )
348
+ assert array_equivalent (
349
+ DatetimeIndex ([0 , np .nan ]), DatetimeIndex ([0 , np .nan ]), dtype_equal = dtype_equal
350
+ )
351
+ assert not array_equivalent (
352
+ DatetimeIndex ([0 , np .nan ]), DatetimeIndex ([1 , np .nan ]), dtype_equal = dtype_equal
353
+ )
354
+ assert array_equivalent (
355
+ TimedeltaIndex ([0 , np .nan ]),
356
+ TimedeltaIndex ([0 , np .nan ]),
357
+ dtype_equal = dtype_equal ,
327
358
)
328
- assert not array_equivalent (np .array (["a" , "b" , "c" , "d" ]), np .array (["e" , "e" ]))
329
- assert array_equivalent (Float64Index ([0 , np .nan ]), Float64Index ([0 , np .nan ]))
330
- assert not array_equivalent (Float64Index ([0 , np .nan ]), Float64Index ([1 , np .nan ]))
331
- assert array_equivalent (DatetimeIndex ([0 , np .nan ]), DatetimeIndex ([0 , np .nan ]))
332
- assert not array_equivalent (DatetimeIndex ([0 , np .nan ]), DatetimeIndex ([1 , np .nan ]))
333
- assert array_equivalent (TimedeltaIndex ([0 , np .nan ]), TimedeltaIndex ([0 , np .nan ]))
334
359
assert not array_equivalent (
335
- TimedeltaIndex ([0 , np .nan ]), TimedeltaIndex ([1 , np .nan ])
360
+ TimedeltaIndex ([0 , np .nan ]),
361
+ TimedeltaIndex ([1 , np .nan ]),
362
+ dtype_equal = dtype_equal ,
336
363
)
337
364
assert array_equivalent (
338
365
DatetimeIndex ([0 , np .nan ], tz = "US/Eastern" ),
339
366
DatetimeIndex ([0 , np .nan ], tz = "US/Eastern" ),
367
+ dtype_equal = dtype_equal ,
340
368
)
341
369
assert not array_equivalent (
342
370
DatetimeIndex ([0 , np .nan ], tz = "US/Eastern" ),
343
371
DatetimeIndex ([1 , np .nan ], tz = "US/Eastern" ),
372
+ dtype_equal = dtype_equal ,
344
373
)
374
+ # The rest are not dtype_equal
345
375
assert not array_equivalent (
346
- DatetimeIndex ([0 , np .nan ]), DatetimeIndex ([0 , np .nan ], tz = "US/Eastern" )
376
+ DatetimeIndex ([0 , np .nan ]), DatetimeIndex ([0 , np .nan ], tz = "US/Eastern" ),
347
377
)
348
378
assert not array_equivalent (
349
379
DatetimeIndex ([0 , np .nan ], tz = "CET" ),
@@ -353,6 +383,11 @@ def test_array_equivalent():
353
383
assert not array_equivalent (DatetimeIndex ([0 , np .nan ]), TimedeltaIndex ([0 , np .nan ]))
354
384
355
385
386
+ def test_array_equivalent_different_dtype_but_equal ():
387
+ # Unclear if this is exposed anywhere in the public-facing API
388
+ assert array_equivalent (np .array ([1 , 2 ]), np .array ([1.0 , 2.0 ]))
389
+
390
+
356
391
@pytest .mark .parametrize (
357
392
"lvalue, rvalue" ,
358
393
[
0 commit comments