Skip to content

Commit 50dca4c

Browse files
authored
feat: support dtype parameter in read_csv for bigquery engine (#1749)
Fixes internal issue 404530013
1 parent 1f6442e commit 50dca4c

File tree

4 files changed

+64
-15
lines changed

4 files changed

+64
-15
lines changed

bigframes/session/__init__.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
import bigframes._config.bigquery_options as bigquery_options
6262
import bigframes.clients
6363
import bigframes.constants
64-
from bigframes.core import blocks, log_adapter
64+
from bigframes.core import blocks, log_adapter, utils
6565
import bigframes.core.pyformat
6666

6767
# Even though the ibis.backends.bigquery import is unused, it's needed
@@ -1108,11 +1108,8 @@ def _read_csv_w_bigquery_engine(
11081108
native CSV loading capabilities, making it suitable for large datasets
11091109
that may not fit into local memory.
11101110
"""
1111-
if dtype is not None:
1112-
raise NotImplementedError(
1113-
f"BigQuery engine does not support the `dtype` argument."
1114-
f"{constants.FEEDBACK_LINK}"
1115-
)
1111+
if dtype is not None and not utils.is_dict_like(dtype):
1112+
raise ValueError("dtype should be a dict-like object.")
11161113

11171114
if names is not None:
11181115
if len(names) != len(set(names)):
@@ -1167,10 +1164,16 @@ def _read_csv_w_bigquery_engine(
11671164
job_config.skip_leading_rows = header + 1
11681165

11691166
table_id = self._loader.load_file(filepath_or_buffer, job_config=job_config)
1170-
return self._loader.read_gbq_table(
1167+
df = self._loader.read_gbq_table(
11711168
table_id, index_col=index_col, columns=columns, names=names
11721169
)
11731170

1171+
if dtype is not None:
1172+
for column, dtype in dtype.items():
1173+
if column in df.columns:
1174+
df[column] = df[column].astype(dtype)
1175+
return df
1176+
11741177
def read_pickle(
11751178
self,
11761179
filepath_or_buffer: FilePath | ReadPickleBuffer,

bigframes/session/loader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -663,9 +663,10 @@ def read_gbq_table(
663663
renamed_cols: Dict[str, str] = {
664664
col: new_name for col, new_name in zip(array_value.column_ids, names)
665665
}
666-
index_names = [
667-
renamed_cols.get(index_col, index_col) for index_col in index_cols
668-
]
666+
if index_col != bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
667+
index_names = [
668+
renamed_cols.get(index_col, index_col) for index_col in index_cols
669+
]
669670
value_columns = [renamed_cols.get(col, col) for col in value_columns]
670671

671672
block = blocks.Block(

tests/system/small/test_session.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,45 @@ def test_read_csv_for_names_and_index_col(
13691369
)
13701370

13711371

1372+
def test_read_csv_for_dtype(session, df_and_gcs_csv_for_two_columns):
1373+
_, path = df_and_gcs_csv_for_two_columns
1374+
1375+
dtype = {"bool_col": pd.BooleanDtype(), "int64_col": pd.Float64Dtype()}
1376+
bf_df = session.read_csv(path, engine="bigquery", dtype=dtype)
1377+
1378+
# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
1379+
pd_df = session.read_csv(path, dtype=dtype)
1380+
1381+
assert bf_df.shape == pd_df.shape
1382+
assert bf_df.columns.tolist() == pd_df.columns.tolist()
1383+
1384+
# BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs
1385+
# (b/280889935) or guarantee row ordering.
1386+
bf_df = bf_df.set_index("rowindex").sort_index()
1387+
pd_df = pd_df.set_index("rowindex")
1388+
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
1389+
1390+
1391+
def test_read_csv_for_dtype_w_names(session, df_and_gcs_csv_for_two_columns):
1392+
_, path = df_and_gcs_csv_for_two_columns
1393+
1394+
names = ["a", "b", "c"]
1395+
dtype = {"b": pd.BooleanDtype(), "c": pd.Float64Dtype()}
1396+
bf_df = session.read_csv(path, engine="bigquery", names=names, dtype=dtype)
1397+
1398+
# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
1399+
pd_df = session.read_csv(path, names=names, dtype=dtype)
1400+
1401+
assert bf_df.shape == pd_df.shape
1402+
assert bf_df.columns.tolist() == pd_df.columns.tolist()
1403+
1404+
# BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs
1405+
# (b/280889935) or guarantee row ordering.
1406+
bf_df = bf_df.set_index("a").sort_index()
1407+
pd_df = pd_df.set_index("a")
1408+
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
1409+
1410+
13721411
@pytest.mark.parametrize(
13731412
("kwargs", "match"),
13741413
[

tests/unit/session/test_session.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,6 @@
108108
@pytest.mark.parametrize(
109109
("kwargs", "match"),
110110
[
111-
pytest.param(
112-
{"engine": "bigquery", "dtype": {}},
113-
"BigQuery engine does not support the `dtype` argument",
114-
id="with_dtype",
115-
),
116111
pytest.param(
117112
{"engine": "bigquery", "usecols": [1, 2]},
118113
"BigQuery engine only supports an iterable of strings for `usecols`.",
@@ -215,6 +210,17 @@ def test_read_csv_w_bigquery_engine_raises_error_for_invalid_names(
215210
session.read_csv("path/to/csv.csv", engine="bigquery", names=names)
216211

217212

213+
def test_read_csv_w_bigquery_engine_raises_error_for_invalid_dtypes():
214+
session = mocks.create_bigquery_session()
215+
216+
with pytest.raises(ValueError, match="dtype should be a dict-like object."):
217+
session.read_csv(
218+
"path/to/csv.csv",
219+
engine="bigquery",
220+
dtype=["a", "b", "c"], # type: ignore[arg-type]
221+
)
222+
223+
218224
@pytest.mark.parametrize("missing_parts_table_id", [(""), ("table")])
219225
def test_read_gbq_missing_parts(missing_parts_table_id):
220226
session = mocks.create_bigquery_session()

0 commit comments

Comments
 (0)