Skip to content

Commit 8ac275a

Browse files
committed
CSHARP-5459: Standardize on using AstExpression.RootVar and Context.CreateRootSymbol.
1 parent c8f0429 commit 8ac275a

21 files changed

+47
-60
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpressionExtensions.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ public static bool IsInt32Constant(this AstExpression expression, out int value)
3535
public static bool IsMaxInt32(this AstExpression expression)
3636
=> expression.IsInt32Constant(out var value) && value == int.MaxValue;
3737

38+
public static bool IsRootVar(this AstExpression expression)
39+
=> expression is AstVarExpression varExpression && varExpression.Name == "ROOT" && varExpression.IsCurrent;
40+
3841
public static bool IsZero(this AstExpression expression)
3942
=> expression is AstConstantExpression constantExpression && constantExpression.Value == 0;
4043
}

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@ static bool ProjectsRoot(AstProjectStage projectStage)
181181
return projectStage.Specifications.Any(
182182
specification =>
183183
specification is AstProjectStageSetFieldSpecification setFieldSpecification &&
184-
setFieldSpecification.Value is AstVarExpression varExpression &&
185-
varExpression.Name == "ROOT");
184+
setFieldSpecification.Value.IsRootVar());
186185
}
187186
}
188187
}
@@ -370,14 +369,12 @@ public override AstNode VisitMapExpression(AstMapExpression node)
370369
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
371370
mapInputconstantFieldExpression.Value.IsString &&
372371
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
373-
mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression &&
374-
mapInputGetFieldVarExpression.Name == "ROOT")
372+
mapInputGetFieldExpression.Input.IsRootVar())
375373
{
376374
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element));
377375
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Push, rewrittenArg);
378376
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
379-
var root = AstExpression.Var("ROOT", isCurrent: true);
380-
return AstExpression.GetField(root, accumulatorFieldName);
377+
return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
381378
}
382379

383380
return base.VisitMapExpression(node);
@@ -388,8 +385,7 @@ public override AstNode VisitPickExpression(AstPickExpression node)
388385
// { $pickOperator : { source : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", sortBy : s, selector : f(x) } }
389386
// => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element) } } } + "$__agg0"
390387
if (node.Source is AstGetFieldExpression getFieldExpression &&
391-
getFieldExpression.Input is AstVarExpression varExpression &&
392-
varExpression.Name == "ROOT" &&
388+
getFieldExpression.Input.IsRootVar() &&
393389
getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpression &&
394390
constantFieldNameExpression.Value.IsString &&
395391
constantFieldNameExpression.Value.AsString == "_elements")
@@ -398,17 +394,14 @@ getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpressio
398394
var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element));
399395
var accumulatorExpression = new AstPickAccumulatorExpression(@operator, node.SortBy, rewrittenSelector, node.N);
400396
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
401-
var root = AstExpression.Var("ROOT", isCurrent: true);
402-
return AstExpression.GetField(root, accumulatorFieldName);
397+
return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
403398
}
404399

405400
return base.VisitPickExpression(node);
406401
}
407402

408403
public override AstNode VisitUnaryExpression(AstUnaryExpression node)
409404
{
410-
var root = AstExpression.Var("ROOT", isCurrent: true);
411-
412405
if (TryOptimizeSizeOfElements(out var optimizedExpression))
413406
{
414407
return optimizedExpression;
@@ -438,7 +431,7 @@ argGetFieldExpression.FieldName is AstConstantExpression constantFieldNameExpres
438431
{
439432
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Sum, 1);
440433
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
441-
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);
434+
optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
442435
return true;
443436
}
444437
}
@@ -455,12 +448,11 @@ node.Arg is AstGetFieldExpression getFieldExpression &&
455448
getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameExpression &&
456449
getFieldConstantFieldNameExpression.Value.IsString &&
457450
getFieldConstantFieldNameExpression.Value == "_elements" &&
458-
getFieldExpression.Input is AstVarExpression getFieldInputVarExpression &&
459-
getFieldInputVarExpression.Name == "ROOT")
451+
getFieldExpression.Input.IsRootVar())
460452
{
461453
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element);
462454
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
463-
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);
455+
optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
464456
return true;
465457
}
466458

@@ -478,13 +470,12 @@ mapExpression.Input is AstGetFieldExpression mapInputGetFieldExpression &&
478470
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
479471
mapInputconstantFieldExpression.Value.IsString &&
480472
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
481-
mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression &&
482-
mapInputGetFieldVarExpression.Name == "ROOT")
473+
mapInputGetFieldExpression.Input.IsRootVar())
483474
{
484475
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element));
485476
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg);
486477
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
487-
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);
478+
optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
488479
return true;
489480
}
490481

src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ protected override AstStage RenderGroupingStage(
112112
var valueSerializer = (IBsonSerializer<TValue>)groupByTranslation.Serializer;
113113
var serializedBoundaries = SerializationHelper.SerializeValues(valueSerializer, _boundaries);
114114
var serializedDefault = _options != null && _options.DefaultBucket.HasValue ? SerializationHelper.SerializeValue(valueSerializer, _options.DefaultBucket.Value) : null;
115-
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.Var("ROOT", isCurrent: true));
115+
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar);
116116
groupingOutputSerializer = IGroupingSerializer.Create(valueSerializer, inputSerializer);
117117

