Skip to content

Commit 46c579a

Browse files
authored
ENH: add dtypes argument to read_gbq (#333)
Use this argument to override the default ``dtype`` for a particular column in the query results. For example, this can be used to select nullable integer columns as the ``Int64`` nullable integer pandas extension type. df = gbq.read_gbq( "SELECT CAST(NULL AS INT64) AS null_integer", dtypes={"null_integer": "Int64"}, )
1 parent a677d2e commit 46c579a

File tree

4 files changed

+108
-23
lines changed

4 files changed

+108
-23
lines changed

docs/source/changelog.rst

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@ Changelog
66
0.14.0 / TBD
77
------------
88

9+
- Add ``dtypes`` argument to ``read_gbq``. Use this argument to override the
10+
default ``dtype`` for a particular column in the query results. For
11+
example, this can be used to select nullable integer columns as the
12+
``Int64`` nullable integer pandas extension type. (:issue:`242`,
13+
:issue:`332`)
14+
15+
.. code-block:: python
16+
17+
df = gbq.read_gbq(
18+
"SELECT CAST(NULL AS INT64) AS null_integer",
19+
dtypes={"null_integer": "Int64"},
20+
)
21+
922
Dependency updates
1023
~~~~~~~~~~~~~~~~~~
1124

@@ -15,7 +28,7 @@ Dependency updates
1528
Internal changes
1629
~~~~~~~~~~~~~~~~
1730

18-
- Update tests to run against for Python 3.8. (:issue:`331`)
31+
- Update tests to run against Python 3.8. (:issue:`331`)
1932

2033

2134
.. _changelog-0.13.3:

pandas_gbq/gbq.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,19 +526,28 @@ def run_query(
526526
)
527527
)
528528

529+
dtypes = kwargs.get("dtypes")
529530
return self._download_results(
530531
query_reply,
531532
max_results=max_results,
532533
progress_bar_type=progress_bar_type,
534+
user_dtypes=dtypes,
533535
)
534536

535537
def _download_results(
536-
self, query_job, max_results=None, progress_bar_type=None
538+
self,
539+
query_job,
540+
max_results=None,
541+
progress_bar_type=None,
542+
user_dtypes=None,
537543
):
538544
# No results are desired, so don't bother downloading anything.
539545
if max_results == 0:
540546
return None
541547

548+
if user_dtypes is None:
549+
user_dtypes = {}
550+
542551
try:
543552
bqstorage_client = None
544553
if max_results is None:
@@ -555,9 +564,10 @@ def _download_results(
555564
)
556565

557566
schema_fields = [field.to_api_repr() for field in rows_iter.schema]
558-
nullsafe_dtypes = _bqschema_to_nullsafe_dtypes(schema_fields)
567+
conversion_dtypes = _bqschema_to_nullsafe_dtypes(schema_fields)
568+
conversion_dtypes.update(user_dtypes)
559569
df = rows_iter.to_dataframe(
560-
dtypes=nullsafe_dtypes,
570+
dtypes=conversion_dtypes,
561571
bqstorage_client=bqstorage_client,
562572
progress_bar_type=progress_bar_type,
563573
)
@@ -790,6 +800,7 @@ def read_gbq(
790800
verbose=None,
791801
private_key=None,
792802
progress_bar_type="tqdm",
803+
dtypes=None,
793804
):
794805
r"""Load data from Google BigQuery using google-cloud-python
795806
@@ -910,6 +921,10 @@ def read_gbq(
910921
``'tqdm_gui'``
911922
Use the :func:`tqdm.tqdm_gui` function to display a
912923
progress bar as a graphical dialog box.
924+
dtypes : dict, optional
925+
A dictionary of column names to pandas ``dtype``. The provided
926+
``dtype`` is used when constructing the series for the column
927+
specified. Otherwise, a default ``dtype`` is used.
913928
verbose : None, deprecated
914929
Deprecated in Pandas-GBQ 0.4.0. Use the `logging module
915930
to adjust verbosity instead
@@ -965,6 +980,7 @@ def read_gbq(
965980
configuration=configuration,
966981
max_results=max_results,
967982
progress_bar_type=progress_bar_type,
983+
dtypes=dtypes,
968984
)
969985

970986
# Reindex the DataFrame on the provided column

tests/system/test_gbq.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import pandas.api.types
99
import pandas.util.testing as tm
1010
from pandas import DataFrame, NaT
11+
12+
try:
13+
import pkg_resources # noqa
14+
except ImportError:
15+
raise ImportError("Could not import pkg_resources (setuptools).")
1116
import pytest
1217
import pytz
1318

@@ -16,14 +21,14 @@
1621

1722

1823
TABLE_ID = "new_test"
24+
PANDAS_VERSION = pkg_resources.parse_version(pandas.__version__)
25+
NULLABLE_INT_PANDAS_VERSION = pkg_resources.parse_version("0.24.0")
26+
NULLABLE_INT_MESSAGE = (
27+
"Require pandas 0.24+ in order to use nullable integer type."
28+
)
1929

2030

2131
def test_imports():
22-
try:
23-
import pkg_resources # noqa
24-
except ImportError:
25-
raise ImportError("Could not import pkg_resources (setuptools).")
26-
2732
gbq._test_google_api_imports()
2833

2934

@@ -87,62 +92,92 @@ def test_should_properly_handle_null_strings(self, project_id):
8792
tm.assert_frame_equal(df, DataFrame({"null_string": [None]}))
8893

8994
def test_should_properly_handle_valid_integers(self, project_id):
90-
query = "SELECT INTEGER(3) AS valid_integer"
95+
query = "SELECT CAST(3 AS INT64) AS valid_integer"
9196
df = gbq.read_gbq(
9297
query,
9398
project_id=project_id,
9499
credentials=self.credentials,
95-
dialect="legacy",
100+
dialect="standard",
96101
)
97102
tm.assert_frame_equal(df, DataFrame({"valid_integer": [3]}))
98103

99104
def test_should_properly_handle_nullable_integers(self, project_id):
105+
if PANDAS_VERSION < NULLABLE_INT_PANDAS_VERSION:
106+
pytest.skip(msg=NULLABLE_INT_MESSAGE)
107+
100108
query = """SELECT * FROM
101-
(SELECT 1 AS nullable_integer),
102-
(SELECT NULL AS nullable_integer)"""
109+
UNNEST([1, NULL]) AS nullable_integer
110+
"""
103111
df = gbq.read_gbq(
104112
query,
105113
project_id=project_id,
106114
credentials=self.credentials,
107-
dialect="legacy",
115+
dialect="standard",
116+
dtypes={"nullable_integer": "Int64"},
117+
)
118+
tm.assert_frame_equal(
119+
df,
120+
DataFrame(
121+
{
122+
"nullable_integer": pandas.Series(
123+
[1, pandas.NA], dtype="Int64"
124+
)
125+
}
126+
),
108127
)
109-
tm.assert_frame_equal(df, DataFrame({"nullable_integer": [1, None]}))
110128

