From 5065232e16b88558a891d72e35796a6f60ba6264 Mon Sep 17 00:00:00 2001 From: Artemy Kolchinsky Date: Wed, 20 Aug 2014 12:22:21 -0700 Subject: [PATCH] Rewrote/refactored get_schema to use methods from table classes (GH8232) Fix creation of database indexes Tests fix Fixes Remove _executefunc Fixed tests to work with Python3 Schema rewriting Fixes Doc update Fixed postgres schema errors Fixes Fixes Moving if_exists checks to create Removing executescript Refactor insert differently Code review fixes Removed meta.schema from read_table Merges Merged #8208 Replacing izip with zip for Python3 --- pandas/io/sql.py | 220 ++++++++++++++++-------------------- pandas/io/tests/test_sql.py | 13 ++- 2 files changed, 106 insertions(+), 127 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 462179b442ac0..83b96d5186dd2 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -552,33 +552,19 @@ class PandasSQLTable(PandasObject): # TODO: support for multiIndex def __init__(self, name, pandas_sql_engine, frame=None, index=True, if_exists='fail', prefix='pandas', index_label=None, - schema=None): + schema=None, keys=None): self.name = name self.pd_sql = pandas_sql_engine self.prefix = prefix self.frame = frame self.index = self._index_name(index, index_label) self.schema = schema + self.if_exists = if_exists + self.keys = keys if frame is not None: - # We want to write a frame - if self.pd_sql.has_table(self.name, self.schema): - if if_exists == 'fail': - raise ValueError("Table '%s' already exists." % name) - elif if_exists == 'replace': - self.pd_sql.drop_table(self.name, self.schema) - self.table = self._create_table_setup() - self.create() - elif if_exists == 'append': - self.table = self.pd_sql.get_table(self.name, self.schema) - if self.table is None: - self.table = self._create_table_setup() - else: - raise ValueError( - "'{0}' is not valid for if_exists".format(if_exists)) - else: - self.table = self._create_table_setup() - self.create() + # We want to initialize based on a dataframe + self.table = self._create_table_setup() else: # no data provided, read-only mode self.table = self.pd_sql.get_table(self.name, self.schema) @@ -593,9 +579,26 @@ def sql_schema(self): from sqlalchemy.schema import CreateTable return str(CreateTable(self.table)) - def create(self): + def _execute_create(self): + # Inserting table into database, add to MetaData object + self.table = self.table.tometadata(self.pd_sql.meta) self.table.create() + def create(self): + if self.exists(): + if self.if_exists == 'fail': + raise ValueError("Table '%s' already exists." % self.name) + elif self.if_exists == 'replace': + self.pd_sql.drop_table(self.name, self.schema) + self._execute_create() + elif self.if_exists == 'append': + pass + else: + raise ValueError( + "'{0}' is not valid for if_exists".format(self.if_exists)) + else: + self._execute_create() + def insert_statement(self): return self.table.insert() @@ -634,9 +637,15 @@ def insert_data(self): return column_names, data_list - def insert(self, chunksize=None): + def get_session(self): + con = self.pd_sql.engine.connect() + return con.begin() - ins = self.insert_statement() + def _execute_insert(self, trans, keys, data_iter): + data = [dict( (k, v) for k, v in zip(keys, row) ) for row in data_iter] + trans.connection.execute(self.insert_statement(), data) + + def insert(self, chunksize=None): keys, data_list = self.insert_data() nrows = len(self.frame) @@ -644,18 +653,15 @@ def insert(self, chunksize=None): chunksize = nrows chunks = int(nrows / chunksize) + 1 - con = self.pd_sql.engine.connect() - with con.begin() as trans: + with self.get_session() as trans: for i in range(chunks): start_i = i * chunksize end_i = min((i + 1) * chunksize, nrows) if start_i >= end_i: break - chunk_list = [arr[start_i:end_i] for arr in data_list] - insert_list = [dict((k, v) for k, v in zip(keys, row)) - for row in zip(*chunk_list)] - con.execute(ins, insert_list) + chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list]) + self._execute_insert(trans, keys, chunk_iter) def read(self, coerce_float=True, parse_dates=None, columns=None): @@ -729,7 +735,7 @@ def _get_column_names_and_types(self, dtype_mapper): return column_names_and_types def _create_table_setup(self): - from sqlalchemy import Table, Column + from sqlalchemy import Table, Column, PrimaryKeyConstraint column_names_and_types = \ self._get_column_names_and_types(self._sqlalchemy_type) @@ -737,7 +743,19 @@ def _create_table_setup(self): columns = [Column(name, typ, index=is_index) for name, typ, is_index in column_names_and_types] - return Table(self.name, self.pd_sql.meta, *columns, schema=self.schema) + if self.keys is not None: + columns.append(PrimaryKeyConstraint(self.keys, + name=self.name+'_pk')) + + + schema = self.schema or self.pd_sql.meta.schema + + # At this point, attach to new metadata, only attach to self.meta + # once table is created. + from sqlalchemy.schema import MetaData + meta = MetaData(self.pd_sql, schema=schema) + + return Table(self.name, meta, *columns, schema=schema) def _harmonize_columns(self, parse_dates=None): """ Make a data_frame's column type align with an sql_table @@ -872,7 +890,6 @@ def execute(self, *args, **kwargs): def read_table(self, table_name, index_col=None, coerce_float=True, parse_dates=None, columns=None, schema=None): - table = PandasSQLTable( table_name, self, index=index_col, schema=schema) return table.read(coerce_float=coerce_float, @@ -901,6 +918,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, table = PandasSQLTable( name, self, frame=frame, index=index, if_exists=if_exists, index_label=index_label, schema=schema) + table.create() table.insert(chunksize) # check for potentially case sensitivity issues (GH7815) if name not in self.engine.table_names(schema=schema or self.meta.schema): @@ -930,8 +948,9 @@ def drop_table(self, table_name, schema=None): self.get_table(table_name, schema).drop() self.meta.clear() - def _create_sql_schema(self, frame, table_name): - table = PandasSQLTable(table_name, self, frame=frame) + def _create_sql_schema(self, frame, table_name, keys=None): + table = PandasSQLTable(table_name, self, frame=frame, index=False, + keys=keys) return str(table.sql_schema()) @@ -997,8 +1016,8 @@ class PandasSQLTableLegacy(PandasSQLTable): def sql_schema(self): return str(";\n".join(self.table)) - def create(self): - with self.pd_sql.con: + def _execute_create(self): + with self.get_session(): for stmt in self.table: self.pd_sql.execute(stmt) @@ -1019,28 +1038,12 @@ def insert_statement(self): self.name, col_names, wildcards) return insert_statement - def insert(self, chunksize=None): - - ins = self.insert_statement() - keys, data_list = self.insert_data() - - nrows = len(self.frame) - if chunksize is None: - chunksize = nrows - chunks = int(nrows / chunksize) + 1 + def get_session(self): + return self.pd_sql.con - with self.pd_sql.con: - for i in range(chunks): - start_i = i * chunksize - end_i = min((i + 1) * chunksize, nrows) - if start_i >= end_i: - break - chunk_list = [arr[start_i:end_i] for arr in data_list] - insert_list = [tuple((v for v in row)) - for row in zip(*chunk_list)] - cur = self.pd_sql.con.cursor() - cur.executemany(ins, insert_list) - cur.close() + def _execute_insert(self, trans, keys, data_iter): + data_list = list(data_iter) + trans.executemany(self.insert_statement(), data_list) def _create_table_setup(self): """Return a list of SQL statement that create a table reflecting the @@ -1061,21 +1064,25 @@ def _create_table_setup(self): br_l = _SQL_SYMB[flv]['br_l'] # left val quote char br_r = _SQL_SYMB[flv]['br_r'] # right val quote char - col_template = br_l + '%s' + br_r + ' %s' - - columns = ',\n '.join(col_template % (cname, ctype) - for cname, ctype, _ in column_names_and_types) - template = """CREATE TABLE %(name)s ( - %(columns)s - )""" - create_stmts = [template % {'name': self.name, 'columns': columns}, ] - - ix_tpl = "CREATE INDEX ix_{tbl}_{col} ON {tbl} ({br_l}{col}{br_r})" - for cname, _, is_index in column_names_and_types: - if not is_index: - continue - create_stmts.append(ix_tpl.format(tbl=self.name, col=cname, - br_l=br_l, br_r=br_r)) + create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, ctype) + for cname, ctype, _ in column_names_and_types] + if self.keys is not None and len(self.keys): + cnames_br = ",".join([br_l + c + br_r for c in self.keys]) + create_tbl_stmts.append( + "CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format( + tbl=self.name, cnames_br=cnames_br)) + + create_stmts = ["CREATE TABLE " + self.name + " (\n" + + ',\n '.join(create_tbl_stmts) + "\n)"] + + ix_cols = [cname for cname, _, is_index in column_names_and_types + if is_index] + if len(ix_cols): + cnames = "_".join(ix_cols) + cnames_br = ",".join([br_l + c + br_r for c in ix_cols]) + create_stmts.append( + "CREATE INDEX ix_{tbl}_{cnames} ON {tbl} ({cnames_br})".format( + tbl=self.name, cnames=cnames, cnames_br=cnames_br)) return create_stmts @@ -1172,16 +1179,28 @@ def to_sql(self, frame, name, if_exists='fail', index=True, ---------- frame: DataFrame name: name of SQL table - flavor: {'sqlite', 'mysql'}, default 'sqlite' if_exists: {'fail', 'replace', 'append'}, default 'fail' fail: If table exists, do nothing. replace: If table exists, drop it, recreate it, and insert data. append: If table exists, insert data. Create if does not exist. + index : boolean, default True + Write DataFrame index as a column + index_label : string or sequence, default None + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + schema : string, default None + Ignored parameter included for compatability with SQLAlchemy version + of `to_sql`. + chunksize : int, default None + If not None, then rows will be written in batches of this size at a + time. If None, all rows will be written at once. """ table = PandasSQLTableLegacy( name, self, frame=frame, index=index, if_exists=if_exists, index_label=index_label) + table.create() table.insert(chunksize) def has_table(self, name, schema=None): @@ -1200,8 +1219,9 @@ def drop_table(self, name, schema=None): drop_sql = "DROP TABLE %s" % name self.execute(drop_sql) - def _create_sql_schema(self, frame, table_name): - table = PandasSQLTableLegacy(table_name, self, frame=frame) + def _create_sql_schema(self, frame, table_name, keys=None): + table = PandasSQLTableLegacy(table_name, self, frame=frame, index=False, + keys=keys) return str(table.sql_schema()) @@ -1227,58 +1247,8 @@ def get_schema(frame, name, flavor='sqlite', keys=None, con=None): """ - if con is None: - if flavor == 'mysql': - warnings.warn(_MYSQL_WARNING, FutureWarning) - return _get_schema_legacy(frame, name, flavor, keys) - pandas_sql = pandasSQL_builder(con=con, flavor=flavor) - return pandas_sql._create_sql_schema(frame, name) - - -def _get_schema_legacy(frame, name, flavor, keys=None): - """Old function from 0.13.1. To keep backwards compatibility. - When mysql legacy support is dropped, it should be possible to - remove this code - """ - - def get_sqltype(dtype, flavor): - pytype = dtype.type - pytype_name = "text" - if issubclass(pytype, np.floating): - pytype_name = "float" - elif issubclass(pytype, np.integer): - pytype_name = "int" - elif issubclass(pytype, np.datetime64) or pytype is datetime: - # Caution: np.datetime64 is also a subclass of np.number. - pytype_name = "datetime" - elif pytype is datetime.date: - pytype_name = "date" - elif issubclass(pytype, np.bool_): - pytype_name = "bool" - - return _SQL_TYPES[pytype_name][flavor] - - lookup_type = lambda dtype: get_sqltype(dtype, flavor) - - column_types = lzip(frame.dtypes.index, map(lookup_type, frame.dtypes)) - if flavor == 'sqlite': - columns = ',\n '.join('[%s] %s' % x for x in column_types) - else: - columns = ',\n '.join('`%s` %s' % x for x in column_types) - - keystr = '' - if keys is not None: - if isinstance(keys, string_types): - keys = (keys,) - keystr = ', PRIMARY KEY (%s)' % ','.join(keys) - template = """CREATE TABLE %(name)s ( - %(columns)s - %(keystr)s - );""" - create_statement = template % {'name': name, 'columns': columns, - 'keystr': keystr} - return create_statement + return pandas_sql._create_sql_schema(frame, name, keys=keys) # legacy names, with depreciation warnings and copied docs diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 4a4b9da619b5f..80988ab2f5e1c 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -1449,6 +1449,15 @@ def _get_index_columns(self, tbl_name): def test_to_sql_save_index(self): self._to_sql_save_index() + for ix_name, ix_col in zip(ixs.Key_name, ixs.Column_name): + if ix_name not in ix_cols: + ix_cols[ix_name] = [] + ix_cols[ix_name].append(ix_col) + return ix_cols.values() + + def test_to_sql_save_index(self): + self._to_sql_save_index() + #------------------------------------------------------------------------------ #--- Old tests from 0.13.1 (before refactor using sqlalchemy) @@ -1545,7 +1554,7 @@ def test_schema(self): frame = tm.makeTimeDataFrame() create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) lines = create_sql.splitlines() - self.assertTrue('PRIMARY KEY (A,B)' in create_sql) + self.assertTrue('PRIMARY KEY ([A],[B])' in create_sql) cur = self.db.cursor() cur.execute(create_sql) @@ -1824,7 +1833,7 @@ def test_schema(self): drop_sql = "DROP TABLE IF EXISTS test" create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],) lines = create_sql.splitlines() - self.assertTrue('PRIMARY KEY (A,B)' in create_sql) + self.assertTrue('PRIMARY KEY (`A`,`B`)' in create_sql) cur = self.db.cursor() cur.execute(drop_sql) cur.execute(create_sql)