118118
return AstStage.Bucket(
@@ -156,7 +156,7 @@ protected override AstStage RenderGroupingStage(
156156
var valueSerializer = (IBsonSerializer<TValue>)groupByTranslation.Serializer;
157157
var keySerializer = AggregateBucketAutoResultIdSerializer.Create(valueSerializer);
158158
var serializedGranularity = _options != null && _options.Granularity.HasValue ? _options.Granularity.Value.Value : null;
159-
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.Var("ROOT", isCurrent: true));
159+
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar);
160160
groupingOutputSerializer = IGroupingSerializer.Create(keySerializer, inputSerializer);
161161

162162
return AstStage.BucketAuto(
@@ -190,7 +190,7 @@ protected override AstStage RenderGroupingStage(
190190
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
191191
var context = TranslationContext.Create(partiallyEvaluatedGroupBy, translationOptions);
192192
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);
193-
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.Var("ROOT", isCurrent: true));
193+
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar);
194194
var groupBySerializer = (IBsonSerializer<TValue>)groupByTranslation.Serializer;
195195
groupingOutputSerializer = IGroupingSerializer.Create(groupBySerializer, inputSerializer);
196196

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LambdaExpressionExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public static string TranslateToDottedFieldName(this LambdaExpression fieldSelec
3838
{
3939
throw new ArgumentException($"ValueType '{parameterSerializer.ValueType.FullName}' of parameterSerializer does not match parameter type '{parameterExpression.Type.FullName}'.", nameof(parameterSerializer));
4040
}
41-
var parameterSymbol = context.CreateSymbolWithVarName(parameterExpression, varName: "ROOT", parameterSerializer, isCurrent: true);
41+
var parameterSymbol = context.CreateRootSymbol(parameterExpression, parameterSerializer);
4242
var lambdaContext = context.WithSymbol(parameterSymbol);
4343
var lambdaBody = ConvertHelper.RemoveConvertToObject(fieldSelectorLambda.Body);
4444
var fieldSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(lambdaContext, lambdaBody);

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ProjectionHelper.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ private static (string, IBsonSerializer) CreateGetFieldChainWithSafeFieldNamesPr
100100
var wrappedValueSerializer = WrappedValueSerializer.Create(fieldName, serializer);
101101
var input = getFieldExpression.Input;
102102

103-
if (input is AstVarExpression varExpression && varExpression.Name == "ROOT")
103+
if (input.IsRootVar())
104104
{
105105
return (fieldName, wrappedValueSerializer);
106106
}
@@ -132,7 +132,7 @@ private static bool IsGetFieldChainWithSafeFieldNames(AstGetFieldExpression getF
132132
return
133133
getFieldExpression.HasSafeFieldName(out _) &&
134134
(
135-
(getFieldExpression.Input is AstVarExpression varExpression && varExpression.Name == "ROOT") ||
135+
(getFieldExpression.Input.IsRootVar()) ||
136136
(getFieldExpression.Input is AstGetFieldExpression nestedGetFieldExpression && IsGetFieldChainWithSafeFieldNames(nestedGetFieldExpression))
137137
);
138138
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public static AggregationExpression TranslateLambdaBody(
122122
}
123123
var parameterSymbol =
124124
asRoot ?
125-
context.CreateSymbolWithVarName(parameterExpression, varName: "ROOT", parameterSerializer, isCurrent: true) :
125+
context.CreateRootSymbol(parameterExpression, parameterSerializer) :
126126
context.CreateSymbol(parameterExpression, parameterSerializer, isCurrent: false);
127127

128128
return TranslateLambdaBody(context, lambdaExpression, parameterSymbol);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PickMethodToAggregationExpressionTranslator.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,18 +260,16 @@ private static bool IsGroupingSource(AstExpression source)
260260
{
261261
return
262262
source is AstGetFieldExpression getFieldExpression &&
263-
getFieldExpression.Input is AstVarExpression inputVarExpression &&
264-
inputVarExpression.Name == "ROOT" &&
263+
getFieldExpression.Input.IsRootVar() &&
265264
getFieldExpression.FieldName is AstConstantExpression fieldNameConstantExpression &&
266265
fieldNameConstantExpression.Value == "_elements";
267266
}
268267

269268
private static bool IsValidKey(AggregationExpression keyTranslation)
270269
{
271270
if (keyTranslation.Ast is AstGetFieldExpression getFieldExpression &&
272-
getFieldExpression.Input is AstVarExpression inputVarExpression &&
271+
getFieldExpression.Input.IsRootVar() &&
273272
getFieldExpression.FieldName is AstConstantExpression constantFieldName &&
274-
inputVarExpression.Name == "ROOT" &&
275273
constantFieldName.Value == "_id")
276274
{
277275
return true;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AverageMethodToExecutableQueryTranslator.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
133133
else
134134
{
135135
Ensure.That(sourceSerializer is IWrappedValueSerializer, "Expected sourceSerializer to be an IWrappedValueSerializer.", nameof(sourceSerializer));
136-
var root = AstExpression.Var("ROOT", isCurrent: true);
137-
valueExpression = AstExpression.GetField(root, "_v");
136+
valueExpression = AstExpression.GetField(AstExpression.RootVar, "_v");
138137
}
139138

140139
IBsonSerializer outputValueSerializer = expression.GetResultType() switch

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ContainsMethodToExecutableQueryTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public static ExecutableQuery<TDocument, bool> Translate<TDocument>(MongoQueryPr
6363
wrappedValueSerializer,
6464
AstStage.Project(
6565
AstProject.ExcludeId(),
66-
AstProject.Set("_v", AstExpression.Var("ROOT"))));
66+
AstProject.Set("_v", AstExpression.RootVar)));
6767
}
6868

6969
var itemExpression = arguments[1];

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/LastMethodToExecutableQueryTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
7979
pipeline.OutputSerializer,
8080
AstStage.Group(
8181
id: BsonNull.Value,
82-
fields: AstExpression.AccumulatorField("_last", AstUnaryAccumulatorOperator.Last, AstExpression.Var("ROOT"))));
82+
fields: AstExpression.AccumulatorField("_last", AstUnaryAccumulatorOperator.Last, AstExpression.RootVar)));
8383

