16
16
17
17
import abc
18
18
import dataclasses
19
+ import functools
19
20
import itertools
20
21
import typing
21
22
from typing import Generator , Mapping , TypeVar , Union
22
23
23
24
import pandas as pd
24
25
26
+ from bigframes import dtypes
27
+ from bigframes .core import field
25
28
import bigframes .core .identifiers as ids
26
- import bigframes .dtypes as dtypes
27
29
import bigframes .operations
28
30
import bigframes .operations .aggregations as agg_ops
29
31
@@ -50,7 +52,7 @@ class Aggregation(abc.ABC):
50
52
51
53
@abc .abstractmethod
52
54
def output_type (
53
- self , input_types : dict [ids .ColumnId , dtypes . ExpressionType ]
55
+ self , input_fields : Mapping [ids .ColumnId , field . Field ]
54
56
) -> dtypes .ExpressionType :
55
57
...
56
58
@@ -72,7 +74,7 @@ class NullaryAggregation(Aggregation):
72
74
op : agg_ops .NullaryWindowOp = dataclasses .field ()
73
75
74
76
def output_type (
75
- self , input_types : dict [ids .ColumnId , bigframes . dtypes . Dtype ]
77
+ self , input_fields : Mapping [ids .ColumnId , field . Field ]
76
78
) -> dtypes .ExpressionType :
77
79
return self .op .output_type ()
78
80
@@ -86,13 +88,17 @@ def remap_column_refs(
86
88
87
89
@dataclasses .dataclass (frozen = True )
88
90
class UnaryAggregation (Aggregation ):
89
- op : agg_ops .UnaryWindowOp = dataclasses . field ()
90
- arg : Union [DerefOp , ScalarConstantExpression ] = dataclasses . field ()
91
+ op : agg_ops .UnaryWindowOp
92
+ arg : Union [DerefOp , ScalarConstantExpression ]
91
93
92
94
def output_type (
93
- self , input_types : dict [ids .ColumnId , bigframes . dtypes . Dtype ]
95
+ self , input_fields : Mapping [ids .ColumnId , field . Field ]
94
96
) -> dtypes .ExpressionType :
95
- return self .op .output_type (self .arg .output_type (input_types ))
97
+ # TODO(b/419300717) Remove resolutions once defers are cleaned up.
98
+ resolved_expr = bind_schema_fields (self .arg , input_fields )
99
+ assert resolved_expr .is_resolved
100
+
101
+ return self .op .output_type (resolved_expr .output_type )
96
102
97
103
@property
98
104
def column_references (self ) -> typing .Tuple [ids .ColumnId , ...]:
@@ -118,10 +124,16 @@ class BinaryAggregation(Aggregation):
118
124
right : Union [DerefOp , ScalarConstantExpression ] = dataclasses .field ()
119
125
120
126
def output_type (
121
- self , input_types : dict [ids .ColumnId , bigframes . dtypes . Dtype ]
127
+ self , input_fields : Mapping [ids .ColumnId , field . Field ]
122
128
) -> dtypes .ExpressionType :
129
+ # TODO(b/419300717) Remove resolutions once defers are cleaned up.
130
+ left_resolved_expr = bind_schema_fields (self .left , input_fields )
131
+ assert left_resolved_expr .is_resolved
132
+ right_resolved_expr = bind_schema_fields (self .right , input_fields )
133
+ assert right_resolved_expr .is_resolved
134
+
123
135
return self .op .output_type (
124
- self . left . output_type ( input_types ), self . right . output_type ( input_types )
136
+ left_resolved_expr . output_type , left_resolved_expr . output_type
125
137
)
126
138
127
139
@property
@@ -189,10 +201,17 @@ def remap_column_refs(
189
201
def is_const (self ) -> bool :
190
202
...
191
203
204
+ @property
192
205
@abc .abstractmethod
193
- def output_type (
194
- self , input_types : dict [ids .ColumnId , dtypes .ExpressionType ]
195
- ) -> dtypes .ExpressionType :
206
+ def is_resolved (self ) -> bool :
207
+ """
208
+ Returns true if and only if the expression's output type and nullability is available.
209
+ """
210
+ ...
211
+
212
+ @property
213
+ @abc .abstractmethod
214
+ def output_type (self ) -> dtypes .ExpressionType :
196
215
...
197
216
198
217
@abc .abstractmethod
@@ -256,9 +275,12 @@ def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
256
275
def nullable (self ) -> bool :
257
276
return pd .isna (self .value ) # type: ignore
258
277
259
- def output_type (
260
- self , input_types : dict [ids .ColumnId , bigframes .dtypes .Dtype ]
261
- ) -> dtypes .ExpressionType :
278
+ @property
279
+ def is_resolved (self ) -> bool :
280
+ return True
281
+
282
+ @property
283
+ def output_type (self ) -> dtypes .ExpressionType :
262
284
return self .dtype
263
285
264
286
def bind_variables (
@@ -308,9 +330,12 @@ def is_const(self) -> bool:
308
330
def column_references (self ) -> typing .Tuple [ids .ColumnId , ...]:
309
331
return ()
310
332
311
- def output_type (
312
- self , input_types : dict [ids .ColumnId , bigframes .dtypes .Dtype ]
313
- ) -> dtypes .ExpressionType :
333
+ @property
334
+ def is_resolved (self ):
335
+ return False
336
+
337
+ @property
338
+ def output_type (self ) -> dtypes .ExpressionType :
314
339
raise ValueError (f"Type of variable { self .id } has not been fixed." )
315
340
316
341
def bind_refs (
@@ -340,7 +365,7 @@ def is_identity(self) -> bool:
340
365
341
366
@dataclasses .dataclass (frozen = True )
342
367
class DerefOp (Expression ):
343
- """A variable expression representing an unbound variable ."""
368
+ """An expression that refers to a column by ID ."""
344
369
345
370
id : ids .ColumnId
346
371
@@ -357,13 +382,13 @@ def nullable(self) -> bool:
357
382
# Safe default, need to actually bind input schema to determine
358
383
return True
359
384
360
- def output_type (
361
- self , input_types : dict [ ids . ColumnId , bigframes . dtypes . Dtype ]
362
- ) -> dtypes . ExpressionType :
363
- if self . id in input_types :
364
- return input_types [ self . id ]
365
- else :
366
- raise ValueError (f"Type of variable { self .id } has not been fixed." )
385
+ @ property
386
+ def is_resolved ( self ) -> bool :
387
+ return False
388
+
389
+ @ property
390
+ def output_type ( self ) -> dtypes . ExpressionType :
391
+ raise ValueError (f"Type of variable { self .id } has not been fixed." )
367
392
368
393
def bind_variables (
369
394
self , bindings : Mapping [str , Expression ], allow_partial_bindings : bool = False
@@ -390,6 +415,55 @@ def is_identity(self) -> bool:
390
415
return True
391
416
392
417
418
+ @dataclasses .dataclass (frozen = True )
419
+ class SchemaFieldRefExpression (Expression ):
420
+ """An expression representing a schema field. This is essentially a DerefOp with input schema bound."""
421
+
422
+ field : field .Field
423
+
424
+ @property
425
+ def column_references (self ) -> typing .Tuple [ids .ColumnId , ...]:
426
+ return (self .field .id ,)
427
+
428
+ @property
429
+ def is_const (self ) -> bool :
430
+ return False
431
+
432
+ @property
433
+ def nullable (self ) -> bool :
434
+ return self .field .nullable
435
+
436
+ @property
437
+ def is_resolved (self ) -> bool :
438
+ return True
439
+
440
+ @property
441
+ def output_type (self ) -> dtypes .ExpressionType :
442
+ return self .field .dtype
443
+
444
+ def bind_variables (
445
+ self , bindings : Mapping [str , Expression ], allow_partial_bindings : bool = False
446
+ ) -> Expression :
447
+ return self
448
+
449
+ def bind_refs (
450
+ self ,
451
+ bindings : Mapping [ids .ColumnId , Expression ],
452
+ allow_partial_bindings : bool = False ,
453
+ ) -> Expression :
454
+ if self .field .id in bindings .keys ():
455
+ return bindings [self .field .id ]
456
+ return self
457
+
458
+ @property
459
+ def is_bijective (self ) -> bool :
460
+ return True
461
+
462
+ @property
463
+ def is_identity (self ) -> bool :
464
+ return True
465
+
466
+
393
467
@dataclasses .dataclass (frozen = True )
394
468
class OpExpression (Expression ):
395
469
"""An expression representing a scalar operation applied to 1 or more argument sub-expressions."""
@@ -429,13 +503,18 @@ def nullable(self) -> bool:
429
503
)
430
504
return not null_free
431
505
432
- def output_type (
433
- self , input_types : dict [ids .ColumnId , dtypes .ExpressionType ]
434
- ) -> dtypes .ExpressionType :
435
- operand_types = tuple (
436
- map (lambda x : x .output_type (input_types = input_types ), self .inputs )
437
- )
438
- return self .op .output_type (* operand_types )
506
+ @functools .cached_property
507
+ def is_resolved (self ) -> bool :
508
+ return all (input .is_resolved for input in self .inputs )
509
+
510
+ @functools .cached_property
511
+ def output_type (self ) -> dtypes .ExpressionType :
512
+ if not self .is_resolved :
513
+ raise ValueError (f"Type of expression { self .op .name } has not been fixed." )
514
+
515
+ input_types = [input .output_type for input in self .inputs ]
516
+
517
+ return self .op .output_type (* input_types )
439
518
440
519
def bind_variables (
441
520
self , bindings : Mapping [str , Expression ], allow_partial_bindings : bool = False
@@ -475,4 +554,22 @@ def deterministic(self) -> bool:
475
554
)
476
555
477
556
557
+ def bind_schema_fields (
558
+ expr : Expression , field_by_id : Mapping [ids .ColumnId , field .Field ]
559
+ ) -> Expression :
560
+ """
561
+ Updates `DerefOp` expressions by replacing column IDs with actual schema fields(columns).
562
+
563
+ We can only deduct an expression's output type and nullability after binding schema fields to
564
+ all its deref expressions.
565
+ """
566
+ if expr .is_resolved :
567
+ return expr
568
+
569
+ expr_by_id = {
570
+ id : SchemaFieldRefExpression (field ) for id , field in field_by_id .items ()
571
+ }
572
+ return expr .bind_refs (expr_by_id )
573
+
574
+
478
575
RefOrConstant = Union [DerefOp , ScalarConstantExpression ]
0 commit comments