Skip to content

Commit bac025f

Browse files
Merge pull request #10386 from jorisvandenbossche/get-schema-keys
BUG: fix multiple columns as primary key in io.sql.get_schema (GH10385)
2 parents 5455aca + f174c98 commit bac025f

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

doc/source/whatsnew/v0.17.0.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ Bug Fixes
120120

121121
- Bug in ``DataFrame.interpolate`` with ``axis=1`` and ``inplace=True`` (:issue:`10395`)
122122

123-
123+
- Bug in ``io.sql.get_schema`` when specifying multiple columns as primary
124+
key (:issue:`10385`).
124125

125126

126127
- Bug in ``test_categorical`` on big-endian builds (:issue:`10425`)

pandas/io/sql.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,11 @@ def _create_table_setup(self):
834834
for name, typ, is_index in column_names_and_types]
835835

836836
if self.keys is not None:
837-
pkc = PrimaryKeyConstraint(self.keys, name=self.name + '_pk')
837+
if not com.is_list_like(self.keys):
838+
keys = [self.keys]
839+
else:
840+
keys = self.keys
841+
pkc = PrimaryKeyConstraint(*keys, name=self.name + '_pk')
838842
columns.append(pkc)
839843

840844
schema = self.schema or self.pd_sql.meta.schema
@@ -899,8 +903,8 @@ def _harmonize_columns(self, parse_dates=None):
899903

900904
def _get_notnull_col_dtype(self, col):
901905
"""
902-
Infer datatype of the Series col. In case the dtype of col is 'object'
903-
and it contains NA values, this infers the datatype of the not-NA
906+
Infer datatype of the Series col. In case the dtype of col is 'object'
907+
and it contains NA values, this infers the datatype of the not-NA
904908
values. Needed for inserting typed data containing NULLs, GH8778.
905909
"""
906910
col_for_inference = col
@@ -1272,7 +1276,7 @@ def _get_unicode_name(name):
12721276
return uname
12731277

12741278
def _get_valid_mysql_name(name):
1275-
# Filter for unquoted identifiers
1279+
# Filter for unquoted identifiers
12761280
# See http://dev.mysql.com/doc/refman/5.0/en/identifiers.html
12771281
uname = _get_unicode_name(name)
12781282
if not len(uname):
@@ -1293,7 +1297,7 @@ def _get_valid_sqlite_name(name):
12931297
# Ensure the string does not include any NUL characters.
12941298
# Replace all " with "".
12951299
# Wrap the entire thing in double quotes.
1296-
1300+
12971301
uname = _get_unicode_name(name)
12981302
if not len(uname):
12991303
raise ValueError("Empty table or column name specified")
@@ -1377,7 +1381,11 @@ def _create_table_setup(self):
13771381
for cname, ctype, _ in column_names_and_types]
13781382

13791383
if self.keys is not None and len(self.keys):
1380-
cnames_br = ",".join([escape(c) for c in self.keys])
1384+
if not com.is_list_like(self.keys):
1385+
keys = [self.keys]
1386+
else:
1387+
keys = self.keys
1388+
cnames_br = ", ".join([escape(c) for c in keys])
13811389
create_tbl_stmts.append(
13821390
"CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format(
13831391
tbl=self.name, cnames_br=cnames_br))
@@ -1391,7 +1399,7 @@ def _create_table_setup(self):
13911399
cnames = "_".join(ix_cols)
13921400
cnames_br = ",".join([escape(c) for c in ix_cols])
13931401
create_stmts.append(
1394-
"CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) +
1402+
"CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) +
13951403
"ON " + escape(self.name) + " (" + cnames_br + ")")
13961404

13971405
return create_stmts
@@ -1416,7 +1424,7 @@ def _sql_type_name(self, col):
14161424

14171425
elif col_type == "complex":
14181426
raise ValueError('Complex datatypes not supported')
1419-
1427+
14201428
if col_type not in _SQL_TYPES:
14211429
col_type = "string"
14221430

pandas/io/tests/test_sql.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,19 @@ def test_get_schema_dtypes(self):
703703
self.assertTrue('CREATE' in create_sql)
704704
self.assertTrue('INTEGER' in create_sql)
705705

706+
def test_get_schema_keys(self):
707+
frame = DataFrame({'Col1':[1.1,1.2], 'Col2':[2.1,2.2]})
708+
create_sql = sql.get_schema(frame, 'test', 'sqlite',
709+
con=self.conn, keys='Col1')
710+
constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("Col1")'
711+
self.assertTrue(constraint_sentence in create_sql)
712+
713+
# multiple columns as key (GH10385)
714+
create_sql = sql.get_schema(self.test_frame1, 'test', 'sqlite',
715+
con=self.conn, keys=['A', 'B'])
716+
constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("A", "B")'
717+
self.assertTrue(constraint_sentence in create_sql)
718+
706719
def test_chunksize_read(self):
707720
df = DataFrame(np.random.randn(22, 5), columns=list('abcde'))
708721
df.to_sql('test_chunksize', self.conn, index=False)
@@ -1851,7 +1864,7 @@ def test_illegal_names(self):
18511864
df2 = DataFrame([[1, 2], [3, 4]], columns=['a', ok_name])
18521865
c_tbl = 'test_ok_col_name%d'%ndx
18531866
df2.to_sql(c_tbl, self.conn, flavor=self.flavor, index=False,
1854-
if_exists='replace')
1867+
if_exists='replace')
18551868
self.conn.cursor().execute("DROP TABLE `%s`" % c_tbl)
18561869
self.conn.commit()
18571870

@@ -1962,7 +1975,7 @@ def test_schema(self):
19621975
frame = tm.makeTimeDataFrame()
19631976
create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],)
19641977
lines = create_sql.splitlines()
1965-
self.assertTrue('PRIMARY KEY ("A","B")' in create_sql)
1978+
self.assertTrue('PRIMARY KEY ("A", "B")' in create_sql)
19661979
cur = self.db.cursor()
19671980
cur.execute(create_sql)
19681981

@@ -2277,7 +2290,7 @@ def test_schema(self):
22772290
drop_sql = "DROP TABLE IF EXISTS test"
22782291
create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],)
22792292
lines = create_sql.splitlines()
2280-
self.assertTrue('PRIMARY KEY (`A`,`B`)' in create_sql)
2293+
self.assertTrue('PRIMARY KEY (`A`, `B`)' in create_sql)
22812294
cur = self.db.cursor()
22822295
cur.execute(drop_sql)
22832296
cur.execute(create_sql)

0 commit comments

Comments
 (0)