@@ -604,7 +604,7 @@ def str_extract(arr, pat, flags=0, expand=None):
604
604
return _str_extract_frame (arr ._orig , pat , flags = flags )
605
605
else :
606
606
result , name = _str_extract_noexpand (arr ._data , pat , flags = flags )
607
- return arr ._wrap_result (result , name = name )
607
+ return arr ._wrap_result (result , name = name , expand = expand )
608
608
609
609
610
610
def str_extractall (arr , pat , flags = 0 ):
@@ -1292,7 +1292,10 @@ def __iter__(self):
1292
1292
i += 1
1293
1293
g = self .get (i )
1294
1294
1295
- def _wrap_result (self , result , use_codes = True , name = None ):
1295
+ def _wrap_result (self , result , use_codes = True ,
1296
+ name = None , expand = None ):
1297
+
1298
+ from pandas .core .index import Index , MultiIndex
1296
1299
1297
1300
# for category, we do the stuff on the categories, so blow it up
1298
1301
# to the full series again
@@ -1302,48 +1305,42 @@ def _wrap_result(self, result, use_codes=True, name=None):
1302
1305
if use_codes and self ._is_categorical :
1303
1306
result = take_1d (result , self ._orig .cat .codes )
1304
1307
1305
- # leave as it is to keep extract and get_dummies results
1306
- # can be merged to _wrap_result_expand in v0.17
1307
- from pandas .core .series import Series
1308
- from pandas .core .frame import DataFrame
1309
- from pandas .core .index import Index
1310
-
1311
- if not hasattr (result , 'ndim' ):
1308
+ if not hasattr (result , 'ndim' ) or not hasattr (result , 'dtype' ):
1312
1309
return result
1310
+ assert result .ndim < 3
1313
1311
1314
- if result .ndim == 1 :
1315
- # Wait until we are sure result is a Series or Index before
1316
- # checking attributes (GH 12180)
1317
- name = name or getattr (result , 'name' , None ) or self ._orig .name
1318
- if isinstance (self ._orig , Index ):
1319
- # if result is a boolean np.array, return the np.array
1320
- # instead of wrapping it into a boolean Index (GH 8875)
1321
- if is_bool_dtype (result ):
1322
- return result
1323
- return Index (result , name = name )
1324
- return Series (result , index = self ._orig .index , name = name )
1325
- else :
1326
- assert result .ndim < 3
1327
- return DataFrame (result , index = self ._orig .index )
1312
+ if expand is None :
1313
+ # infer from ndim if expand is not specified
1314
+ expand = False if result .ndim == 1 else True
1315
+
1316
+ elif expand is True and not isinstance (self ._orig , Index ):
1317
+ # required when expand=True is explicitly specified
1318
+ # not needed when infered
1319
+
1320
+ def cons_row (x ):
1321
+ if is_list_like (x ):
1322
+ return x
1323
+ else :
1324
+ return [x ]
1325
+
1326
+ result = [cons_row (x ) for x in result ]
1328
1327
1329
- def _wrap_result_expand (self , result , expand = False ):
1330
1328
if not isinstance (expand , bool ):
1331
1329
raise ValueError ("expand must be True or False" )
1332
1330
1333
- # for category, we do the stuff on the categories, so blow it up
1334
- # to the full series again
1335
- if self ._is_categorical :
1336
- result = take_1d (result , self ._orig .cat .codes )
1337
-
1338
- from pandas .core .index import Index , MultiIndex
1339
- if not hasattr (result , 'ndim' ):
1340
- return result
1331
+ if name is None :
1332
+ name = getattr (result , 'name' , None )
1333
+ if name is None :
1334
+ # do not use logical or, _orig may be a DataFrame
1335
+ # which has "name" column
1336
+ name = self ._orig .name
1341
1337
1338
+ # Wait until we are sure result is a Series or Index before
1339
+ # checking attributes (GH 12180)
1342
1340
if isinstance (self ._orig , Index ):
1343
- name = getattr (result , 'name' , None )
1344
1341
# if result is a boolean np.array, return the np.array
1345
1342
# instead of wrapping it into a boolean Index (GH 8875)
1346
- if hasattr ( result , 'dtype' ) and is_bool_dtype (result ):
1343
+ if is_bool_dtype (result ):
1347
1344
return result
1348
1345
1349
1346
if expand :
@@ -1354,18 +1351,10 @@ def _wrap_result_expand(self, result, expand=False):
1354
1351
else :
1355
1352
index = self ._orig .index
1356
1353
if expand :
1357
-
1358
- def cons_row (x ):
1359
- if is_list_like (x ):
1360
- return x
1361
- else :
1362
- return [x ]
1363
-
1364
1354
cons = self ._orig ._constructor_expanddim
1365
- data = [cons_row (x ) for x in result ]
1366
- return cons (data , index = index )
1355
+ return cons (result , index = index )
1367
1356
else :
1368
- name = getattr ( result , 'name' , None )
1357
+ # Must a Series
1369
1358
cons = self ._orig ._constructor
1370
1359
return cons (result , name = name , index = index )
1371
1360
@@ -1380,12 +1369,12 @@ def cat(self, others=None, sep=None, na_rep=None):
1380
1369
@copy (str_split )
1381
1370
def split (self , pat = None , n = - 1 , expand = False ):
1382
1371
result = str_split (self ._data , pat , n = n )
1383
- return self ._wrap_result_expand (result , expand = expand )
1372
+ return self ._wrap_result (result , expand = expand )
1384
1373
1385
1374
@copy (str_rsplit )
1386
1375
def rsplit (self , pat = None , n = - 1 , expand = False ):
1387
1376
result = str_rsplit (self ._data , pat , n = n )
1388
- return self ._wrap_result_expand (result , expand = expand )
1377
+ return self ._wrap_result (result , expand = expand )
1389
1378
1390
1379
_shared_docs ['str_partition' ] = ("""
1391
1380
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
@@ -1440,7 +1429,7 @@ def rsplit(self, pat=None, n=-1, expand=False):
1440
1429
def partition (self , pat = ' ' , expand = True ):
1441
1430
f = lambda x : x .partition (pat )
1442
1431
result = _na_map (f , self ._data )
1443
- return self ._wrap_result_expand (result , expand = expand )
1432
+ return self ._wrap_result (result , expand = expand )
1444
1433
1445
1434
@Appender (_shared_docs ['str_partition' ] % {
1446
1435
'side' : 'last' ,
@@ -1451,7 +1440,7 @@ def partition(self, pat=' ', expand=True):
1451
1440
def rpartition (self , pat = ' ' , expand = True ):
1452
1441
f = lambda x : x .rpartition (pat )
1453
1442
result = _na_map (f , self ._data )
1454
- return self ._wrap_result_expand (result , expand = expand )
1443
+ return self ._wrap_result (result , expand = expand )
1455
1444
1456
1445
@copy (str_get )
1457
1446
def get (self , i ):
@@ -1597,7 +1586,8 @@ def get_dummies(self, sep='|'):
1597
1586
# methods available for making the dummies...
1598
1587
data = self ._orig .astype (str ) if self ._is_categorical else self ._data
1599
1588
result = str_get_dummies (data , sep )
1600
- return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
1589
+ return self ._wrap_result (result , use_codes = (not self ._is_categorical ),
1590
+ expand = True )
1601
1591
1602
1592
@copy (str_translate )
1603
1593
def translate (self , table , deletechars = None ):
0 commit comments