Skip to content

Commit c65f396

Browse files
eavidandevin-petersohn
authored andcommitted
parallel to_sql() (modin-project#461)
* initial working implementation * Fix issue in modin-project#453#issuecomment-461130825 * initial working implementation * Fix issue in modin-project#453#issuecomment-461130825 * alignment with @devin-petersohn comments documentation and unit tests * to_sql removed from pandas_query_compiler. moved to io.py * to_sql added to base/io.py to support other engines by defaulting to pandas * align base io to_sql signature to contain qc as well * fixed base io to_sql to first convert qc into pandas dataframe and then run to_sql() * clean up tests to follow pytest best practices * Undo unnecessary changes * Fixed linting and made fixture into a factory * more linting
1 parent 93131b0 commit c65f396

File tree

6 files changed

+162
-73
lines changed

6 files changed

+162
-73
lines changed

modin/data_management/factories.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,14 @@ def read_sql(cls, **kwargs):
154154
def _read_sql(cls, **kwargs):
155155
return cls.io_cls.read_sql(**kwargs)
156156

157+
@classmethod
158+
def to_sql(cls, *args, **kwargs):
159+
return cls._determine_engine()._to_sql(*args, **kwargs)
160+
161+
@classmethod
162+
def _to_sql(cls, *args, **kwargs):
163+
return cls.io_cls.to_sql(*args, **kwargs)
164+
157165

158166
class PandasOnRayFactory(BaseFactory):
159167

modin/engines/base/io.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,29 @@ def read_sql(
422422
chunksize=chunksize,
423423
)
424424
)
425+
426+
@classmethod
427+
def to_sql(
428+
cls,
429+
qc,
430+
name,
431+
con,
432+
schema=None,
433+
if_exists="fail",
434+
index=True,
435+
index_label=None,
436+
chunksize=None,
437+
dtype=None,
438+
):
439+
ErrorMessage.default_to_pandas("`to_sql`")
440+
df = qc.to_pandas()
441+
df.to_sql(
442+
name=name,
443+
con=con,
444+
schema=schema,
445+
if_exists=if_exists,
446+
index=index,
447+
index_label=index_label,
448+
chunksize=chunksize,
449+
dtype=dtype,
450+
)

modin/engines/ray/pandas_on_ray/io.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,35 @@ def read_feather(cls, path, nthreads=1, columns=None):
538538
)
539539
return new_query_compiler
540540

541+
@classmethod
542+
def to_sql(cls, qc, **kwargs):
543+
"""Write records stored in a DataFrame to a SQL database.
544+
Args:
545+
qc: the query compiler of the DF that we want to run to_sql on
546+
kwargs: parameters for pandas.to_sql(**kwargs)
547+
"""
548+
# we first insert an empty DF in order to create the full table in the database
549+
# This also helps to validate the input against pandas
550+
# we would like to_sql() to complete only when all rows have been inserted into the database
551+
# since the mapping operation is non-blocking, each partition will return an empty DF
552+
# so at the end, the blocking operation will be this empty DF to_pandas
553+
554+
empty_df = qc.head(1).to_pandas().head(0)
555+
empty_df.to_sql(**kwargs)
556+
# so each partition will append its respective DF
557+
kwargs["if_exists"] = "append"
558+
columns = qc.columns
559+
560+
def func(df, **kwargs):
561+
df.columns = columns
562+
df.to_sql(**kwargs)
563+
return pandas.DataFrame()
564+
565+
map_func = qc._prepare_method(func, **kwargs)
566+
result = qc.map_across_full_axis(1, map_func)
567+
# blocking operation
568+
result.to_pandas()
569+
541570

542571
@ray.remote
543572
def get_index(index_name, *partition_indices): # pragma: no cover

