Skip to content

Commit f8e382c

Browse files
committed
COMPAT: For pandas 0.21 CategoricalDtype
1 parent 4778ee2 commit f8e382c

File tree

8 files changed

+104
-53
lines changed

8 files changed

+104
-53
lines changed

dask/array/percentile.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _percentile(a, q, interpolation='linear'):
1717
return None
1818
if isinstance(q, Iterator):
1919
q = list(q)
20-
if str(a.dtype) == 'category':
20+
if a.dtype.name == 'category':
2121
result = np.percentile(a.codes, q, interpolation=interpolation)
2222
import pandas as pd
2323
return pd.Categorical.from_codes(result, a.categories, a.ordered)
@@ -100,7 +100,7 @@ def merge_percentiles(finalq, qs, vals, Ns, interpolation='lower'):
100100

101101
# TODO: Perform this check above in percentile once dtype checking is easy
102102
# Here we silently change meaning
103-
if str(vals[0].dtype) == 'category':
103+
if vals[0].dtype.name == 'category':
104104
result = merge_percentiles(finalq, qs, [v.codes for v in vals], Ns, interpolation)
105105
import pandas as pd
106106
return pd.Categorical.from_codes(result, vals[0].categories, vals[0].ordered)

dask/dataframe/io/tests/test_io.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
import pytest
66
from threading import Lock
7-
8-
import threading
7+
from multiprocessing.pool import ThreadPool
98

109
import dask.array as da
1110
import dask.dataframe as dd
@@ -15,7 +14,7 @@
1514
from dask.utils import tmpfile
1615
from dask.local import get_sync
1716

18-
from dask.dataframe.utils import assert_eq
17+
from dask.dataframe.utils import assert_eq, is_categorical_dtype
1918

2019

2120
####################
@@ -119,13 +118,14 @@ def test_from_array_with_record_dtype():
119118

120119
def test_from_bcolz_multiple_threads():
121120
bcolz = pytest.importorskip('bcolz')
121+
pool = ThreadPool(processes=5)
122122

123-
def check():
123+
def check(i):
124124
t = bcolz.ctable([[1, 2, 3], [1., 2., 3.], ['a', 'b', 'a']],
125125
names=['x', 'y', 'a'])
126126
d = dd.from_bcolz(t, chunksize=2)
127127
assert d.npartitions == 2
128-
assert str(d.dtypes['a']) == 'category'
128+
assert is_categorical_dtype(d.dtypes['a'])
129129
assert list(d.x.compute(get=get_sync)) == [1, 2, 3]
130130
assert list(d.a.compute(get=get_sync)) == ['a', 'b', 'a']
131131

@@ -139,14 +139,7 @@ def check():
139139
assert (sorted(dd.from_bcolz(t, chunksize=2).dask) !=
140140
sorted(dd.from_bcolz(t, chunksize=3).dask))
141141

142-
threads = []
143-
for i in range(5):
144-
thread = threading.Thread(target=check)
145-
thread.start()
146-
threads.append(thread)
147-
148-
for thread in threads:
149-
thread.join()
142+
pool.map(check, range(5))
150143

151144

152145
def test_from_bcolz():
@@ -156,7 +149,7 @@ def test_from_bcolz():
156149
names=['x', 'y', 'a'])
157150
d = dd.from_bcolz(t, chunksize=2)
158151
assert d.npartitions == 2
159-
assert str(d.dtypes['a']) == 'category'
152+
assert is_categorical_dtype(d.dtypes['a'])
160153
assert list(d.x.compute(get=get_sync)) == [1, 2, 3]
161154
assert list(d.a.compute(get=get_sync)) == ['a', 'b', 'a']
162155
L = list(d.index.compute(get=get_sync))

dask/dataframe/partitionquantiles.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
from ..utils import random_state_data
8080
from ..base import tokenize
8181
from .core import Series
82+
from .utils import is_categorical_dtype
8283
from dask.compatibility import zip
8384

8485

@@ -363,7 +364,7 @@ def process_val_weights(vals_and_weights, npartitions, dtype_info):
363364
rv = np.concatenate([trimmed, jumbo_vals])
364365
rv.sort()
365366

