Skip to content

Commit e55bb04

Browse files
committed
WIP: Fix dictionary array factorize NA handling
1 parent 35a17b9 commit e55bb04

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

pandas/core/arrays/arrow/array.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1207,10 +1207,9 @@ def factorize(
12071207
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
12081208
data = data.cast(pa.int64())
12091209

1210-
if pa.types.is_dictionary(data.type):
1211-
encoded = data
1212-
else:
1213-
encoded = data.dictionary_encode(null_encoding=null_encoding)
1210+
if pa.types.is_dictionary(data.type) and null_encoding=='encode':
1211+
data = data.cast(data.type.value_type)
1212+
encoded = data.dictionary_encode(null_encoding=null_encoding)
12141213
if encoded.length() == 0:
12151214
indices = np.array([], dtype=np.intp)
12161215
uniques = type(self)(pa.chunked_array([], type=encoded.type.value_type))

pandas/tests/extension/test_arrow.py

+19
Original file line numberDiff line numberDiff line change
@@ -3330,6 +3330,25 @@ def test_factorize_chunked_dictionary():
33303330
tm.assert_index_equal(res_uniques, exp_uniques)
33313331

33323332

3333+
def test_factorize_dictionary_with_na():
3334+
# Test that factorize properly handles NA values in dictionary arrays
3335+
arr = pd.array(['a1', pd.NA], dtype=pd.ArrowDtype(pa.dictionary(pa.int32(), pa.utf8())))
3336+
3337+
# Test with use_na_sentinel=True (default)
3338+
indices, uniques = arr.factorize()
3339+
expected_indices = np.array([0, -1], dtype=np.intp)
3340+
tm.assert_numpy_array_equal(indices, expected_indices)
3341+
expected_uniques = pd.array(['a1'], dtype=arr.dtype)
3342+
tm.assert_extension_array_equal(uniques, expected_uniques)
3343+
3344+
# Test with use_na_sentinel=False
3345+
indices, uniques = arr.factorize(use_na_sentinel=False)
3346+
expected_indices = np.array([0, 1], dtype=np.intp)
3347+
tm.assert_numpy_array_equal(indices, expected_indices)
3348+
expected_uniques = pd.array(['a1', None], dtype=arr.dtype)
3349+
tm.assert_extension_array_equal(uniques, expected_uniques)
3350+
3351+
33333352
def test_dictionary_astype_categorical():
33343353
# GH#56672
33353354
arrs = [

0 commit comments

Comments
 (0)