Skip to content

Commit 190390b

Browse files
authored
fix: support JSON and STRUCT for bbq.sql_scalar (#1754)
Fixes internal issue 416015997
1 parent c51d2b1 commit 190390b

File tree

6 files changed

+136
-42
lines changed

6 files changed

+136
-42
lines changed

bigframes/bigquery/_operations/sql.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import google.cloud.bigquery
2222

23+
import bigframes.core.compile.sqlglot.sqlglot_ir as sqlglot_ir
2324
import bigframes.core.sql
2425
import bigframes.dataframe
2526
import bigframes.dtypes
@@ -72,16 +73,16 @@ def sql_scalar(
7273
# Another benefit of this is that if there is a syntax error in the SQL
7374
# template, then this will fail with an error earlier in the process,
7475
# aiding users in debugging.
75-
base_series = columns[0]
76-
literals = [
77-
bigframes.dtypes.bigframes_dtype_to_literal(column.dtype) for column in columns
76+
literals_sql = [
77+
sqlglot_ir._literal(None, column.dtype).sql(dialect="bigquery")
78+
for column in columns
7879
]
79-
literals_sql = [bigframes.core.sql.simple_literal(literal) for literal in literals]
80+
select_sql = sql_template.format(*literals_sql)
81+
dry_run_sql = f"SELECT {select_sql}"
8082

8183
# Use the executor directly, because we want the original column IDs, not
8284
# the user-friendly column names that block.to_sql_query() would produce.
83-
select_sql = sql_template.format(*literals_sql)
84-
dry_run_sql = f"SELECT {select_sql}"
85+
base_series = columns[0]
8586
bqclient = base_series._session.bqclient
8687
job = bqclient.query(
8788
dry_run_sql, job_config=google.cloud.bigquery.QueryJobConfig(dry_run=True)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import typing
1919

2020
from google.cloud import bigquery
21+
import numpy as np
2122
import pyarrow as pa
2223
import sqlglot as sg
2324
import sqlglot.dialects.bigquery
@@ -213,7 +214,11 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
213214
elif dtype == dtypes.BYTES_DTYPE:
214215
return _cast(str(value), sqlglot_type)
215216
elif dtypes.is_time_like(dtype):
217+
if isinstance(value, np.generic):
218+
value = value.item()
216219
return _cast(sge.convert(value.isoformat()), sqlglot_type)
220+
elif dtype in (dtypes.NUMERIC_DTYPE, dtypes.BIGNUMERIC_DTYPE):
221+
return _cast(sge.convert(value), sqlglot_type)
217222
elif dtypes.is_geo_like(dtype):
218223
wkt = value if isinstance(value, str) else to_wkt(value)
219224
return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt))
@@ -234,6 +239,8 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
234239
)
235240
return values if len(value) > 0 else _cast(values, sqlglot_type)
236241
else:
242+
if isinstance(value, np.generic):
243+
value = value.item()
237244
return sge.convert(value)
238245

239246

bigframes/dtypes.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -499,33 +499,6 @@ def bigframes_dtype_to_arrow_dtype(
499499
)
500500

501501

502-
def bigframes_dtype_to_literal(
503-
bigframes_dtype: Dtype,
504-
) -> Any:
505-
"""Create a representative literal value for a bigframes dtype.
506-
507-
The inverse of infer_literal_type().
508-
"""
509-
if isinstance(bigframes_dtype, pd.ArrowDtype):
510-
arrow_type = bigframes_dtype.pyarrow_dtype
511-
return arrow_type_to_literal(arrow_type)
512-
513-
if isinstance(bigframes_dtype, pd.Float64Dtype):
514-
return 1.0
515-
if isinstance(bigframes_dtype, pd.Int64Dtype):
516-
return 1
517-
if isinstance(bigframes_dtype, pd.BooleanDtype):
518-
return True
519-
if isinstance(bigframes_dtype, pd.StringDtype):
520-
return "string"
521-
if isinstance(bigframes_dtype, gpd.array.GeometryDtype):
522-
return shapely.geometry.Point((0, 0))
523-
524-
raise TypeError(
525-
f"No literal conversion for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
526-
)
527-
528-
529502
def arrow_type_to_literal(
530503
arrow_type: pa.DataType,
531504
) -> Any:

tests/system/small/bigquery/test_sql.py

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import bigframes.bigquery
15+
import pandas as pd
16+
import pytest
1617

18+
import bigframes.bigquery as bbq
19+
import bigframes.dtypes as dtypes
20+
import bigframes.pandas as bpd
1721

18-
def test_sql_scalar_on_scalars_null_index(scalars_df_null_index):
19-
series = bigframes.bigquery.sql_scalar(
22+
23+
def test_sql_scalar_for_all_scalar_types(scalars_df_null_index):
24+
series = bbq.sql_scalar(
2025
"""
2126
CAST({0} AS INT64)
2227
+ BYTE_LENGTH({1})
@@ -48,3 +53,109 @@ def test_sql_scalar_on_scalars_null_index(scalars_df_null_index):
4853
)
4954
result = series.to_pandas()
5055
assert len(result) == len(scalars_df_null_index)
56+
57+
58+
def test_sql_scalar_for_bool_series(scalars_df_index):
59+
series: bpd.Series = scalars_df_index["bool_col"]
60+
result = bbq.sql_scalar("CAST({0} AS INT64)", [series])
61+
expected = series.astype(dtypes.INT_DTYPE)
62+
expected.name = None
63+
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())
64+
65+
66+
@pytest.mark.parametrize(
67+
("column_name"),
68+
[
69+
pytest.param("bool_col"),
70+
pytest.param("bytes_col"),
71+
pytest.param("date_col"),
72+
pytest.param("datetime_col"),
73+
pytest.param("geography_col"),
74+
pytest.param("int64_col"),
75+
pytest.param("numeric_col"),
76+
pytest.param("float64_col"),
77+
pytest.param("string_col"),
78+
pytest.param("time_col"),
79+
pytest.param("timestamp_col"),
80+
],
81+
)
82+
def test_sql_scalar_outputs_all_scalar_types(scalars_df_index, column_name):
83+
series: bpd.Series = scalars_df_index[column_name]
84+
result = bbq.sql_scalar("{0}", [series])
85+
expected = series
86+
expected.name = None
87+
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())
88+
89+
90+
def test_sql_scalar_for_array_series(repeated_df):
91+
result = bbq.sql_scalar(
92+
"""
93+
ARRAY_LENGTH({0}) + ARRAY_LENGTH({1}) + ARRAY_LENGTH({2})
94+
+ ARRAY_LENGTH({3}) + ARRAY_LENGTH({4}) + ARRAY_LENGTH({5})
95+
+ ARRAY_LENGTH({6})
96+
""",
97+
[
98+
repeated_df["int_list_col"],
99+
repeated_df["bool_list_col"],
100+
repeated_df["float_list_col"],
101+
repeated_df["date_list_col"],
102+
repeated_df["date_time_list_col"],
103+
repeated_df["numeric_list_col"],
104+
repeated_df["string_list_col"],
105+
],
106+
)
107+
108+
expected = (
109+
repeated_df["int_list_col"].list.len()
110+
+ repeated_df["bool_list_col"].list.len()
111+
+ repeated_df["float_list_col"].list.len()
112+
+ repeated_df["date_list_col"].list.len()
113+
+ repeated_df["date_time_list_col"].list.len()
114+
+ repeated_df["numeric_list_col"].list.len()
115+
+ repeated_df["string_list_col"].list.len()
116+
)
117+
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())
118+
119+
120+
def test_sql_scalar_outputs_array_series(repeated_df):
121+
result = bbq.sql_scalar("{0}", [repeated_df["int_list_col"]])
122+
expected = repeated_df["int_list_col"]
123+
expected.name = None
124+
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())
125+
126+
127+
def test_sql_scalar_for_struct_series(nested_structs_df):
128+
result = bbq.sql_scalar(
129+
"CHAR_LENGTH({0}.name) + {0}.age",
130+
[nested_structs_df["person"]],
131+
)
132+
expected = nested_structs_df["person"].struct.field(
133+
"name"
134+
).str.len() + nested_structs_df["person"].struct.field("age")
135+
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())
136+
137+
138+
def test_sql_scalar_outputs_struct_series(nested_structs_df):
139+
result = bbq.sql_scalar("{0}", [nested_structs_df["person"]])
140+
expected = nested_structs_df["person"]
141+
expected.name = None
142+
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())
143+
144+
145+
def test_sql_scalar_for_json_series(json_df):
146+
result = bbq.sql_scalar(
147+
"""JSON_VALUE({0}, '$.int_value')""",
148+
[
149+
json_df["json_col"],
150+
],
151+
)
152+
expected = bbq.json_value(json_df["json_col"], "$.int_value")
153+
expected.name = None
154+
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())
155+
156+
157+
def test_sql_scalar_outputs_json_series(json_df):
158+
result = bbq.sql_scalar("{0}", [json_df["json_col"]])
159+
expected = json_df["json_col"]
160+
expected.name = None
161+
pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())

tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ WITH `bfcte_0` AS (
1010
ST_GEOGFROMTEXT('POINT (-122.0838511 37.3860517)'),
1111
123456789,
1212
0,
13-
1.234567890,
13+
CAST(1.234567890 AS NUMERIC),
1414
1.25,
1515
0,
1616
0,
@@ -27,7 +27,7 @@ WITH `bfcte_0` AS (
2727
ST_GEOGFROMTEXT('POINT (-71.104 42.315)'),
2828
-987654321,
2929
1,
30-
1.234567890,
30+
CAST(1.234567890 AS NUMERIC),
3131
2.51,
3232
1,
3333
1,
@@ -44,7 +44,7 @@ WITH `bfcte_0` AS (
4444
ST_GEOGFROMTEXT('POINT (-0.124474760143016 51.5007826749545)'),
4545
314159,
4646
0,
47-
101.101010100,
47+
CAST(101.101010100 AS NUMERIC),
4848
25000000000.0,
4949
2,
5050
2,
@@ -95,7 +95,7 @@ WITH `bfcte_0` AS (
9595
CAST(NULL AS GEOGRAPHY),
9696
55555,
9797
0,
98-
5.555555000,
98+
CAST(5.555555000 AS NUMERIC),
9999
555.555,
100100
5,
101101
5,
@@ -112,7 +112,7 @@ WITH `bfcte_0` AS (
112112
ST_GEOGFROMTEXT('LINESTRING (-0.127959 51.507728, -0.127026 51.507473)'),
113113
101202303,
114114
2,
115-
-10.090807000,
115+
CAST(-10.090807000 AS NUMERIC),
116116
-123.456,
117117
6,
118118
6,
@@ -129,7 +129,7 @@ WITH `bfcte_0` AS (
129129
CAST(NULL AS GEOGRAPHY),
130130
-214748367,
131131
2,
132-
11111111.100000000,
132+
CAST(11111111.100000000 AS NUMERIC),
133133
42.42,
134134
7,
135135
7,

third_party/bigframes_vendored/ibis/backends/bigquery/datatypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def from_ibis(cls, dtype: dt.DataType) -> str:
5353
)
5454
elif dtype.is_integer():
5555
return "INT64"
56+
elif dtype.is_boolean():
57+
return "BOOLEAN"
5658
elif dtype.is_binary():
5759
return "BYTES"
5860
elif dtype.is_string():

0 commit comments

Comments
 (0)