8484
var finalizer = method.Name == "LastOrDefault" ? __singleOrDefaultFinalizer : __singleFinalizer;
8585

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MaxMethodToExecutableQueryTranslator.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
6666
ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method);
6767

6868
var sourceSerializer = pipeline.OutputSerializer;
69-
var root = AstExpression.Var("ROOT", isCurrent: true);
7069
AstExpression valueAst;
7170
IBsonSerializer valueSerializer;
7271
if (method.IsOneOf(__maxWithSelectorMethods))
@@ -86,7 +85,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
8685
}
8786
else
8887
{
89-
valueAst = root;
88+
valueAst = AstExpression.RootVar;
9089
valueSerializer = pipeline.OutputSerializer;
9190
}
9291

@@ -95,7 +94,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
9594
AstStage.Group(
9695
id: BsonNull.Value,
9796
fields: AstExpression.AccumulatorField("_max", AstUnaryAccumulatorOperator.Max, valueAst)),
98-
AstStage.ReplaceRoot(AstExpression.GetField(root, "_max")));
97+
AstStage.ReplaceRoot(AstExpression.GetField(AstExpression.RootVar, "_max")));
9998

10099
return ExecutableQuery.Create(
101100
provider,

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MinMethodToExecutableQueryTranslator.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
6666
ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method);
6767

6868
var sourceSerializer = pipeline.OutputSerializer;
69-
var root = AstExpression.Var("ROOT", isCurrent: true);
7069
AstExpression valueAst;
7170
IBsonSerializer valueSerializer;
7271
if (method.IsOneOf(__minWithSelectorMethods))
@@ -86,7 +85,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
8685
}
8786
else
8887
{
89-
valueAst = root;
88+
valueAst = AstExpression.RootVar;
9089
valueSerializer = pipeline.OutputSerializer;
9190
}
9291

@@ -95,7 +94,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
9594
AstStage.Group(
9695
id: BsonNull.Value,
9796
fields: AstExpression.AccumulatorField("_min", AstUnaryAccumulatorOperator.Min, valueAst)),
98-
AstStage.ReplaceRoot(AstExpression.GetField(root, "_min")));
97+
AstStage.ReplaceRoot(AstExpression.GetField(AstExpression.RootVar, "_min")));
9998

10099
return ExecutableQuery.Create(
101100
provider,

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/StandardDeviationMethodToExecutableQueryTranslator.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
281281
}
282282
else
283283
{
284-
var root = AstExpression.Var("ROOT", isCurrent: true);
285-
valueAst = AstExpression.GetField(root, "_v");
284+
valueAst = AstExpression.GetField(AstExpression.RootVar, "_v");
286285
}
287286
var outputValueType = expression.GetResultType();
288287
var outputValueSerializer = BsonSerializer.LookupSerializer(outputValueType);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/SumMethodToExecutableQueryTranslator.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
115115
var arguments = expression.Arguments;
116116

117117
if (method.IsOneOf(__sumMethods))
118-
{
118+
{
119119
var sourceExpression = arguments[0];
120120
var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression);
121121
ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method);
@@ -131,8 +131,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
131131
else
132132
{
133133
Ensure.That(sourceSerializer is IWrappedValueSerializer, "Expected sourceSerializer to be an IWrappedValueSerializer.", nameof(sourceSerializer));
134-
var rootVar = AstExpression.Var("ROOT", isCurrent: true);
135-
valueAst = AstExpression.GetField(rootVar, "_v");
134+
valueAst = AstExpression.GetField(AstExpression.RootVar, "_v");
136135
}
137136

138137
var outputValueType = expression.GetResultType();

0 commit comments

Comments
 (0)