366-
if str(dtype) == 'category':
367+
if is_categorical_dtype(dtype):
367368
rv = pd.Categorical.from_codes(rv, info[0], info[1])
368369
elif 'datetime64' in str(dtype):
369370
rv = pd.DatetimeIndex(rv, dtype=dtype)
@@ -398,7 +399,7 @@ def percentiles_summary(df, num_old, num_new, upsample, state):
398399
qs = sample_percentiles(num_old, num_new, length, upsample, random_state)
399400
data = df.values
400401
interpolation = 'linear'
401-
if str(data.dtype) == 'category':
402+
if is_categorical_dtype(data):
402403
data = data.codes
403404
interpolation = 'nearest'
404405
vals = _percentile(data, qs, interpolation=interpolation)
@@ -410,7 +411,7 @@ def percentiles_summary(df, num_old, num_new, upsample, state):
410411

411412
def dtype_info(df):
412413
info = None
413-
if str(df.dtype) == 'category':
414+
if is_categorical_dtype(df):
414415
data = df.values
415416
info = (data.categories, data.ordered)
416417
return df.dtype, info

dask/dataframe/tests/test_categorical.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def test_categorize_index():
202202
@pytest.mark.parametrize('shuffle', ['disk', 'tasks'])
203203
def test_categorical_set_index(shuffle):
204204
df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': ['a', 'b', 'b', 'c']})
205-
df['y'] = df.y.astype('category', ordered=True)
205+
df['y'] = pd.Categorical(df['y'], categories=['a', 'b', 'c'], ordered=True)
206206
a = dd.from_pandas(df, npartitions=2)
207207

208208
with dask.set_options(get=dask.get, shuffle=shuffle):

dask/dataframe/tests/test_dataframe.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2202,9 +2202,9 @@ def test_categorize_info():
22022202
"Int64Index: 4 entries, 0 to 3\n"
22032203
"Data columns (total 3 columns):\n"
22042204
"x 4 non-null int64\n"
2205-
"y 4 non-null category\n"
2205+
"y 4 non-null CategoricalDtype(categories=['a', 'b', 'c'], ordered=False)\n" # noqa
22062206
"z 4 non-null object\n"
2207-
"dtypes: category(1), object(1), int64(1)")
2207+
"dtypes: CategoricalDtype(categories=['a', 'b', 'c'], ordered=False)(1), object(1), int64(1)") # noqa
22082208

22092209

22102210
def test_gh_1301():

dask/dataframe/tests/test_format.py

+52-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# coding: utf-8
22
import pandas as pd
3+
from textwrap import dedent
34

45
import dask.dataframe as dd
56
from dask.dataframe.utils import PANDAS_VERSION
@@ -440,14 +441,26 @@ def test_index_format():
440441
s = pd.Series([1, 2, 3, 4, 5, 6, 7, 8],
441442
index=pd.CategoricalIndex([1, 2, 3, 4, 5, 6, 7, 8], name='YYY'))
442443
ds = dd.from_pandas(s, 3)
443-
exp = """Dask Index Structure:
444-
npartitions=3
445-
1 category[known]
446-
4 ...
447-
7 ...
448-
8 ...
449-
Name: YYY, dtype: category
450-
Dask Name: from_pandas, 6 tasks"""
444+
if PANDAS_VERSION >= '0.21.0':
445+
exp = dedent("""\
446+
Dask Index Structure:
447+
npartitions=3
448+
1 category[known]
449+
4 ...
450+
7 ...
451+
8 ...
452+
Name: YYY, dtype: CategoricalDtype(categories=[1, 2, 3, 4, 5, 6, 7, 8], ordered=False)
453+
Dask Name: from_pandas, 6 tasks""")
454+
else:
455+
exp = dedent("""\
456+
Dask Index Structure:
457+
npartitions=3
458+
1 category[known]
459+
4 ...
460+
7 ...
461+
8 ...
462+
Name: YYY, dtype: category
463+
Dask Name: from_pandas, 6 tasks""")
451464
assert repr(ds.index) == exp
452465
assert str(ds.index) == exp
453466