111129
def test_should_properly_handle_valid_longs(self, project_id):
112130
query = "SELECT 1 << 62 AS valid_long"
113131
df = gbq.read_gbq(
114132
query,
115133
project_id=project_id,
116134
credentials=self.credentials,
117-
dialect="legacy",
135+
dialect="standard",
118136
)
119137
tm.assert_frame_equal(df, DataFrame({"valid_long": [1 << 62]}))
120138

121139
def test_should_properly_handle_nullable_longs(self, project_id):
140+
if PANDAS_VERSION < NULLABLE_INT_PANDAS_VERSION:
141+
pytest.skip(msg=NULLABLE_INT_MESSAGE)
142+
122143
query = """SELECT * FROM
123-
(SELECT 1 << 62 AS nullable_long),
124-
(SELECT NULL AS nullable_long)"""
144+
UNNEST([1 << 62, NULL]) AS nullable_long
145+
"""
125146
df = gbq.read_gbq(
126147
query,
127148
project_id=project_id,
128149
credentials=self.credentials,
129-
dialect="legacy",
150+
dialect="standard",
151+
dtypes={"nullable_long": "Int64"},
130152
)
131153
tm.assert_frame_equal(
132-
df, DataFrame({"nullable_long": [1 << 62, None]})
154+
df,
155+
DataFrame(
156+
{
157+
"nullable_long": pandas.Series(
158+
[1 << 62, pandas.NA], dtype="Int64"
159+
)
160+
}
161+
),
133162
)
134163

135164
def test_should_properly_handle_null_integers(self, project_id):
136-
query = "SELECT INTEGER(NULL) AS null_integer"
165+
if PANDAS_VERSION < NULLABLE_INT_PANDAS_VERSION:
166+
pytest.skip(msg=NULLABLE_INT_MESSAGE)
167+
168+
query = "SELECT CAST(NULL AS INT64) AS null_integer"
137169
df = gbq.read_gbq(
138170
query,
139171
project_id=project_id,
140172
credentials=self.credentials,
141-
dialect="legacy",
173+
dialect="standard",
174+
dtypes={"null_integer": "Int64"},
142175
)
143176
tm.assert_frame_equal(
144177
df,
145-
DataFrame({"null_integer": pandas.Series([None], dtype="object")}),
178+
DataFrame(
179+
{"null_integer": pandas.Series([pandas.NA], dtype="Int64")}
180+
),
146181
)
147182

148183
def test_should_properly_handle_valid_floats(self, project_id):

tests/unit/test_gbq.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,27 @@ def test_load_does_not_modify_schema_arg(mock_bigquery_client):
476476
assert original_schema == original_schema_cp
477477

478478

479+
def test_read_gbq_passes_dtypes(
480+
mock_bigquery_client, mock_service_account_credentials
481+
):
482+
mock_service_account_credentials.project_id = "service_account_project_id"
483+
df = gbq.read_gbq(
484+
"SELECT 1 AS int_col",
485+
dialect="standard",
486+
credentials=mock_service_account_credentials,
487+
dtypes={"int_col": "my-custom-dtype"},
488+
)
489+
assert df is not None
490+
491+
mock_list_rows = mock_bigquery_client.list_rows("dest", max_results=100)
492+
493+
mock_list_rows.to_dataframe.assert_called_once_with(
494+
dtypes={"int_col": "my-custom-dtype"},
495+
bqstorage_client=mock.ANY,
496+
progress_bar_type=mock.ANY,
497+
)
498+
499+
479500
def test_read_gbq_calls_tqdm(
480501
mock_bigquery_client, mock_service_account_credentials
481502
):

0 commit comments

Comments
 (0)