Skip to content

Commit ad5b98c

Browse files
authored
refactor: compile the selected_cols for the ResultNode (#1765)
1 parent 393425e commit ad5b98c

File tree

7 files changed

+106
-38
lines changed
  • bigframes/core/compile/sqlglot
  • tests/unit/core/compile/sqlglot/snapshots
    • test_compile_projection/test_compile_projection
    • test_compile_readlocal
      • test_compile_readlocal
      • test_compile_readlocal_w_json_df
      • test_compile_readlocal_w_lists_df
      • test_compile_readlocal_w_structs_df

7 files changed

+106
-38
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,14 @@ def _remap_variables(self, node: nodes.ResultNode) -> nodes.ResultNode:
120120

121121
def _compile_result_node(self, root: nodes.ResultNode) -> str:
122122
sqlglot_ir = self.compile_node(root.child)
123-
# TODO: add order_by, limit, and selections to sqlglot_expr
123+
124+
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
125+
(name, scalar_compiler.compile_scalar_expression(ref))
126+
for ref, name in root.output_cols
127+
)
128+
sqlglot_ir = sqlglot_ir.select(selected_cols)
129+
130+
# TODO: add order_by, limit to sqlglot_expr
124131
return sqlglot_ir.sql
125132

126133
@functools.lru_cache(maxsize=5000)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,22 @@ def select(
128128
self,
129129
selected_cols: tuple[tuple[str, sge.Expression], ...],
130130
) -> SQLGlotIR:
131-
cols_expr = [
131+
selections = [
132132
sge.Alias(
133133
this=expr,
134134
alias=sge.to_identifier(id, quoted=self.quoted),
135135
)
136136
for id, expr in selected_cols
137137
]
138-
new_expr = self._encapsulate_as_cte().select(*cols_expr, append=False)
139-
return SQLGlotIR(expr=new_expr)
138+
# Attempts to simplify selected columns when the original and new column
139+
# names are simply aliases of each other.
140+
squashed_selections = _squash_selections(self.expr.expressions, selections)
141+
if squashed_selections != []:
142+
new_expr = self.expr.select(*squashed_selections, append=False)
143+
return SQLGlotIR(expr=new_expr)
144+
else:
145+
new_expr = self._encapsulate_as_cte().select(*selections, append=False)
146+
return SQLGlotIR(expr=new_expr)
140147

141148
def project(
142149
self,
@@ -199,7 +206,7 @@ def _encapsulate_as_cte(
199206
this=select_expr,
200207
alias=new_cte_name,
201208
)
202-
new_with_clause = sge.With(expressions=existing_ctes + [new_cte])
209+
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
203210
new_select_expr = (
204211
sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name))
205212
)
@@ -254,3 +261,62 @@ def _table(table: bigquery.TableReference) -> sge.Table:
254261
db=sg.to_identifier(table.dataset_id, quoted=True),
255262
catalog=sg.to_identifier(table.project, quoted=True),
256263
)
264+
265+
266+
def _squash_selections(
267+
old_expr: list[sge.Expression], new_expr: list[sge.Alias]
268+
) -> list[sge.Alias]:
269+
"""
270+
Simplifies the select column expressions if existing (old_expr) and
271+
new (new_expr) selected columns are both simple aliases of column definitions.
272+
273+
Example:
274+
old_expr: [A AS X, B AS Y]
275+
new_expr: [X AS P, Y AS Q]
276+
Result: [A AS P, B AS Q]
277+
"""
278+
old_alias_map: typing.Dict[str, str] = {}
279+
for selected in old_expr:
280+
column_alias_pair = _get_column_alias_pair(selected)
281+
if column_alias_pair is None:
282+
return []
283+
else:
284+
old_alias_map[column_alias_pair[1]] = column_alias_pair[0]
285+
286+
new_selected_cols: typing.List[sge.Alias] = []
287+
for selected in new_expr:
288+
column_alias_pair = _get_column_alias_pair(selected)
289+
if column_alias_pair is None or column_alias_pair[0] not in old_alias_map:
290+
return []
291+
else:
292+
new_alias_expr = sge.Alias(
293+
this=sge.ColumnDef(
294+
this=sge.to_identifier(
295+
old_alias_map[column_alias_pair[0]], quoted=True
296+
)
297+
),
298+
alias=sg.to_identifier(column_alias_pair[1], quoted=True),
299+
)
300+
new_selected_cols.append(new_alias_expr)
301+
return new_selected_cols
302+
303+
304+
def _get_column_alias_pair(
305+
expr: sge.Expression,
306+
) -> typing.Optional[typing.Tuple[str, str]]:
307+
"""Checks if an expression is a simple alias of a column definition
308+
(e.g., "column_name AS alias_name").
309+
If it is, returns a tuple containing the alias name and original column name.
310+
Returns `None` otherwise.
311+
"""
312+
if not isinstance(expr, sge.Alias):
313+
return None
314+
if not isinstance(expr.this, sge.ColumnDef):
315+
return None
316+
317+
column_def_expr: sge.ColumnDef = expr.this
318+
if not isinstance(column_def_expr.this, sge.Identifier):
319+
return None
320+
321+
original_identifier: sge.Identifier = column_def_expr.this
322+
return (original_identifier.this, expr.alias)

