Skip to content

Commit 8e71b03

Browse files
authored
refactor: Cache dtypes for scalar expressions for SQLGlot compiler (#1759)
* feat: include bq schema and query string in dry run results * rename key * fix tests * refactor: cache dtypes for scalar expressions: * fix deref expr type resolution bug * add test * remove dry_run changes from another branch * remove more changes from dry_run PR * rename DeferredDtype to AbsentDtype * removed absentDtype and reuse bind_refs * use a separate resolver for fields * fix lint * move field resolutions to a separate function * update helper function name * update doc and function names * bind schema at compile time for SQLGlot compiler * define a separate expression for field reference
1 parent f9c29c8 commit 8e71b03

File tree

8 files changed

+320
-82
lines changed

8 files changed

+320
-82
lines changed

bigframes/core/bigframe_node.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import typing
2323
from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Set, Tuple
2424

25-
from bigframes.core import identifiers
25+
from bigframes.core import field, identifiers
2626
import bigframes.core.schema as schemata
2727
import bigframes.dtypes
2828

@@ -34,23 +34,6 @@
3434
T = typing.TypeVar("T")
3535

3636

37-
@dataclasses.dataclass(frozen=True)
38-
class Field:
39-
id: identifiers.ColumnId
40-
dtype: bigframes.dtypes.Dtype
41-
# Best effort, nullable=True if not certain
42-
nullable: bool = True
43-
44-
def with_nullable(self) -> Field:
45-
return Field(self.id, self.dtype, nullable=True)
46-
47-
def with_nonnull(self) -> Field:
48-
return Field(self.id, self.dtype, nullable=False)
49-
50-
def with_id(self, id: identifiers.ColumnId) -> Field:
51-
return Field(id, self.dtype, nullable=self.nullable)
52-
53-
5437
@dataclasses.dataclass(eq=False, frozen=True)
5538
class BigFrameNode:
5639
"""
@@ -162,7 +145,7 @@ def roots(self) -> typing.Set[BigFrameNode]:
162145
# TODO: Store some local data lazily for select, aggregate nodes.
163146
@property
164147
@abc.abstractmethod
165-
def fields(self) -> Sequence[Field]:
148+
def fields(self) -> Sequence[field.Field]:
166149
...
167150

168151
@property
@@ -292,7 +275,7 @@ def _dtype_lookup(self) -> dict[identifiers.ColumnId, bigframes.dtypes.Dtype]:
292275
return {field.id: field.dtype for field in self.fields}
293276

294277
@functools.cached_property
295-
def field_by_id(self) -> Mapping[identifiers.ColumnId, Field]:
278+
def field_by_id(self) -> Mapping[identifiers.ColumnId, field.Field]:
296279
return {field.id: field for field in self.fields}
297280

298281
# Plan algorithms

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2727
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2828
import bigframes.core.ordering as bf_ordering
29+
from bigframes.core.rewrite import schema_binding
2930

3031

3132
class SQLGlotCompiler:
@@ -183,6 +184,6 @@ def compile_projection(
183184

184185
def _replace_unsupported_ops(node: nodes.BigFrameNode):
185186
node = nodes.bottom_up(node, rewrite.rewrite_slice)
186-
node = nodes.bottom_up(node, rewrite.rewrite_timedelta_expressions)
187+
node = nodes.bottom_up(node, schema_binding.bind_schema_to_expressions)
187188
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
188189
return node

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression:
3535
return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True))
3636

3737

38+
@compile_scalar_expression.register
39+
def compile_field_ref_expression(
40+
expr: expression.SchemaFieldRefExpression,
41+
) -> sge.Expression:
42+
return sge.ColumnDef(this=sge.to_identifier(expr.field.id.sql, quoted=True))
43+
44+
3845
@compile_scalar_expression.register
3946
def compile_constant_expression(
4047
expr: expression.ScalarConstantExpression,

bigframes/core/expression.py

Lines changed: 130 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@
1616

1717
import abc
1818
import dataclasses
19+
import functools
1920
import itertools
2021
import typing
2122
from typing import Generator, Mapping, TypeVar, Union
2223

2324
import pandas as pd
2425

26+
from bigframes import dtypes
27+
from bigframes.core import field
2528
import bigframes.core.identifiers as ids
26-
import bigframes.dtypes as dtypes
2729
import bigframes.operations
2830
import bigframes.operations.aggregations as agg_ops
2931

@@ -50,7 +52,7 @@ class Aggregation(abc.ABC):
5052

5153
@abc.abstractmethod
5254
def output_type(
53-
self, input_types: dict[ids.ColumnId, dtypes.ExpressionType]
55+
self, input_fields: Mapping[ids.ColumnId, field.Field]
5456
) -> dtypes.ExpressionType:
5557
...
5658

@@ -72,7 +74,7 @@ class NullaryAggregation(Aggregation):
7274
op: agg_ops.NullaryWindowOp = dataclasses.field()
7375

7476
def output_type(
75-
self, input_types: dict[ids.ColumnId, bigframes.dtypes.Dtype]
77+
self, input_fields: Mapping[ids.ColumnId, field.Field]
7678
) -> dtypes.ExpressionType:
7779
return self.op.output_type()
7880

@@ -86,13 +88,17 @@ def remap_column_refs(
8688

8789
@dataclasses.dataclass(frozen=True)
8890
class UnaryAggregation(Aggregation):
89-
op: agg_ops.UnaryWindowOp = dataclasses.field()
90-
arg: Union[DerefOp, ScalarConstantExpression] = dataclasses.field()
91+
op: agg_ops.UnaryWindowOp
92+
arg: Union[DerefOp, ScalarConstantExpression]
9193

9294
def output_type(
93-
self, input_types: dict[ids.ColumnId, bigframes.dtypes.Dtype]
95+
self, input_fields: Mapping[ids.ColumnId, field.Field]
9496
) -> dtypes.ExpressionType:
95-
return self.op.output_type(self.arg.output_type(input_types))
97+
# TODO(b/419300717) Remove resolutions once defers are cleaned up.
98+
resolved_expr = bind_schema_fields(self.arg, input_fields)
99+
assert resolved_expr.is_resolved
100+
101+
return self.op.output_type(resolved_expr.output_type)
96102

97103
@property
98104
def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
@@ -118,10 +124,16 @@ class BinaryAggregation(Aggregation):
118124
right: Union[DerefOp, ScalarConstantExpression] = dataclasses.field()
119125

120126
def output_type(
121-
self, input_types: dict[ids.ColumnId, bigframes.dtypes.Dtype]
127+
self, input_fields: Mapping[ids.ColumnId, field.Field]
122128
) -> dtypes.ExpressionType:
129+
# TODO(b/419300717) Remove resolutions once defers are cleaned up.
130+
left_resolved_expr = bind_schema_fields(self.left, input_fields)
131+
assert left_resolved_expr.is_resolved
132+
right_resolved_expr = bind_schema_fields(self.right, input_fields)
133+
assert right_resolved_expr.is_resolved
134+
123135
return self.op.output_type(
124-
self.left.output_type(input_types), self.right.output_type(input_types)
136+
left_resolved_expr.output_type, left_resolved_expr.output_type
125137
)
126138

127139
@property
@@ -189,10 +201,17 @@ def remap_column_refs(
189201
def is_const(self) -> bool:
190202
...
191203

204+
@property
192205
@abc.abstractmethod
193-
def output_type(
194-
self, input_types: dict[ids.ColumnId, dtypes.ExpressionType]
195-
) -> dtypes.ExpressionType:
206+
def is_resolved(self) -> bool:
207+
"""
208+
Returns true if and only if the expression's output type and nullability is available.
209+
"""
210+
...
211+
212+
@property
213+
@abc.abstractmethod
214+
def output_type(self) -> dtypes.ExpressionType:
196215
...
197216

198217
@abc.abstractmethod
@@ -256,9 +275,12 @@ def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
256275
def nullable(self) -> bool:
257276
return pd.isna(self.value) # type: ignore
258277

259-
def output_type(
260-
self, input_types: dict[ids.ColumnId, bigframes.dtypes.Dtype]
261-
) -> dtypes.ExpressionType:
278+
@property
279+
def is_resolved(self) -> bool:
280+
return True
281+
282+
@property
283+
def output_type(self) -> dtypes.ExpressionType:
262284
return self.dtype
263285

264286
def bind_variables(
@@ -308,9 +330,12 @@ def is_const(self) -> bool:
308330
def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
309331
return ()
310332

311-
def output_type(
312-
self, input_types: dict[ids.ColumnId, bigframes.dtypes.Dtype]
313-
) -> dtypes.ExpressionType:
333+
@property
334+
def is_resolved(self):
335+
return False
336+
337+
@property
338+
def output_type(self) -> dtypes.ExpressionType:
314339
raise ValueError(f"Type of variable {self.id} has not been fixed.")
315340

316341
def bind_refs(
@@ -340,7 +365,7 @@ def is_identity(self) -> bool:
340365

341366
@dataclasses.dataclass(frozen=True)
342367
class DerefOp(Expression):
343-
"""A variable expression representing an unbound variable."""
368+
"""An expression that refers to a column by ID."""
344369

345370
id: ids.ColumnId
346371

@@ -357,13 +382,13 @@ def nullable(self) -> bool:
357382
# Safe default, need to actually bind input schema to determine
358383
return True
359384

360-
def output_type(
361-
self, input_types: dict[ids.ColumnId, bigframes.dtypes.Dtype]
362-
) -> dtypes.ExpressionType:
363-
if self.id in input_types:
364-
return input_types[self.id]
365-
else:
366-
raise ValueError(f"Type of variable {self.id} has not been fixed.")
385+
@property
386+
def is_resolved(self) -> bool:
387+
return False
388+
389+
@property
390+
def output_type(self) -> dtypes.ExpressionType:
391+
raise ValueError(f"Type of variable {self.id} has not been fixed.")
367392

368393
def bind_variables(
369394
self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False
@@ -390,6 +415,55 @@ def is_identity(self) -> bool:
390415
return True
391416

392417

418+
@dataclasses.dataclass(frozen=True)
419+
class SchemaFieldRefExpression(Expression):
420+
"""An expression representing a schema field. This is essentially a DerefOp with input schema bound."""
421+
422+
field: field.Field
423+
424+
@property
425+
def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
426+
return (self.field.id,)
427+
428+
@property
429+
def is_const(self) -> bool:
430+
return False
431+
432+
@property
433+
def nullable(self) -> bool:
434+
return self.field.nullable
435+
436+
@property
437+
def is_resolved(self) -> bool:
438+
return True
439+
440+
@property
441+
def output_type(self) -> dtypes.ExpressionType:
442+
return self.field.dtype
443+
444+
def bind_variables(
445+
self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False
446+
) -> Expression:
447+
return self
448+
449+
def bind_refs(
450+
self,
451+
bindings: Mapping[ids.ColumnId, Expression],
452+
allow_partial_bindings: bool = False,
453+
) -> Expression:
454+
if self.field.id in bindings.keys():
455+
return bindings[self.field.id]
456+
return self
457+
458+
@property
459+
def is_bijective(self) -> bool:
460+
return True
461+
462+
@property
463+
def is_identity(self) -> bool:
464+
return True
465+
466+
393467
@dataclasses.dataclass(frozen=True)
394468
class OpExpression(Expression):
395469
"""An expression representing a scalar operation applied to 1 or more argument sub-expressions."""
@@ -429,13 +503,18 @@ def nullable(self) -> bool:
429503
)
430504
return not null_free
431505

432-
def output_type(
433-
self, input_types: dict[ids.ColumnId, dtypes.ExpressionType]
434-
) -> dtypes.ExpressionType:
435-
operand_types = tuple(
436-
map(lambda x: x.output_type(input_types=input_types), self.inputs)
437-
)
438-
return self.op.output_type(*operand_types)
506+
@functools.cached_property
507+
def is_resolved(self) -> bool:
508+
return all(input.is_resolved for input in self.inputs)
509+
510+
@functools.cached_property
511+
def output_type(self) -> dtypes.ExpressionType:
512+
if not self.is_resolved:
513+
raise ValueError(f"Type of expression {self.op.name} has not been fixed.")
514+
515+
input_types = [input.output_type for input in self.inputs]
516+
517+
return self.op.output_type(*input_types)
439518

440519
def bind_variables(
441520
self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False
@@ -475,4 +554,22 @@ def deterministic(self) -> bool:
475554
)
476555

477556

557+
def bind_schema_fields(
558+
expr: Expression, field_by_id: Mapping[ids.ColumnId, field.Field]
559+
) -> Expression:
560+
"""
561+
Updates `DerefOp` expressions by replacing column IDs with actual schema fields(columns).
562+
563+
We can only deduct an expression's output type and nullability after binding schema fields to
564+
all its deref expressions.
565+
"""
566+
if expr.is_resolved:
567+
return expr
568+
569+
expr_by_id = {
570+
id: SchemaFieldRefExpression(field) for id, field in field_by_id.items()
571+
}
572+
return expr.bind_refs(expr_by_id)
573+
574+
478575
RefOrConstant = Union[DerefOp, ScalarConstantExpression]

bigframes/core/field.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
19+
from bigframes import dtypes
20+
from bigframes.core import identifiers
21+
22+
23+
@dataclasses.dataclass(frozen=True)
24+
class Field:
25+
id: identifiers.ColumnId
26+
dtype: dtypes.Dtype
27+
# Best effort, nullable=True if not certain
28+
nullable: bool = True
29+
30+
def with_nullable(self) -> Field:
31+
return Field(self.id, self.dtype, nullable=True)
32+
33+
def with_nonnull(self) -> Field:
34+
return Field(self.id, self.dtype, nullable=False)
35+
36+
def with_id(self, id: identifiers.ColumnId) -> Field:
37+
return Field(id, self.dtype, nullable=self.nullable)

0 commit comments

Comments
 (0)