|
2 | 2 | # pylint: disable-msg=W0612,E1101
|
3 | 3 |
|
4 | 4 | from pandas import DataFrame, Series
|
5 |
| -from pandas.core.sparse import SparseDataFrame |
6 | 5 | import pandas as pd
|
7 | 6 |
|
8 | 7 | from numpy import nan
|
@@ -234,26 +233,31 @@ def test_basic_types(self):
|
234 | 233 | 'b': ['A', 'A', 'B', 'C', 'C'],
|
235 | 234 | 'c': [2, 3, 3, 3, 2]})
|
236 | 235 |
|
| 236 | + expected = DataFrame({'a': [1, 0, 0], |
| 237 | + 'b': [0, 1, 0], |
| 238 | + 'c': [0, 0, 1]}, |
| 239 | + dtype='uint8', |
| 240 | + columns=list('abc')) |
237 | 241 | if not self.sparse:
|
238 |
| - exp_df_type = DataFrame |
239 |
| - exp_blk_type = pd.core.internals.IntBlock |
| 242 | + compare = tm.assert_frame_equal |
240 | 243 | else:
|
241 |
| - exp_df_type = SparseDataFrame |
242 |
| - exp_blk_type = pd.core.internals.SparseBlock |
243 |
| - |
244 |
| - self.assertEqual( |
245 |
| - type(get_dummies(s_list, sparse=self.sparse)), exp_df_type) |
246 |
| - self.assertEqual( |
247 |
| - type(get_dummies(s_series, sparse=self.sparse)), exp_df_type) |
248 |
| - |
249 |
| - r = get_dummies(s_df, sparse=self.sparse, columns=s_df.columns) |
250 |
| - self.assertEqual(type(r), exp_df_type) |
251 |
| - |
252 |
| - r = get_dummies(s_df, sparse=self.sparse, columns=['a']) |
253 |
| - exp_blk_type = pd.core.internals.IntBlock |
254 |
| - self.assertEqual(type(r[['a_0']]._data.blocks[0]), exp_blk_type) |
255 |
| - self.assertEqual(type(r[['a_1']]._data.blocks[0]), exp_blk_type) |
256 |
| - self.assertEqual(type(r[['a_2']]._data.blocks[0]), exp_blk_type) |
| 244 | + expected = expected.to_sparse(fill_value=0, kind='integer') |
| 245 | + compare = tm.assert_sp_frame_equal |
| 246 | + |
| 247 | + result = get_dummies(s_list, sparse=self.sparse) |
| 248 | + compare(result, expected) |
| 249 | + |
| 250 | + result = get_dummies(s_series, sparse=self.sparse) |
| 251 | + compare(result, expected) |
| 252 | + |
| 253 | + result = get_dummies(s_df, sparse=self.sparse, columns=s_df.columns) |
| 254 | + tm.assert_series_equal(result.get_dtype_counts(), |
| 255 | + Series({'uint8': 8})) |
| 256 | + |
| 257 | + result = get_dummies(s_df, sparse=self.sparse, columns=['a']) |
| 258 | + expected = Series({'uint8': 3, 'int64': 1, 'object': 1}).sort_values() |
| 259 | + tm.assert_series_equal(result.get_dtype_counts().sort_values(), |
| 260 | + expected) |
257 | 261 |
|
258 | 262 | def test_just_na(self):
|
259 | 263 | just_na_list = [np.nan]
|
|
0 commit comments