tests/unit/core/compile/sqlglot/snapshots/test_compile_projection/test_compile_projection/out.sql

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ WITH `bfcte_0` AS (
66
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` INT64, `bfcol_1` INT64, `bfcol_2` INT64>>[STRUCT(0, 123456789, 0), STRUCT(1, -987654321, 1), STRUCT(2, 314159, 2), STRUCT(3, CAST(NULL AS INT64), 3), STRUCT(4, -234892, 4), STRUCT(5, 55555, 5), STRUCT(6, 101202303, 6), STRUCT(7, -214748367, 7), STRUCT(8, 2, 8)])
77
)
88
SELECT
9-
`bfcol_3` AS `bfcol_5`,
10-
`bfcol_4` AS `bfcol_6`,
11-
`bfcol_2` AS `bfcol_7`
9+
`bfcol_3` AS `rowindex`,
10+
`bfcol_4` AS `int64_col`
1211
FROM `bfcte_0`

tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,19 @@ WITH `bfcte_0` AS (
157157
)])
158158
)
159159
SELECT
160-
`bfcol_0` AS `bfcol_16`,
161-
`bfcol_1` AS `bfcol_17`,
162-
`bfcol_2` AS `bfcol_18`,
163-
`bfcol_3` AS `bfcol_19`,
164-
`bfcol_4` AS `bfcol_20`,
165-
`bfcol_5` AS `bfcol_21`,
166-
`bfcol_6` AS `bfcol_22`,
167-
`bfcol_7` AS `bfcol_23`,
168-
`bfcol_8` AS `bfcol_24`,
169-
`bfcol_9` AS `bfcol_25`,
170-
`bfcol_10` AS `bfcol_26`,
171-
`bfcol_11` AS `bfcol_27`,
172-
`bfcol_12` AS `bfcol_28`,
173-
`bfcol_13` AS `bfcol_29`,
174-
`bfcol_14` AS `bfcol_30`,
175-
`bfcol_15` AS `bfcol_31`
160+
`bfcol_0` AS `rowindex`,
161+
`bfcol_1` AS `bool_col`,
162+
`bfcol_2` AS `bytes_col`,
163+
`bfcol_3` AS `date_col`,
164+
`bfcol_4` AS `datetime_col`,
165+
`bfcol_5` AS `geography_col`,
166+
`bfcol_6` AS `int64_col`,
167+
`bfcol_7` AS `int64_too`,
168+
`bfcol_8` AS `numeric_col`,
169+
`bfcol_9` AS `float64_col`,
170+
`bfcol_10` AS `rowindex_1`,
171+
`bfcol_11` AS `rowindex_2`,
172+
`bfcol_12` AS `string_col`,
173+
`bfcol_13` AS `time_col`,
174+
`bfcol_14` AS `timestamp_col`
176175
FROM `bfcte_0`

tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_json_df/out.sql

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ WITH `bfcte_0` AS (
44
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` JSON, `bfcol_1` INT64>>[STRUCT(PARSE_JSON('null'), 0), STRUCT(PARSE_JSON('true'), 1), STRUCT(PARSE_JSON('100'), 2), STRUCT(PARSE_JSON('0.98'), 3), STRUCT(PARSE_JSON('"a string"'), 4), STRUCT(PARSE_JSON('[]'), 5), STRUCT(PARSE_JSON('[1,2,3]'), 6), STRUCT(PARSE_JSON('[{"a":1},{"a":2},{"a":null},{}]'), 7), STRUCT(PARSE_JSON('"100"'), 8), STRUCT(PARSE_JSON('{"date":"2024-07-16"}'), 9), STRUCT(PARSE_JSON('{"int_value":2,"null_filed":null}'), 10), STRUCT(PARSE_JSON('{"list_data":[10,20,30]}'), 11)])
55
)
66
SELECT
7-
`bfcol_0` AS `bfcol_2`,
8-
`bfcol_1` AS `bfcol_3`
7+
`bfcol_0` AS `json_col`
98
FROM `bfcte_0`

tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_lists_df/out.sql

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,12 @@ WITH `bfcte_0` AS (
3434
)])
3535
)
3636
SELECT
37-
`bfcol_0` AS `bfcol_9`,
38-
`bfcol_1` AS `bfcol_10`,
39-
`bfcol_2` AS `bfcol_11`,
40-
`bfcol_3` AS `bfcol_12`,
41-
`bfcol_4` AS `bfcol_13`,
42-
`bfcol_5` AS `bfcol_14`,
43-
`bfcol_6` AS `bfcol_15`,
44-
`bfcol_7` AS `bfcol_16`,
45-
`bfcol_8` AS `bfcol_17`
37+
`bfcol_0` AS `rowindex`,
38+
`bfcol_1` AS `int_list_col`,
39+
`bfcol_2` AS `bool_list_col`,
40+
`bfcol_3` AS `float_list_col`,
41+
`bfcol_4` AS `date_list_col`,
42+
`bfcol_5` AS `date_time_list_col`,
43+
`bfcol_6` AS `numeric_list_col`,
44+
`bfcol_7` AS `string_list_col`
4645
FROM `bfcte_0`

tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_structs_df/out.sql

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ WITH `bfcte_0` AS (
2020
)])
2121
)
2222
SELECT
23-
`bfcol_0` AS `bfcol_3`,
24-
`bfcol_1` AS `bfcol_4`,
25-
`bfcol_2` AS `bfcol_5`
23+
`bfcol_0` AS `id`,
24+
`bfcol_1` AS `person`
2625
FROM `bfcte_0`

0 commit comments

Comments
 (0)