diff --git a/doc/source/whatsnew/v0.17.0.txt b/doc/source/whatsnew/v0.17.0.txt index 09a39a6d9b2f5..cd41c4fc82146 100644 --- a/doc/source/whatsnew/v0.17.0.txt +++ b/doc/source/whatsnew/v0.17.0.txt @@ -120,7 +120,8 @@ Bug Fixes - Bug in ``DataFrame.interpolate`` with ``axis=1`` and ``inplace=True`` (:issue:`10395`) - +- Bug in ``io.sql.get_schema`` when specifying multiple columns as primary + key (:issue:`10385`). - Bug in ``test_categorical`` on big-endian builds (:issue:`10425`) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index b4e8c7de2b4e1..8d8768c08fe02 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -834,7 +834,11 @@ def _create_table_setup(self): for name, typ, is_index in column_names_and_types] if self.keys is not None: - pkc = PrimaryKeyConstraint(self.keys, name=self.name + '_pk') + if not com.is_list_like(self.keys): + keys = [self.keys] + else: + keys = self.keys + pkc = PrimaryKeyConstraint(*keys, name=self.name + '_pk') columns.append(pkc) schema = self.schema or self.pd_sql.meta.schema @@ -899,8 +903,8 @@ def _harmonize_columns(self, parse_dates=None): def _get_notnull_col_dtype(self, col): """ - Infer datatype of the Series col. In case the dtype of col is 'object' - and it contains NA values, this infers the datatype of the not-NA + Infer datatype of the Series col. In case the dtype of col is 'object' + and it contains NA values, this infers the datatype of the not-NA values. Needed for inserting typed data containing NULLs, GH8778. """ col_for_inference = col @@ -1272,7 +1276,7 @@ def _get_unicode_name(name): return uname def _get_valid_mysql_name(name): - # Filter for unquoted identifiers + # Filter for unquoted identifiers # See http://dev.mysql.com/doc/refman/5.0/en/identifiers.html uname = _get_unicode_name(name) if not len(uname): @@ -1293,7 +1297,7 @@ def _get_valid_sqlite_name(name): # Ensure the string does not include any NUL characters. # Replace all " with "". # Wrap the entire thing in double quotes. - + uname = _get_unicode_name(name) if not len(uname): raise ValueError("Empty table or column name specified") @@ -1377,7 +1381,11 @@ def _create_table_setup(self): for cname, ctype, _ in column_names_and_types] if self.keys is not None and len(self.keys): - cnames_br = ",".join([escape(c) for c in self.keys]) + if not com.is_list_like(self.keys): + keys = [self.keys] + else: + keys = self.keys + cnames_br = ", ".join([escape(c) for c in keys]) create_tbl_stmts.append( "CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format( tbl=self.name, cnames_br=cnames_br)) @@ -1391,7 +1399,7 @@ def _create_table_setup(self): cnames = "_".join(ix_cols) cnames_br = ",".join([escape(c) for c in ix_cols]) create_stmts.append( - "CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) + + "CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) + "ON " + escape(self.name) + " (" + cnames_br + ")") return create_stmts @@ -1416,7 +1424,7 @@ def _sql_type_name(self, col): elif col_type == "complex": raise ValueError('Complex datatypes not supported') - + if col_type not in _SQL_TYPES: col_type = "string" diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 33ea63ba41f1f..d8bc3c61f68f0 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -703,6 +703,19 @@ def test_get_schema_dtypes(self): self.assertTrue('CREATE' in create_sql) self.assertTrue('INTEGER' in create_sql) + def test_get_schema_keys(self): + frame = DataFrame({'Col1':[1.1,1.2], 'Col2':[2.1,2.2]}) + create_sql = sql.get_schema(frame, 'test', 'sqlite', + con=self.conn, keys='Col1') + constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("Col1")' + self.assertTrue(constraint_sentence in create_sql) + + # multiple columns as key (GH10385) + create_sql = sql.get_schema(self.test_frame1, 'test', 'sqlite', + con=self.conn, keys=['A', 'B']) + constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("A", "B")' + self.assertTrue(constraint_sentence in create_sql) + def test_chunksize_read(self): df = DataFrame(np.random.randn(22, 5), columns=list('abcde')) df.to_sql('test_chunksize', self.conn, index=False) @@ -1851,7 +1864,7 @@ def test_illegal_names(self): df2 = DataFrame([[1, 2], [3, 4]], columns=['a', ok_name]) c_tbl = 'test_ok_col_name%d'%ndx df2.to_sql(c_tbl, self.conn, flavor=self.flavor, index=False, - if_exists='replace') + if_exists='replace') self.conn.cursor().execute("DROP TABLE `%s`" % c_tbl) self.conn.commit() @@ -1962,7 +1975,7 @@ def test_schema(self): frame = tm.makeTimeDataFrame() create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) lines = create_sql.splitlines() - self.assertTrue('PRIMARY KEY ("A","B")' in create_sql) + self.assertTrue('PRIMARY KEY ("A", "B")' in create_sql) cur = self.db.cursor() cur.execute(create_sql) @@ -2277,7 +2290,7 @@ def test_schema(self): drop_sql = "DROP TABLE IF EXISTS test" create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],) lines = create_sql.splitlines() - self.assertTrue('PRIMARY KEY (`A`,`B`)' in create_sql) + self.assertTrue('PRIMARY KEY (`A`, `B`)' in create_sql) cur = self.db.cursor() cur.execute(drop_sql) cur.execute(create_sql)