modin/experimental/pandas/test/test_io_exp.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,42 @@
22
import pandas
33
import pytest
44
import modin.experimental.pandas as pd
5-
6-
7-
from modin.pandas.test.test_io import (
8-
setup_sql_file,
9-
teardown_sql_file,
5+
from modin.pandas.test.test_io import ( # noqa: F401
106
modin_df_equals_pandas,
7+
make_sql_connection,
118
)
129

1310

14-
def test_from_sql_distributed():
11+
def test_from_sql_distributed(make_sql_connection): # noqa: F811
1512
if os.environ.get("MODIN_ENGINE", "") == "Ray":
1613
filename = "test_from_sql_distributed.db"
17-
teardown_sql_file(filename)
1814
table = "test_from_sql_distributed"
19-
db_uri = "sqlite:///" + filename
20-
setup_sql_file(db_uri, filename, table, True)
15+
conn = make_sql_connection(filename, table)
2116
query = "select * from {0}".format(table)
2217

23-
pandas_df = pandas.read_sql(query, db_uri)
18+
pandas_df = pandas.read_sql(query, conn)
2419
modin_df_from_query = pd.read_sql(
25-
query, db_uri, partition_column="col1", lower_bound=0, upper_bound=6
20+
query, conn, partition_column="col1", lower_bound=0, upper_bound=6
2621
)
2722
modin_df_from_table = pd.read_sql(
28-
table, db_uri, partition_column="col1", lower_bound=0, upper_bound=6
23+
table, conn, partition_column="col1", lower_bound=0, upper_bound=6
2924
)
3025

3126
assert modin_df_equals_pandas(modin_df_from_query, pandas_df)
3227
assert modin_df_equals_pandas(modin_df_from_table, pandas_df)
3328

34-
teardown_sql_file(filename)
3529

36-
37-
def test_from_sql_defaults():
30+
def test_from_sql_defaults(make_sql_connection): # noqa: F811
3831
filename = "test_from_sql_distributed.db"
39-
teardown_sql_file(filename)
4032
table = "test_from_sql_distributed"
41-
db_uri = "sqlite:///" + filename
42-
setup_sql_file(db_uri, filename, table, True)
33+
conn = make_sql_connection(filename, table)
4334
query = "select * from {0}".format(table)
4435

45-
pandas_df = pandas.read_sql(query, db_uri)
36+
pandas_df = pandas.read_sql(query, conn)
4637
with pytest.warns(UserWarning):
47-
modin_df_from_query = pd.read_sql(query, db_uri)
38+
modin_df_from_query = pd.read_sql(query, conn)
4839
with pytest.warns(UserWarning):
49-
modin_df_from_table = pd.read_sql(table, db_uri)
40+
modin_df_from_table = pd.read_sql(table, conn)
5041

5142
assert modin_df_equals_pandas(modin_df_from_query, pandas_df)
5243
assert modin_df_equals_pandas(modin_df_from_table, pandas_df)
53-
54-
teardown_sql_file(filename)

modin/pandas/dataframe.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4188,25 +4188,34 @@ def to_sql(
41884188
self,
41894189
name,
41904190
con,
4191-
flavor=None,
41924191
schema=None,
41934192
if_exists="fail",
41944193
index=True,
41954194
index_label=None,
41964195
chunksize=None,
41974196
dtype=None,
4198-
): # pragma: no cover
4199-
return self._default_to_pandas(
4200-
pandas.DataFrame.to_sql,
4201-
name,
4202-
con,
4203-
flavor,
4204-
schema,
4205-
if_exists,
4206-
index,
4207-
index_label,
4208-
chunksize,
4209-
dtype,
4197+
):
4198+
new_query_compiler = self._query_compiler
4199+
# writing the index to the database by inserting it to the DF
4200+
if index:
4201+
if not index_label:
4202+
index_label = "index"
4203+
new_query_compiler = new_query_compiler.insert(0, index_label, self.index)
4204+
# so pandas._to_sql will not write the index to the database as well
4205+
index = False
4206+
4207+
from modin.data_management.factories import BaseFactory
4208+
4209+
BaseFactory.to_sql(
4210+
new_query_compiler,
4211+
name=name,
4212+
con=con,
4213+
schema=schema,
4214+
if_exists=if_exists,
4215+
index=index,
4216+
index_label=index_label,
4217+
chunksize=chunksize,
4218+
dtype=dtype,
42104219
)
42114220

42124221
def to_stata(

modin/pandas/test/test_io.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from pathlib import Path
1111
import pyarrow as pa
1212
import os
13-
import sqlite3
1413
import sys
1514

1615
# needed to resolve ray-project/ray#3744
@@ -47,7 +46,6 @@ def setup_parquet_file(row_size, force=False):
4746
).to_parquet(TEST_PARQUET_FILENAME)
4847

4948

50-
@pytest.fixture
5149
def create_test_ray_dataframe():
5250
df = pd.DataFrame(
5351
{
@@ -62,7 +60,6 @@ def create_test_ray_dataframe():
6260
return df
6361

6462

65-
@pytest.fixture
6663
def create_test_pandas_dataframe():
6764
df = pandas.DataFrame(
6865
{
@@ -261,26 +258,41 @@ def teardown_pickle_file():
261258

262259

263260
@pytest.fixture
264-
def setup_sql_file(conn, filename, table, force=False):
265-
if os.path.exists(filename) and not force:
266-
pass
267-
else:
268-
df = pandas.DataFrame(
269-
{
270-
"col1": [0, 1, 2, 3, 4, 5, 6],
271-
"col2": [7, 8, 9, 10, 11, 12, 13],
272-
"col3": [14, 15, 16, 17, 18, 19, 20],
273-
"col4": [21, 22, 23, 24, 25, 26, 27],
274-
"col5": [0, 0, 0, 0, 0, 0, 0],
275-
}
276-
)
277-
df.to_sql(table, conn)
278-
279-
280-
@pytest.fixture
281-
def teardown_sql_file(filename):
282-
if os.path.exists(filename):
283-
os.remove(filename)
261+
def make_sql_connection():
262+
"""Sets up sql connections and takes them down after the caller is done.
263+
264+
Yields:
265+
Factory that generates sql connection objects
266+
"""
267+
filenames = []
268+
269+
def _sql_connection(filename, table=""):
270+
# Remove file if exists
271+
if os.path.exists(filename):
272+
os.remove(filename)
273+
filenames.append(filename)
274+
275+
# Create connection and, if needed, table
276+
conn = "sqlite:///{}".format(filename)
277+
if table:
278+
df = pandas.DataFrame(
279+
{
280+
"col1": [0, 1, 2, 3, 4, 5, 6],
281+
"col2": [7, 8, 9, 10, 11, 12, 13],
282+
"col3": [14, 15, 16, 17, 18, 19, 20],
283+
"col4": [21, 22, 23, 24, 25, 26, 27],
284+
"col5": [0, 0, 0, 0, 0, 0, 0],
285+
}
286+
)
287+
df.to_sql(table, conn)
288+
return conn
289+
290+
yield _sql_connection
291+
292+
# Takedown the fixture
293+
for filename in filenames:
294+
if os.path.exists(filename):
295+
os.remove(filename)
284296

285297

286298
def test_from_parquet():
@@ -460,21 +472,17 @@ def test_from_pickle():
460472
teardown_pickle_file()
461473

462474

463-
def test_from_sql():
475+
def test_from_sql(make_sql_connection):
464476
filename = "test_from_sql.db"
465-
teardown_sql_file(filename)
466-
conn = sqlite3.connect(filename)
467477
table = "test_from_sql"
468-
setup_sql_file(conn, filename, table, True)
478+
conn = make_sql_connection(filename, table)
469479
query = "select * from {0}".format(table)
470480

471481
pandas_df = pandas.read_sql(query, conn)
472482
modin_df = pd.read_sql(query, conn)
473483

474484
assert modin_df_equals_pandas(modin_df, pandas_df)
475485

476-
teardown_sql_file(filename)
477-
478486

479487
@pytest.mark.skip(reason="No SAS write methods in Pandas")
480488
def test_from_sas():
@@ -750,20 +758,40 @@ def test_to_pickle():
750758
teardown_test_file(TEST_PICKLE_DF_FILENAME)
751759

752760

753-
def test_to_sql():
761+
def test_to_sql_without_index(make_sql_connection):
762+
table_name = "tbl_without_index"
754763
modin_df = create_test_ray_dataframe()
755764
pandas_df = create_test_pandas_dataframe()
756765

757-
TEST_SQL_DF_FILENAME = "test_df.sql"
758-
TEST_SQL_pandas_FILENAME = "test_pandas.sql"
766+
# We do not pass the table name so the fixture won't generate a table
767+
conn = make_sql_connection("test_to_sql.db")
768+
modin_df.to_sql(table_name, conn, index=False)
769+
df_modin_sql = pandas.read_sql(table_name, con=conn)
770+
771+
# We do not pass the table name so the fixture won't generate a table
772+
conn = make_sql_connection("test_to_sql_pandas.db")
773+
pandas_df.to_sql(table_name, conn, index=False)
774+
df_pandas_sql = pandas.read_sql(table_name, con=conn)
775+
776+
assert df_modin_sql.sort_index().equals(df_pandas_sql.sort_index())
777+
778+
779+
def test_to_sql_with_index(make_sql_connection):
780+
table_name = "tbl_with_index"
781+
modin_df = create_test_ray_dataframe()
782+
pandas_df = create_test_pandas_dataframe()
759783

760-
modin_df.to_pickle(TEST_SQL_DF_FILENAME)
761-
pandas_df.to_pickle(TEST_SQL_pandas_FILENAME)
784+
# We do not pass the table name so the fixture won't generate a table
785+
conn = make_sql_connection("test_to_sql.db")
786+
modin_df.to_sql(table_name, conn)
787+
df_modin_sql = pandas.read_sql(table_name, con=conn, index_col="index")
762788

763-
assert test_files_eq(TEST_SQL_DF_FILENAME, TEST_SQL_pandas_FILENAME)
789+
# We do not pass the table name so the fixture won't generate a table
790+
conn = make_sql_connection("test_to_sql_pandas.db")
791+
pandas_df.to_sql(table_name, conn)
792+
df_pandas_sql = pandas.read_sql(table_name, con=conn, index_col="index")
764793

765-
teardown_test_file(TEST_SQL_DF_FILENAME)
766-
teardown_test_file(TEST_SQL_pandas_FILENAME)
794+
assert df_modin_sql.sort_index().equals(df_pandas_sql.sort_index())
767795

768796

769797
def test_to_stata():

0 commit comments

Comments
 (0)