@@ -456,17 +469,36 @@ def test_categorical_format():
456469
s = pd.Series(['a', 'b', 'c']).astype('category')
457470
known = dd.from_pandas(s, npartitions=1)
458471
unknown = known.cat.as_unknown()
459-
exp = ("Dask Series Structure:\n"
460-
"npartitions=1\n"
461-
"0 category[known]\n"
462-
"2 ...\n"
463-
"dtype: category\n"
464-
"Dask Name: from_pandas, 1 tasks")
472+
if PANDAS_VERSION >= '0.21.0':
473+
exp = dedent("""\
474+
Dask Series Structure:
475+
npartitions=1
476+
0 category[known]
477+
2 ...
478+
dtype: CategoricalDtype(categories=['a', 'b', 'c'], ordered=False)
479+
Dask Name: from_pandas, 1 tasks""")
480+
else:
481+
exp = ("Dask Series Structure:\n"
482+
"npartitions=1\n"
483+
"0 category[known]\n"
484+
"2 ...\n"
485+
"dtype: category\n"
486+
"Dask Name: from_pandas, 1 tasks")
465487
assert repr(known) == exp
466-
exp = ("Dask Series Structure:\n"
467-
"npartitions=1\n"
468-
"0 category[unknown]\n"
469-
"2 ...\n"
470-
"dtype: category\n"
471-
"Dask Name: from_pandas, 1 tasks")
488+
if PANDAS_VERSION >= '0.21.0':
489+
exp = dedent("""\
490+
Dask Series Structure:
491+
npartitions=1
492+
0 category[unknown]
493+
2 ...
494+
dtype: CategoricalDtype(categories=['__UNKNOWN_CATEGORIES__'], ordered=False)
495+
Dask Name: from_pandas, 1 tasks""")
496+
497+
else:
498+
exp = ("Dask Series Structure:\n"
499+
"npartitions=1\n"
500+
"0 category[unknown]\n"
501+
"2 ...\n"
502+
"dtype: category\n"
503+
"Dask Name: from_pandas, 1 tasks")
472504
assert repr(unknown) == exp

dask/dataframe/tests/test_utils_dataframe.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def test_meta_nonempty_empty_categories():
163163
# Series
164164
s = idx.to_series()
165165
res = meta_nonempty(s)
166-
assert res.dtype == s.dtype
166+
assert res.dtype == 'category'
167+
assert s.dtype == 'category'
167168
assert type(res.cat.categories) is type(s.cat.categories)
168169
assert res.cat.ordered == s.cat.ordered
169170
assert res.name == s.name
@@ -302,13 +303,30 @@ def test_check_meta():
302303
df2 = df[['a', 'b', 'd', 'e']]
303304
with pytest.raises(ValueError) as err:
304305
check_meta(df2, meta2, funcname='from_delayed')
305-
assert str(err.value) == ('Metadata mismatch found in `from_delayed`.\n'
306-
'\n'
307-
'Partition type: `DataFrame`\n'
308-
'+--------+----------+----------+\n'
309-
'| Column | Found | Expected |\n'
310-
'+--------+----------+----------+\n'
311-
'| a | object | category |\n'
312-
'| c | - | float64 |\n'
313-
'| e | category | - |\n'
314-
'+--------+----------+----------+')
306+
307+
if PANDAS_VERSION >= '0.21.0':
308+
exp = (
309+
'Metadata mismatch found in `from_delayed`.\n'
310+
'\n'
311+
'Partition type: `DataFrame`\n'
312+
'+--------+-------------------------------------------------------------+------------------------------------------------+\n' # noqa
313+
'| Column | Found | Expected |\n' # noqa
314+
'+--------+-------------------------------------------------------------+------------------------------------------------+\n' # noqa
315+
'| a | object | CategoricalDtype(categories=[], ordered=False) |\n' # noqa
316+
'| c | - | float64 |\n' # noqa
317+
"| e | CategoricalDtype(categories=['x', 'y', 'z'], ordered=False) | - |\n" # noqa
318+
'+--------+-------------------------------------------------------------+------------------------------------------------+' # noqa
319+
)
320+
else:
321+
exp = (
322+
'Metadata mismatch found in `from_delayed`.\n'
323+
'\n'
324+
'Partition type: `DataFrame`\n'
325+
'+--------+----------+----------+\n'
326+
'| Column | Found | Expected |\n'
327+
'+--------+----------+----------+\n'
328+
'| a | object | category |\n'
329+
'| c | - | float64 |\n'
330+
'| e | category | - |\n'
331+
'+--------+----------+----------+')
332+
assert str(err.value) == exp

dask/dataframe/utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,13 @@ def equal_dtypes(a, b):
454454
return False
455455
if (a is '-' or b is '-'):
456456
return False
457+
if is_categorical_dtype(a) and is_categorical_dtype(b):
458+
# Pandas 0.21 CategoricalDtype compat
459+
if (PANDAS_VERSION >= '0.21.0' and
460+
(UNKNOWN_CATEGORIES in a.categories or
461+
UNKNOWN_CATEGORIES in b.categories)):
462+
return True
463+
return a == b
457464
return (a.kind in eq_types and b.kind in eq_types) or (a == b)
458465

459466
if not isinstance(meta, (pd.Series, pd.Index, pd.DataFrame)):

0 commit comments

Comments
 (0)