From 4b673872b88fa98b3f7d57574014b8ea49ff6888 Mon Sep 17 00:00:00 2001 From: Oleksandr Poliakov Date: Tue, 18 Mar 2025 17:06:23 -0700 Subject: [PATCH 1/4] CSHARP-4779: Support Dictionary(IEnumerable> collection) constructor in LINQ3 --- .../Ast/Optimizers/AstSimplifier.cs | 57 +++++++ .../Reflection/DictionaryConstructor.cs | 34 ++++ ...essionToAggregationExpressionTranslator.cs | 104 ++++++++++++ ...essionToAggregationExpressionTranslator.cs | 4 + ...nToAggregationExpressionTranslatorTests.cs | 155 ++++++++++++++++++ 5 files changed, 354 insertions(+) create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs create mode 100644 tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs index 206fd8b308c..5b1fd53d478 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs @@ -14,6 +14,7 @@ */ using System; +using System.Collections.Generic; using System.Linq; using MongoDB.Bson; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; @@ -454,8 +455,42 @@ public override AstNode VisitMapExpression(AstMapExpression node) } } + if (node.In is AstComputedDocumentExpression inComputedDocumentExpression && + inComputedDocumentExpression.Fields.All(f => f.Value is AstGetFieldExpression getFieldExpression && getFieldExpression.Input == node.As && getFieldExpression.CanBeConvertedToFieldPath())) + { + + // { $map : { input : { $map : { input : , as : "y", in : { A : "$$y.FieldA" } } }, as: "v", in : { B : '$$v.A' } } } => { $map : { input : , as: "v", in : { B : "$$v.FieldA" } } } + if (node.Input is AstMapExpression inputMapExpression && + inputMapExpression.In is AstComputedDocumentExpression innerInComputedDocumentExpression) + { + var simplified = AstExpression.Map( + inputMapExpression.Input, + inputMapExpression.As, + AstExpression.ComputedDocument(inComputedDocumentExpression.Fields.Select(f => RemapField(f, node.As.Name, innerInComputedDocumentExpression.Fields)))); + + return Visit(simplified); + } + + // { $map : { input : [{ A: "$$ROOT.FieldA" }], as : "v", in: { B : "$$v.A" } } } => [{ B : "$FieldA }] + if (node.Input is AstComputedArrayExpression inputArrayExpression && + inputArrayExpression.Items.All(i => i is AstComputedDocumentExpression)) + { + var simplified = AstExpression.ComputedArray(inputArrayExpression.Items.Select(i => + AstExpression.ComputedDocument(inComputedDocumentExpression.Fields.Select(f => RemapField(f, node.As.Name, ((AstComputedDocumentExpression)i).Fields))))); + return Visit(simplified); + } + } + return base.VisitMapExpression(node); + static AstComputedField RemapField(AstComputedField field, string @as, IEnumerable innerFields) + { + var fieldPath = ((AstGetFieldExpression)field.Value).ConvertToFieldPath().Replace($"$${@as}.", string.Empty); + var innerField = innerFields.Single(f => f.Path == fieldPath); + + return AstExpression.ComputedField(field.Path, innerField.Value); + } + static AstExpression UltimateGetFieldInput(AstGetFieldExpression getField) { if (getField.Input is AstGetFieldExpression nestedInputGetField) @@ -574,7 +609,29 @@ arg is AstBinaryExpression argBinaryExpression && return AstExpression.Binary(oppositeComparisonOperator, argBinaryExpression.Arg1, argBinaryExpression.Arg2); } + // { $arrayToObject : [[{ k : 'A', v : '$A' }, { k : 'B', v : '$B' }]] } => { A : '$A', B : '$B' } + if (node.Operator is AstUnaryOperator.ArrayToObject && + arg is AstComputedArrayExpression computedArrayExpression && + computedArrayExpression.Items.All( + i => i is AstComputedDocumentExpression computedDocumentExpression && + computedDocumentExpression.Fields.FirstOrDefault(f => f.Path == "k")?.Value is AstConstantExpression && + computedDocumentExpression.Fields.Any(f => f.Path == "v")) + ) + { + var fields = computedArrayExpression.Items.Select(KeyValuePairDocumentToComputedField); + return AstExpression.ComputedDocument(fields); + } + return node.Update(arg); + + static AstComputedField KeyValuePairDocumentToComputedField(AstExpression expression) + { + var documentExpression = (AstComputedDocumentExpression)expression; + var keyExpression = documentExpression.Fields.First(f => f.Path == "k").Value; + var valueExpression = documentExpression.Fields.First(f => f.Path == "v").Value; + + return AstExpression.ComputedField(((AstConstantExpression)keyExpression).Value.AsString, valueExpression); + } } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs new file mode 100644 index 00000000000..37be3984270 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs @@ -0,0 +1,34 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class DictionaryConstructor + { + // public static methods + public static bool IsIEnumerableKeyValuePairConstructor(ConstructorInfo ctor) + { + var parameters = ctor.GetParameters(); + return parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var enumerableType) && + enumerableType.IsConstructedGenericType && + enumerableType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs new file mode 100644 index 00000000000..72dd19ccb18 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs @@ -0,0 +1,104 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators +{ + internal static class NewDictionaryExpressionToAggregationExpressionTranslator + { + public static TranslatedExpression Translate(TranslationContext context, NewExpression expression) + { + var arguments = expression.Arguments; + var collectionExpression = arguments.Single(); + var collectionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, collectionExpression); + + if (collectionTranslation.Serializer is IBsonArraySerializer bsonArraySerializer && + bsonArraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) + { + IBsonSerializer keySerializer = null; + IBsonSerializer valueSerializer = null; + AstExpression collectionTranslationAst; + + if (itemSerializationInfo.Serializer is IRepresentationConfigurable { Representation: BsonType.Array }) + { + collectionTranslationAst = collectionTranslation.Ast; + } + else if (itemSerializationInfo.Serializer is IBsonDocumentSerializer itemDocumentSerializer) + { + if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo) || + !itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"document serializer class {itemSerializationInfo.Serializer.GetType()} does not provide member serialization info for required fields."); + } + + if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v") + { + collectionTranslationAst = collectionTranslation.Ast; + } + else + { + keySerializer = keyMemberSerializationInfo.Serializer; + valueSerializer = valueMemberSerializationInfo.Serializer; + + var pairVar = AstExpression.Var("pair"); + var computedDocumentAst = AstExpression.ComputedDocument([ + AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)), + AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName)) + ]); + collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst); + } + } + else + { + throw new ExpressionNotSupportedException(expression, because: $"document serializer class {itemSerializationInfo.Serializer.GetType()} does not implement {nameof(IBsonDocumentSerializer)}"); + } + + if (keySerializer is not IRepresentationConfigurable { Representation: BsonType.String }) + { + throw new ExpressionNotSupportedException(expression, because: "key did not serialize as a string"); + } + + var ast = AstExpression.Unary(AstUnaryOperator.ArrayToObject, collectionTranslationAst); + var resultSerializer = CreateDictionarySerializer(keySerializer, valueSerializer); + return new TranslatedExpression(expression, ast, resultSerializer); + } + + throw new ExpressionNotSupportedException(expression); + } + + public static bool CanTranslate(NewExpression expression) + => expression.Type.IsConstructedGenericType && + expression.Type.GetGenericTypeDefinition() == typeof(Dictionary<,>) && + DictionaryConstructor.IsIEnumerableKeyValuePairConstructor(expression.Constructor); + + private static IBsonSerializer CreateDictionarySerializer(IBsonSerializer keySerializer, IBsonSerializer valueSerializer) + { + var dictionaryType = typeof(Dictionary<,>).MakeGenericType(keySerializer.ValueType, valueSerializer.ValueType); + var serializerType = typeof(DictionaryInterfaceImplementerSerializer<,,>).MakeGenericType(dictionaryType, keySerializer.ValueType, valueSerializer.ValueType); + + return (IBsonSerializer)Activator.CreateInstance(serializerType, DictionaryRepresentation.Document, keySerializer, valueSerializer); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs index b54f431e516..af521d05658 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs @@ -50,6 +50,10 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr { return NewKeyValuePairExpressionToAggregationExpressionTranslator.Translate(context, expression); } + if (NewDictionaryExpressionToAggregationExpressionTranslator.CanTranslate(expression)) + { + return NewDictionaryExpressionToAggregationExpressionTranslator.Translate(context, expression); + } return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty()); } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs new file mode 100644 index 00000000000..59931e8f199 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs @@ -0,0 +1,155 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#if NET6_0_OR_GREATER || NETCOREAPP3_1_OR_GREATER + +using System; +using System.Collections.Generic; +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Linq; +using MongoDB.Driver.TestHelpers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators +{ + public class NewDictionaryExpressionToAggregationExpressionTranslatorTests : LinqIntegrationTest + { + public NewDictionaryExpressionToAggregationExpressionTranslatorTests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_should_translate() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { new KeyValuePair("A", d.A), new KeyValuePair("B", d.B) })); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { A : '$A', B: '$B' }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ ["A"] = "a", ["B"] = "b" }); + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_Create_should_translate() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { KeyValuePair.Create("A", d.A), KeyValuePair.Create("B", d.B) })); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { A : '$A', B: '$B' }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ ["A"] = "a", ["B"] = "b" }); + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_should_translate_Guid_as_string_key() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { new KeyValuePair(d.GuidAsString, d.A) })); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { $arrayToObject : [[{ k : '$GuidAsString', v : '$A' }]] }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ [Guid.Parse("3E9AE467-9705-4C17-9655-EE7730BCC2EE")] = "a" }); + } + + + [Fact] + public void NewDictionary_with_KeyValuePairs_should_translate_dynamic_array() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + d.Items.Select(i => new KeyValuePair(i.P, i.W)))); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { $arrayToObject : { $map: { input: '$Items', as: 'i', in: { k: '$$i.P', v: '$$i.W' } } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ ["x"] = "y" }); + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_throws_on_non_string_key() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { new KeyValuePair(42, d.A) })); + + var exception = Record.Exception(() => queryable.ToList()); + + exception.Should().NotBeNull(); + exception.Should().BeOfType(); + } + + public class C + { + public string A { get; set; } + + public string B { get; set; } + + [BsonRepresentation(BsonType.String)] + public Guid GuidAsString { get; set; } + + public Item[] Items { get; set; } + } + + public class Item + { + public string P { get; set; } + + public string W { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new C + { + A = "a", + B = "b", + GuidAsString = Guid.Parse("3E9AE467-9705-4C17-9655-EE7730BCC2EE"), + Items = [ new Item { P = "x", W = "y" } ] + }, + ]; + } + } +} +#endif From f1a774adc81d290a320157a7737b9948fab4947a Mon Sep 17 00:00:00 2001 From: Oleksandr Poliakov Date: Tue, 15 Apr 2025 16:29:10 -0700 Subject: [PATCH 2/4] PR --- .../NewDictionaryExpressionToAggregationExpressionTranslator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs index 72dd19ccb18..093e9e82301 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs @@ -41,7 +41,7 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr IBsonSerializer valueSerializer = null; AstExpression collectionTranslationAst; - if (itemSerializationInfo.Serializer is IRepresentationConfigurable { Representation: BsonType.Array }) + if (itemSerializationInfo.Serializer is IKeyValuePairSerializer { Representation: BsonType.Array }) { collectionTranslationAst = collectionTranslation.Ast; } From fca3c4f2fd694c6b99961acbe86fcdfd835c4e60 Mon Sep 17 00:00:00 2001 From: Oleksandr Poliakov Date: Wed, 30 Apr 2025 12:27:46 -0700 Subject: [PATCH 3/4] PR --- .../Ast/Optimizers/AstSimplifier.cs | 122 ++++++++++++------ .../Reflection/DictionaryConstructor.cs | 14 +- ...essionToAggregationExpressionTranslator.cs | 92 +++++++------ 3 files changed, 133 insertions(+), 95 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs index 5b1fd53d478..7ce6ff1e607 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs @@ -455,42 +455,81 @@ public override AstNode VisitMapExpression(AstMapExpression node) } } - if (node.In is AstComputedDocumentExpression inComputedDocumentExpression && - inComputedDocumentExpression.Fields.All(f => f.Value is AstGetFieldExpression getFieldExpression && getFieldExpression.Input == node.As && getFieldExpression.CanBeConvertedToFieldPath())) - { + // { $map : { input : { $map : { input : , as : "inner", in : { A : , B : , ... } } }, as: "outer", in : { F : '$$outer.A', G : "$$outer.B", ... } } } + // => { $map : { input : , as: "inner", in : { F : , G : , ... } } } + if (node.Input is AstMapExpression innerMapExpression && + node.As is var outerVar && + node.In is AstComputedDocumentExpression outerComputedDocumentExpression && + innerMapExpression.Input is var innerInput && + innerMapExpression.As is var innerVar && + innerMapExpression.In is AstComputedDocumentExpression innerComputedDocumentExpression && + outerComputedDocumentExpression.Fields.All(outerField => + outerField.Value is AstGetFieldExpression outerGetFieldExpression && + outerGetFieldExpression.Input == outerVar && + outerGetFieldExpression.FieldName is AstConstantExpression { Value : BsonString { Value : var matchingFieldName } } && + innerComputedDocumentExpression.Fields.Any(innerField => innerField.Path == matchingFieldName))) + { + var rewrittenOuterFields = new List(); + foreach (var outerField in outerComputedDocumentExpression.Fields) + { + var outerGetFieldExpression = (AstGetFieldExpression)outerField.Value; + var matchingFieldName = ((AstConstantExpression)outerGetFieldExpression.FieldName).Value.AsString; + var matchingInnerField = innerComputedDocumentExpression.Fields.Single(innerField => innerField.Path == matchingFieldName); + var rewrittenOuterField = AstExpression.ComputedField(outerField.Path, matchingInnerField.Value); + rewrittenOuterFields.Add(rewrittenOuterField); + } - // { $map : { input : { $map : { input : , as : "y", in : { A : "$$y.FieldA" } } }, as: "v", in : { B : '$$v.A' } } } => { $map : { input : , as: "v", in : { B : "$$v.FieldA" } } } - if (node.Input is AstMapExpression inputMapExpression && - inputMapExpression.In is AstComputedDocumentExpression innerInComputedDocumentExpression) - { - var simplified = AstExpression.Map( - inputMapExpression.Input, - inputMapExpression.As, - AstExpression.ComputedDocument(inComputedDocumentExpression.Fields.Select(f => RemapField(f, node.As.Name, innerInComputedDocumentExpression.Fields)))); + var simplified = AstExpression.Map( + input: innerInput, + @as: innerVar, + @in: AstExpression.ComputedDocument(rewrittenOuterFields)); + + return Visit(simplified); + } + + // { $map : { input : [{ A : , B : , ... }, { A : , B : , ... }, ...], as : "item", in: { F : "$$item.A", G : "$$item.B", ... } } } + // => [{ F : , G : ", ... }, { F : , G : , ... }, ...] + if (node.Input is AstComputedArrayExpression inputComputedArray && + inputComputedArray.Items.Count >= 1 && + inputComputedArray.Items[0] is AstComputedDocumentExpression firstComputedDocument && + firstComputedDocument.Fields.Select(inputField => inputField.Path).ToArray() is var inputFieldNames && + inputComputedArray.Items.Skip(1).All(otherItem => + otherItem is AstComputedDocumentExpression otherComputedDocument && + otherComputedDocument.Fields.Select(otherField => otherField.Path).SequenceEqual(inputFieldNames)) && + node.As is var itemVar && + node.In is AstComputedDocumentExpression mappedDocument && + mappedDocument.Fields.All(mappedField => + mappedField.Value is AstGetFieldExpression mappedGetField && + mappedGetField.Input == itemVar && + mappedGetField.FieldName is AstConstantExpression { Value : BsonString { Value : var matchingFieldName } } && + inputFieldNames.Contains(matchingFieldName))) + { + var rewrittenItems = new List(); + foreach (var inputItem in inputComputedArray.Items) + { + var inputDocument = (AstComputedDocumentExpression)inputItem; + + var rewrittenFields = new List(); + foreach (var mappedField in mappedDocument.Fields) + { + var mappedGetField = (AstGetFieldExpression)mappedField.Value; + var matchingFieldName = ((AstConstantExpression)mappedGetField.FieldName).Value.AsString; + var matchingInputField = inputDocument.Fields.Single(inputField => inputField.Path == matchingFieldName); + var rewrittenField = AstExpression.ComputedField(mappedField.Path, matchingInputField.Value); + rewrittenFields.Add(rewrittenField); + } - return Visit(simplified); + var rewrittenItem = AstExpression.ComputedDocument(rewrittenFields); + rewrittenItems.Add(rewrittenItem); } - // { $map : { input : [{ A: "$$ROOT.FieldA" }], as : "v", in: { B : "$$v.A" } } } => [{ B : "$FieldA }] - if (node.Input is AstComputedArrayExpression inputArrayExpression && - inputArrayExpression.Items.All(i => i is AstComputedDocumentExpression)) - { - var simplified = AstExpression.ComputedArray(inputArrayExpression.Items.Select(i => - AstExpression.ComputedDocument(inComputedDocumentExpression.Fields.Select(f => RemapField(f, node.As.Name, ((AstComputedDocumentExpression)i).Fields))))); - return Visit(simplified); - } + var simplified = AstExpression.ComputedArray(rewrittenItems); + + return Visit(simplified); } return base.VisitMapExpression(node); - static AstComputedField RemapField(AstComputedField field, string @as, IEnumerable innerFields) - { - var fieldPath = ((AstGetFieldExpression)field.Value).ConvertToFieldPath().Replace($"$${@as}.", string.Empty); - var innerField = innerFields.Single(f => f.Path == fieldPath); - - return AstExpression.ComputedField(field.Path, innerField.Value); - } - static AstExpression UltimateGetFieldInput(AstGetFieldExpression getField) { if (getField.Input is AstGetFieldExpression nestedInputGetField) @@ -609,28 +648,31 @@ arg is AstBinaryExpression argBinaryExpression && return AstExpression.Binary(oppositeComparisonOperator, argBinaryExpression.Arg1, argBinaryExpression.Arg2); } - // { $arrayToObject : [[{ k : 'A', v : '$A' }, { k : 'B', v : '$B' }]] } => { A : '$A', B : '$B' } - if (node.Operator is AstUnaryOperator.ArrayToObject && + // { $arrayToObject : [[{ k : 'A', v : }, { k : 'B', v : }, ...]] } => { A : , B : , ... } + if (node.Operator == AstUnaryOperator.ArrayToObject && arg is AstComputedArrayExpression computedArrayExpression && computedArrayExpression.Items.All( - i => i is AstComputedDocumentExpression computedDocumentExpression && - computedDocumentExpression.Fields.FirstOrDefault(f => f.Path == "k")?.Value is AstConstantExpression && - computedDocumentExpression.Fields.Any(f => f.Path == "v")) - ) + item => + item is AstComputedDocumentExpression computedDocumentExpression && + computedDocumentExpression.Fields.Count == 2 && + computedDocumentExpression.Fields[0].Path == "k" && + computedDocumentExpression.Fields[1].Path == "v" && + computedDocumentExpression.Fields[0].Value is AstConstantExpression { Value : { IsString : true } })) { - var fields = computedArrayExpression.Items.Select(KeyValuePairDocumentToComputedField); - return AstExpression.ComputedDocument(fields); + var computedFields = computedArrayExpression.Items.Select(KeyValuePairDocumentToComputedField); + return AstExpression.ComputedDocument(computedFields); } return node.Update(arg); static AstComputedField KeyValuePairDocumentToComputedField(AstExpression expression) { - var documentExpression = (AstComputedDocumentExpression)expression; - var keyExpression = documentExpression.Fields.First(f => f.Path == "k").Value; - var valueExpression = documentExpression.Fields.First(f => f.Path == "v").Value; + // caller has verified that expression is of the form: { k : , v : } + var keyValuePairDocumentExpression = (AstComputedDocumentExpression)expression; + var keyConstantExpression = (AstConstantExpression)keyValuePairDocumentExpression.Fields[0].Value; + var valueExpression = keyValuePairDocumentExpression.Fields[1].Value; - return AstExpression.ComputedField(((AstConstantExpression)keyExpression).Value.AsString, valueExpression); + return AstExpression.ComputedField(keyConstantExpression.Value.AsString, valueExpression); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs index 37be3984270..b44568d07bf 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs @@ -21,14 +21,14 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection { internal static class DictionaryConstructor { - // public static methods - public static bool IsIEnumerableKeyValuePairConstructor(ConstructorInfo ctor) + public static bool IsWithIEnumerableKeyValuePairConstructor(ConstructorInfo constructor) { - var parameters = ctor.GetParameters(); - return parameters.Length == 1 && - parameters[0].ParameterType.ImplementsIEnumerable(out var enumerableType) && - enumerableType.IsConstructedGenericType && - enumerableType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>); + var parameters = constructor.GetParameters(); + return + parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var enumerableType) && + enumerableType.IsConstructedGenericType && + enumerableType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs index 093e9e82301..9bdb3b2cf6c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs @@ -15,85 +15,81 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Linq.Expressions; using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Options; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { - internal static class NewDictionaryExpressionToAggregationExpressionTranslator + internal static class NewDictionaryExpressionToAggregationExpressionTranslator { + public static bool CanTranslate(NewExpression expression) + => expression.Type.IsConstructedGenericType && + expression.Type.GetGenericTypeDefinition() == typeof(Dictionary<,>) && + DictionaryConstructor.IsWithIEnumerableKeyValuePairConstructor(expression.Constructor); + public static TranslatedExpression Translate(TranslationContext context, NewExpression expression) { var arguments = expression.Arguments; - var collectionExpression = arguments.Single(); + + var collectionExpression = arguments[0]; var collectionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, collectionExpression); + var itemSerializer = ArraySerializerHelper.GetItemSerializer(collectionTranslation.Serializer); - if (collectionTranslation.Serializer is IBsonArraySerializer bsonArraySerializer && - bsonArraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) - { - IBsonSerializer keySerializer = null; - IBsonSerializer valueSerializer = null; - AstExpression collectionTranslationAst; + IBsonSerializer keySerializer; + IBsonSerializer valueSerializer; + AstExpression collectionTranslationAst; - if (itemSerializationInfo.Serializer is IKeyValuePairSerializer { Representation: BsonType.Array }) + if (itemSerializer is IBsonDocumentSerializer itemDocumentSerializer) + { + if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo)) { - collectionTranslationAst = collectionTranslation.Ast; + throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Key member"); } - else if (itemSerializationInfo.Serializer is IBsonDocumentSerializer itemDocumentSerializer) - { - if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo) || - !itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo)) - { - throw new ExpressionNotSupportedException(expression, because: $"document serializer class {itemSerializationInfo.Serializer.GetType()} does not provide member serialization info for required fields."); - } + keySerializer = keyMemberSerializationInfo.Serializer; - if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v") - { - collectionTranslationAst = collectionTranslation.Ast; - } - else - { - keySerializer = keyMemberSerializationInfo.Serializer; - valueSerializer = valueMemberSerializationInfo.Serializer; + if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Value member"); + } + valueSerializer = valueMemberSerializationInfo.Serializer; - var pairVar = AstExpression.Var("pair"); - var computedDocumentAst = AstExpression.ComputedDocument([ - AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)), - AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName)) - ]); - collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst); - } + if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v") + { + collectionTranslationAst = collectionTranslation.Ast; } else { - throw new ExpressionNotSupportedException(expression, because: $"document serializer class {itemSerializationInfo.Serializer.GetType()} does not implement {nameof(IBsonDocumentSerializer)}"); - } + var pairVar = AstExpression.Var("pair"); + var computedDocumentAst = AstExpression.ComputedDocument([ + AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)), + AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName)) + ]); - if (keySerializer is not IRepresentationConfigurable { Representation: BsonType.String }) - { - throw new ExpressionNotSupportedException(expression, because: "key did not serialize as a string"); + collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst); } + } + else + { + throw new ExpressionNotSupportedException(expression); + } - var ast = AstExpression.Unary(AstUnaryOperator.ArrayToObject, collectionTranslationAst); - var resultSerializer = CreateDictionarySerializer(keySerializer, valueSerializer); - return new TranslatedExpression(expression, ast, resultSerializer); + if (keySerializer is not IRepresentationConfigurable { Representation: BsonType.String }) + { + throw new ExpressionNotSupportedException(expression, because: "key does not serialize as a string"); } - throw new ExpressionNotSupportedException(expression); + var ast = AstExpression.Unary(AstUnaryOperator.ArrayToObject, collectionTranslationAst); + var resultSerializer = CreateResultSerializer(keySerializer, valueSerializer); + return new TranslatedExpression(expression, ast, resultSerializer); } - public static bool CanTranslate(NewExpression expression) - => expression.Type.IsConstructedGenericType && - expression.Type.GetGenericTypeDefinition() == typeof(Dictionary<,>) && - DictionaryConstructor.IsIEnumerableKeyValuePairConstructor(expression.Constructor); - - private static IBsonSerializer CreateDictionarySerializer(IBsonSerializer keySerializer, IBsonSerializer valueSerializer) + private static IBsonSerializer CreateResultSerializer(IBsonSerializer keySerializer, IBsonSerializer valueSerializer) { var dictionaryType = typeof(Dictionary<,>).MakeGenericType(keySerializer.ValueType, valueSerializer.ValueType); var serializerType = typeof(DictionaryInterfaceImplementerSerializer<,,>).MakeGenericType(dictionaryType, keySerializer.ValueType, valueSerializer.ValueType); From 30bb7a4968589cfbf9dcc5c568bc9929e17d9016 Mon Sep 17 00:00:00 2001 From: Oleksandr Poliakov Date: Wed, 30 Apr 2025 15:36:30 -0700 Subject: [PATCH 4/4] PR --- .../Linq3Implementation/Reflection/DictionaryConstructor.cs | 3 +++ ...DictionaryExpressionToAggregationExpressionTranslator.cs | 6 ++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs index b44568d07bf..ca0ccc27664 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs @@ -23,8 +23,11 @@ internal static class DictionaryConstructor { public static bool IsWithIEnumerableKeyValuePairConstructor(ConstructorInfo constructor) { + var declaringType = constructor.DeclaringType; var parameters = constructor.GetParameters(); return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(Dictionary<,>) && parameters.Length == 1 && parameters[0].ParameterType.ImplementsIEnumerable(out var enumerableType) && enumerableType.IsConstructedGenericType && diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs index 9bdb3b2cf6c..aee174ac38d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs @@ -26,12 +26,10 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { - internal static class NewDictionaryExpressionToAggregationExpressionTranslator + internal static class NewDictionaryExpressionToAggregationExpressionTranslator { public static bool CanTranslate(NewExpression expression) - => expression.Type.IsConstructedGenericType && - expression.Type.GetGenericTypeDefinition() == typeof(Dictionary<,>) && - DictionaryConstructor.IsWithIEnumerableKeyValuePairConstructor(expression.Constructor); + => DictionaryConstructor.IsWithIEnumerableKeyValuePairConstructor(expression.Constructor); public static TranslatedExpression Translate(TranslationContext context, NewExpression expression) {