@@ -749,10 +749,14 @@ def __init__(self, objs, axis=0, join='outer', join_axes=None,
749
749
self .new_axes = self ._get_new_axes ()
750
750
751
751
def get_result (self ):
752
- if self ._is_series :
752
+ if self ._is_series and self . axis == 0 :
753
753
new_data = np .concatenate ([x .values for x in self .objs ])
754
754
name = _consensus_name_attr (self .objs )
755
755
return Series (new_data , index = self .new_axes [0 ], name = name )
756
+ elif self ._is_series :
757
+ data = dict (zip (self .new_axes [1 ], self .objs ))
758
+ return DataFrame (data , index = self .new_axes [0 ],
759
+ columns = self .new_axes [1 ])
756
760
else :
757
761
new_data = self ._get_concatenated_data ()
758
762
return self .objs [0 ]._from_axes (new_data , self .new_axes )
@@ -864,8 +868,14 @@ def _concat_single_item(self, item):
864
868
assert (self .axis >= 1 )
865
869
return np .concatenate (to_concat , axis = self .axis - 1 )
866
870
871
+ def _get_result_dim (self ):
872
+ if self ._is_series and self .axis == 1 :
873
+ return 2
874
+ else :
875
+ return self .objs [0 ].ndim
876
+
867
877
def _get_new_axes (self ):
868
- ndim = self .objs [ 0 ]. ndim
878
+ ndim = self ._get_result_dim ()
869
879
new_axes = [None ] * ndim
870
880
871
881
if self .ignore_index :
@@ -879,11 +889,7 @@ def _get_new_axes(self):
879
889
for i in range (ndim ):
880
890
if i == self .axis :
881
891
continue
882
- all_indexes = [x ._data .axes [i ] for x in self .objs ]
883
- comb_axis = _get_combined_index (all_indexes ,
884
- intersect = self .intersect )
885
- new_axes [i ] = comb_axis
886
-
892
+ new_axes [i ] = self ._get_comb_axis (i )
887
893
else :
888
894
assert (len (self .join_axes ) == ndim - 1 )
889
895
@@ -896,9 +902,22 @@ def _get_new_axes(self):
896
902
897
903
return new_axes
898
904
905
+ def _get_comb_axis (self , i ):
906
+ if self ._is_series :
907
+ all_indexes = [x .index for x in self .objs ]
908
+ else :
909
+ all_indexes = [x ._data .axes [i ] for x in self .objs ]
910
+
911
+ return _get_combined_index (all_indexes , intersect = self .intersect )
912
+
899
913
def _get_concat_axis (self ):
900
914
if self ._is_series :
901
- indexes = [x .index for x in self .objs ]
915
+ if self .axis == 0 :
916
+ indexes = [x .index for x in self .objs ]
917
+ elif self .keys is None :
918
+ return Index (np .arange (len (self .objs )))
919
+ else :
920
+ return _ensure_index (self .keys )
902
921
else :
903
922
indexes = [x ._data .axes [self .axis ] for x in self .objs ]
904
923
0 commit comments