Skip to content

BUG: fix multiple columns as primary key in io.sql.get_schema (GH10385) #10386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v0.17.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ Bug Fixes

- Bug in ``DataFrame.interpolate`` with ``axis=1`` and ``inplace=True`` (:issue:`10395`)


- Bug in ``io.sql.get_schema`` when specifying multiple columns as primary
key (:issue:`10385`).


- Bug in ``test_categorical`` on big-endian builds (:issue:`10425`)
Expand Down
24 changes: 16 additions & 8 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,11 @@ def _create_table_setup(self):
for name, typ, is_index in column_names_and_types]

if self.keys is not None:
pkc = PrimaryKeyConstraint(self.keys, name=self.name + '_pk')
if not com.is_list_like(self.keys):
keys = [self.keys]
else:
keys = self.keys
pkc = PrimaryKeyConstraint(*keys, name=self.name + '_pk')
columns.append(pkc)

schema = self.schema or self.pd_sql.meta.schema
Expand Down Expand Up @@ -899,8 +903,8 @@ def _harmonize_columns(self, parse_dates=None):

def _get_notnull_col_dtype(self, col):
"""
Infer datatype of the Series col. In case the dtype of col is 'object'
and it contains NA values, this infers the datatype of the not-NA
Infer datatype of the Series col. In case the dtype of col is 'object'
and it contains NA values, this infers the datatype of the not-NA
values. Needed for inserting typed data containing NULLs, GH8778.
"""
col_for_inference = col
Expand Down Expand Up @@ -1272,7 +1276,7 @@ def _get_unicode_name(name):
return uname

def _get_valid_mysql_name(name):
# Filter for unquoted identifiers
# Filter for unquoted identifiers
# See http://dev.mysql.com/doc/refman/5.0/en/identifiers.html
uname = _get_unicode_name(name)
if not len(uname):
Expand All @@ -1293,7 +1297,7 @@ def _get_valid_sqlite_name(name):
# Ensure the string does not include any NUL characters.
# Replace all " with "".
# Wrap the entire thing in double quotes.

uname = _get_unicode_name(name)
if not len(uname):
raise ValueError("Empty table or column name specified")
Expand Down Expand Up @@ -1377,7 +1381,11 @@ def _create_table_setup(self):
for cname, ctype, _ in column_names_and_types]

if self.keys is not None and len(self.keys):
cnames_br = ",".join([escape(c) for c in self.keys])
if not com.is_list_like(self.keys):
keys = [self.keys]
else:
keys = self.keys
cnames_br = ", ".join([escape(c) for c in keys])
create_tbl_stmts.append(
"CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format(
tbl=self.name, cnames_br=cnames_br))
Expand All @@ -1391,7 +1399,7 @@ def _create_table_setup(self):
cnames = "_".join(ix_cols)
cnames_br = ",".join([escape(c) for c in ix_cols])
create_stmts.append(
"CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) +
"CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) +
"ON " + escape(self.name) + " (" + cnames_br + ")")

return create_stmts
Expand All @@ -1416,7 +1424,7 @@ def _sql_type_name(self, col):

elif col_type == "complex":
raise ValueError('Complex datatypes not supported')

if col_type not in _SQL_TYPES:
col_type = "string"

Expand Down
19 changes: 16 additions & 3 deletions pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,19 @@ def test_get_schema_dtypes(self):
self.assertTrue('CREATE' in create_sql)
self.assertTrue('INTEGER' in create_sql)

def test_get_schema_keys(self):
frame = DataFrame({'Col1':[1.1,1.2], 'Col2':[2.1,2.2]})
create_sql = sql.get_schema(frame, 'test', 'sqlite',
con=self.conn, keys='Col1')
constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("Col1")'
self.assertTrue(constraint_sentence in create_sql)

# multiple columns as key (GH10385)
create_sql = sql.get_schema(self.test_frame1, 'test', 'sqlite',
con=self.conn, keys=['A', 'B'])
constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("A", "B")'
self.assertTrue(constraint_sentence in create_sql)

def test_chunksize_read(self):
df = DataFrame(np.random.randn(22, 5), columns=list('abcde'))
df.to_sql('test_chunksize', self.conn, index=False)
Expand Down Expand Up @@ -1851,7 +1864,7 @@ def test_illegal_names(self):
df2 = DataFrame([[1, 2], [3, 4]], columns=['a', ok_name])
c_tbl = 'test_ok_col_name%d'%ndx
df2.to_sql(c_tbl, self.conn, flavor=self.flavor, index=False,
if_exists='replace')
if_exists='replace')
self.conn.cursor().execute("DROP TABLE `%s`" % c_tbl)
self.conn.commit()

Expand Down Expand Up @@ -1962,7 +1975,7 @@ def test_schema(self):
frame = tm.makeTimeDataFrame()
create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],)
lines = create_sql.splitlines()
self.assertTrue('PRIMARY KEY ("A","B")' in create_sql)
self.assertTrue('PRIMARY KEY ("A", "B")' in create_sql)
cur = self.db.cursor()
cur.execute(create_sql)

Expand Down Expand Up @@ -2277,7 +2290,7 @@ def test_schema(self):
drop_sql = "DROP TABLE IF EXISTS test"
create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],)
lines = create_sql.splitlines()
self.assertTrue('PRIMARY KEY (`A`,`B`)' in create_sql)
self.assertTrue('PRIMARY KEY (`A`, `B`)' in create_sql)
cur = self.db.cursor()
cur.execute(drop_sql)
cur.execute(create_sql)
Expand Down