@@ -664,20 +664,28 @@ def _index_name(self, index, index_label):
664
664
else :
665
665
return None
666
666
667
+ def _get_column_names_and_types (self , dtype_mapper ):
668
+ column_names_and_types = []
669
+ if self .index is not None :
670
+ for i , idx_label in enumerate (self .index ):
671
+ idx_type = dtype_mapper (
672
+ self .frame .index .get_level_values (i ).dtype )
673
+ column_names_and_types .append ((idx_label , idx_type ))
674
+
675
+ column_names_and_types += zip (
676
+ list (map (str , self .frame .columns )),
677
+ map (dtype_mapper , self .frame .dtypes )
678
+ )
679
+ return column_names_and_types
680
+
667
681
def _create_table_statement (self ):
668
682
from sqlalchemy import Table , Column
669
683
670
- columns = list ( map ( str , self . frame . columns ))
671
- column_types = map ( self ._sqlalchemy_type , self .frame . dtypes )
684
+ column_names_and_types = \
685
+ self ._get_column_names_and_types ( self ._sqlalchemy_type )
672
686
673
687
columns = [Column (name , typ )
674
- for name , typ in zip (columns , column_types )]
675
-
676
- if self .index is not None :
677
- for i , idx_label in enumerate (self .index [::- 1 ]):
678
- idx_type = self ._sqlalchemy_type (
679
- self .frame .index .get_level_values (i ))
680
- columns .insert (0 , Column (idx_label , idx_type , index = True ))
688
+ for name , typ in column_names_and_types ]
681
689
682
690
return Table (self .name , self .pd_sql .meta , * columns )
683
691
@@ -957,16 +965,13 @@ def insert(self):
957
965
def _create_table_statement (self ):
958
966
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
959
967
960
- columns = list (map (str , self .frame .columns ))
968
+ column_names_and_types = \
969
+ self ._get_column_names_and_types (self ._sql_type_name )
970
+
961
971
pat = re .compile ('\s+' )
962
- if any (map (pat .search , columns )):
972
+ column_names = [col_name for col_name , _ in column_names_and_types ]
973
+ if any (map (pat .search , column_names )):
963
974
warnings .warn (_SAFE_NAMES_WARNING )
964
- column_types = [self ._sql_type_name (typ ) for typ in self .frame .dtypes ]
965
-
966
- if self .index is not None :
967
- for i , idx_label in enumerate (self .index [::- 1 ]):
968
- columns .insert (0 , idx_label )
969
- column_types .insert (0 , self ._sql_type_name (self .frame .index .get_level_values (i ).dtype ))
970
975
971
976
flv = self .pd_sql .flavor
972
977
@@ -976,7 +981,7 @@ def _create_table_statement(self):
976
981
col_template = br_l + '%s' + br_r + ' %s'
977
982
978
983
columns = ',\n ' .join (col_template %
979
- x for x in zip ( columns , column_types ) )
984
+ x for x in column_names_and_types )
980
985
template = """CREATE TABLE %(name)s (
981
986
%(columns)s
982
987
)"""
0 commit comments