Skip to content

Commit ad74cdf

Browse files
rstamBorisDog
authored andcommitted
CSHARP-4428: LINQ3 not handling down cast in UpdateDefinitionBuilder Set method.
1 parent 332b09a commit ad74cdf

File tree

3 files changed

+142
-100
lines changed

3 files changed

+142
-100
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs

Lines changed: 115 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -28,115 +28,61 @@ internal static class ConvertExpressionToFilterFieldTranslator
2828
{
2929
public static AstFilterField Translate(TranslationContext context, UnaryExpression expression)
3030
{
31-
if (expression.NodeType == ExpressionType.Convert)
31+
if (expression.NodeType == ExpressionType.Convert || expression.NodeType == ExpressionType.TypeAs)
3232
{
3333
var field = ExpressionToFilterFieldTranslator.Translate(context, expression.Operand);
34-
var fieldSerializer = field.Serializer;
35-
var fieldType = fieldSerializer.ValueType;
34+
var fieldType = field.Serializer.ValueType;
3635
var targetType = expression.Type;
3736

38-
if (fieldType.IsEnumOrNullableEnum(out _, out var underlyingType))
37+
if (IsConvertEnumToUnderlyingType(fieldType, targetType))
3938
{
40-
if (targetType.IsSameAsOrNullableOf(underlyingType))
41-
{
42-
IBsonSerializer enumSerializer;
43-
if (fieldType.IsNullable())
44-
{
45-
var nullableSerializer = (INullableSerializer)fieldSerializer;
46-
enumSerializer = nullableSerializer.ValueSerializer;
47-
}
48-
else
49-
{
50-
enumSerializer = fieldSerializer;
51-
}
52-
53-
IBsonSerializer targetSerializer;
54-
var enumUnderlyingTypeSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer);
55-
if (targetType.IsNullable())
56-
{
57-
targetSerializer = NullableSerializer.Create(enumUnderlyingTypeSerializer);
58-
}
59-
else
60-
{
61-
targetSerializer = enumUnderlyingTypeSerializer;
62-
}
63-
64-
return AstFilter.Field(field.Path, targetSerializer);
65-
}
39+
return TranslateConvertEnumToUnderlyingType(field, targetType);
6640
}
6741

68-
if (IsNumericType(targetType))
42+
if (IsNumericConversion(fieldType, targetType))
6943
{
70-
IBsonSerializer targetTypeSerializer = expression.Type switch
71-
{
72-
Type t when t == typeof(byte) => new ByteSerializer(),
73-
Type t when t == typeof(short) => new Int16Serializer(),
74-
Type t when t == typeof(ushort) => new UInt16Serializer(),
75-
Type t when t == typeof(int) => new Int32Serializer(),
76-
Type t when t == typeof(uint) => new UInt32Serializer(),
77-
Type t when t == typeof(long) => new Int64Serializer(),
78-
Type t when t == typeof(ulong) => new UInt64Serializer(),
79-
Type t when t == typeof(float) => new SingleSerializer(),
80-
Type t when t == typeof(double) => new DoubleSerializer(),
81-
Type t when t == typeof(decimal) => new DecimalSerializer(),
82-
_ => throw new ExpressionNotSupportedException(expression)
83-
};
84-
if (fieldSerializer is IRepresentationConfigurable representationConfigurableFieldSerializer &&
85-
targetTypeSerializer is IRepresentationConfigurable representationConfigurableTargetTypeSerializer)
86-
{
87-
var fieldRepresentation = representationConfigurableFieldSerializer.Representation;
88-
if (fieldRepresentation == BsonType.String)
89-
{
90-
targetTypeSerializer = representationConfigurableTargetTypeSerializer.WithRepresentation(fieldRepresentation);
91-
}
92-
}
93-
if (fieldSerializer is IRepresentationConverterConfigurable converterConfigurableFieldSerializer &&
94-
targetTypeSerializer is IRepresentationConverterConfigurable converterConfigurableTargetTypeSerializer)
95-
{
96-
targetTypeSerializer = converterConfigurableTargetTypeSerializer.WithConverter(converterConfigurableFieldSerializer.Converter);
97-
}
98-
return AstFilter.Field(field.Path, targetTypeSerializer);
44+
return TranslateNumericConversion(field, targetType);
9945
}
10046

101-
if (targetType.IsConstructedGenericType &&
102-
targetType.GetGenericTypeDefinition() == typeof(Nullable<>))
47+
if (IsConvertToNullable(fieldType, targetType))
10348
{
104-
var nullableValueType = targetType.GetGenericArguments()[0];
105-
if (nullableValueType == fieldType)
106-
{
107-
var nullableSerializerType = typeof(NullableSerializer<>).MakeGenericType(nullableValueType);
108-
var nullableSerializer = (IBsonSerializer)Activator.CreateInstance(nullableSerializerType, fieldSerializer);
109-
return AstFilter.Field(field.Path, nullableSerializer);
110-
}
111-
112-
if (fieldType.IsConstructedGenericType &&
113-
fieldType.GetGenericTypeDefinition() == typeof(Nullable<>))
114-
{
115-
var fieldValueType = fieldType.GetGenericArguments()[0];
116-
if (fieldValueType.IsEnum)
117-
{
118-
var enumUnderlyingType = fieldValueType.GetEnumUnderlyingType();
119-
if (nullableValueType == enumUnderlyingType)
120-
{
121-
var fieldSerializerType = fieldSerializer.GetType();
122-
if (fieldSerializerType.IsConstructedGenericType &&
123-
fieldSerializerType.GetGenericTypeDefinition() == typeof(NullableSerializer<>))
124-
{
125-
var enumSerializer = ((IChildSerializerConfigurable)fieldSerializer).ChildSerializer;
126-
var enumUnderlyingTypeSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer);
127-
var nullableSerializerType = typeof(NullableSerializer<>).MakeGenericType(nullableValueType);
128-
var nullableSerializer = (IBsonSerializer)Activator.CreateInstance(nullableSerializerType, enumUnderlyingTypeSerializer);
129-
return AstFilter.Field(field.Path, nullableSerializer);
130-
}
131-
}
132-
}
133-
}
49+
return TranslateConvertToNullable(field);
50+
}
51+
52+
if (IsConvertToDerivedType(fieldType, targetType))
53+
{
54+
return TranslateConvertToDerivedType(field, targetType);
13455
}
13556
}
13657

13758
throw new ExpressionNotSupportedException(expression);
13859
}
13960

61+
private static bool IsConvertEnumToUnderlyingType(Type fieldType, Type targetType)
62+
{
63+
return
64+
fieldType.IsEnumOrNullableEnum(out _, out var underlyingType) &&
65+
targetType.IsSameAsOrNullableOf(underlyingType);
66+
}
67+
68+
private static bool IsConvertToDerivedType(Type fieldType, Type targetType)
69+
{
70+
return targetType.IsSubclassOf(fieldType);
71+
}
72+
73+
private static bool IsConvertToNullable(Type fieldType, Type targetType)
74+
{
75+
return
76+
targetType.IsConstructedGenericType &&
77+
targetType.GetGenericTypeDefinition() == typeof(Nullable<>) &&
78+
targetType.GetGenericArguments()[0] == fieldType;
79+
}
80+
81+
private static bool IsNumericConversion(Type fieldType, Type targetType)
82+
{
83+
return IsNumericType(fieldType) && IsNumericType(targetType);
84+
}
85+
14086
private static bool IsNumericType(Type type)
14187
{
14288
switch (Type.GetTypeCode(type))
@@ -147,6 +93,7 @@ private static bool IsNumericType(Type type)
14793
case TypeCode.Int16:
14894
case TypeCode.Int32:
14995
case TypeCode.Int64:
96+
case TypeCode.SByte:
15097
case TypeCode.Single:
15198
case TypeCode.UInt16:
15299
case TypeCode.UInt32:
@@ -157,5 +104,81 @@ private static bool IsNumericType(Type type)
157104
return false;
158105
}
159106
}
107+
108+
private static AstFilterField TranslateConvertEnumToUnderlyingType(AstFilterField field, Type targetType)
109+
{
110+
var fieldSerializer = field.Serializer;
111+
var fieldType = fieldSerializer.ValueType;
112+
113+
IBsonSerializer enumSerializer;
114+
if (fieldType.IsNullable())
115+
{
116+
var nullableSerializer = (INullableSerializer)fieldSerializer;
117+
enumSerializer = nullableSerializer.ValueSerializer;
118+
}
119+
else
120+
{
121+
enumSerializer = fieldSerializer;
122+
}
123+
124+
IBsonSerializer targetSerializer;
125+
var enumUnderlyingTypeSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer);
126+
if (targetType.IsNullable())
127+
{
128+
targetSerializer = NullableSerializer.Create(enumUnderlyingTypeSerializer);
129+
}
130+
else
131+
{
132+
targetSerializer = enumUnderlyingTypeSerializer;
133+
}
134+
135+
return AstFilter.Field(field.Path, targetSerializer);
136+
}
137+
138+
private static AstFilterField TranslateConvertToDerivedType(AstFilterField field, Type targetType)
139+
{
140+
var targetSerializer = BsonSerializer.LookupSerializer(targetType);
141+
return AstFilter.Field(field.Path, targetSerializer);
142+
}
143+
144+
private static AstFilterField TranslateConvertToNullable(AstFilterField field)
145+
{
146+
var nullableSerializer = NullableSerializer.Create(field.Serializer);
147+
return AstFilter.Field(field.Path, nullableSerializer);
148+
}
149+
150+
private static AstFilterField TranslateNumericConversion(AstFilterField field, Type targetType)
151+
{
152+
IBsonSerializer targetTypeSerializer = targetType switch
153+
{
154+
Type t when t == typeof(byte) => new ByteSerializer(),
155+
Type t when t == typeof(sbyte) => new SByteSerializer(),
156+
Type t when t == typeof(short) => new Int16Serializer(),
157+
Type t when t == typeof(ushort) => new UInt16Serializer(),
158+
Type t when t == typeof(int) => new Int32Serializer(),
159+
Type t when t == typeof(uint) => new UInt32Serializer(),
160+
Type t when t == typeof(long) => new Int64Serializer(),
161+
Type t when t == typeof(ulong) => new UInt64Serializer(),
162+
Type t when t == typeof(float) => new SingleSerializer(),
163+
Type t when t == typeof(double) => new DoubleSerializer(),
164+
Type t when t == typeof(decimal) => new DecimalSerializer(),
165+
_ => throw new Exception($"Unexpected target type: {targetType}.")
166+
};
167+
if (field.Serializer is IRepresentationConfigurable representationConfigurableFieldSerializer &&
168+
targetTypeSerializer is IRepresentationConfigurable representationConfigurableTargetTypeSerializer)
169+
{
170+
var fieldRepresentation = representationConfigurableFieldSerializer.Representation;
171+
if (fieldRepresentation == BsonType.String)
172+
{
173+
targetTypeSerializer = representationConfigurableTargetTypeSerializer.WithRepresentation(fieldRepresentation);
174+
}
175+
}
176+
if (field.Serializer is IRepresentationConverterConfigurable converterConfigurableFieldSerializer &&
177+
targetTypeSerializer is IRepresentationConverterConfigurable converterConfigurableTargetTypeSerializer)
178+
{
179+
targetTypeSerializer = converterConfigurableTargetTypeSerializer.WithConverter(converterConfigurableFieldSerializer.Converter);
180+
}
181+
return AstFilter.Field(field.Path, targetTypeSerializer);
182+
}
160183
}
161184
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ExpressionToFilterFieldTranslator.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ public static AstFilterField Translate(TranslationContext context, Expression ex
3030
case ExpressionType.MemberAccess: return MemberExpressionToFilterFieldTranslator.Translate(context, (MemberExpression)expression);
3131
case ExpressionType.Call: return MethodCallExpressionToFilterFieldTranslator.Translate(context, (MethodCallExpression)expression);
3232
case ExpressionType.Parameter: return ParameterExpressionToFilterFieldTranslator.Translate(context, (ParameterExpression)expression);
33-
case ExpressionType.Convert: return ConvertExpressionToFilterFieldTranslator.Translate(context, (UnaryExpression)expression);
33+
34+
case ExpressionType.Convert:
35+
case ExpressionType.TypeAs:
36+
return ConvertExpressionToFilterFieldTranslator.Translate(context, (UnaryExpression)expression);
3437
}
3538

3639
throw new ExpressionNotSupportedException(expression);

tests/MongoDB.Driver.Tests/UpdateDefinitionBuilderTests.cs

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
using System;
1717
using System.Collections.Generic;
1818
using System.Linq;
19+
using System.Linq.Expressions;
1920
using FluentAssertions;
2021
using MongoDB.Bson;
2122
using MongoDB.Bson.Serialization;
@@ -595,23 +596,23 @@ public void Set_Typed_with_cast()
595596
{
596597
var subject = CreateSubject<Message>();
597598

598-
Assert(subject.Set(x => ((SmsMessage)x).PhoneNumber, "1234567890"), "{$set: {pn: '1234567890'}}");
599+
Assert(subject.Set(x => ((SmsMessage)x).PhoneNumber, "1234567890"), "{$set: {pn: '1234567890'}}", LinqProvider.V3);
599600

600601
var subject2 = CreateSubject<Person>();
601602

602-
Assert(subject2.Set(x => ((SmsMessage)x.Message).PhoneNumber, "1234567890"), "{$set: {'m.pn': '1234567890'}}");
603+
Assert(subject2.Set(x => ((SmsMessage)x.Message).PhoneNumber, "1234567890"), "{$set: {'m.pn': '1234567890'}}", LinqProvider.V3);
603604
}
604605

605606
[Fact]
606607
public void Set_Typed_with_type_as()
607608
{
608609
var subject = CreateSubject<Message>();
609610

610-
Assert(subject.Set(x => (x as SmsMessage).PhoneNumber, "1234567890"), "{$set: {pn: '1234567890'}}");
611+
Assert(subject.Set(x => (x as SmsMessage).PhoneNumber, "1234567890"), "{$set: {pn: '1234567890'}}", LinqProvider.V3);
611612

612613
var subject2 = CreateSubject<Person>();
613614

614-
Assert(subject2.Set(x => (x.Message as SmsMessage).PhoneNumber, "1234567890"), "{$set: {'m.pn': '1234567890'}}");
615+
Assert(subject2.Set(x => (x.Message as SmsMessage).PhoneNumber, "1234567890"), "{$set: {'m.pn': '1234567890'}}", LinqProvider.V3);
615616
}
616617

617618
[Fact]
@@ -650,7 +651,12 @@ public void Unset_Typed()
650651

651652
private void Assert<TDocument>(UpdateDefinition<TDocument> update, BsonDocument expected)
652653
{
653-
var renderedUpdate = Render(update).AsBsonDocument;
654+
Assert(update, expected, LinqProvider.V2);
655+
}
656+
657+
private void Assert<TDocument>(UpdateDefinition<TDocument> update, BsonDocument expected, LinqProvider linqProvider)
658+
{
659+
var renderedUpdate = Render(update, linqProvider).AsBsonDocument;
654660

655661
renderedUpdate.Should().Be(expected);
656662
}
@@ -665,7 +671,12 @@ private void Assert<TDocument>(UpdateDefinition<TDocument> update, string[] expe
665671

666672
private void Assert<TDocument>(UpdateDefinition<TDocument> update, string expected)
667673
{
668-
Assert(update, BsonDocument.Parse(expected));
674+
Assert(update, expected, LinqProvider.V2);
675+
}
676+
677+
private void Assert<TDocument>(UpdateDefinition<TDocument> update, string expected, LinqProvider linqProvider)
678+
{
679+
Assert(update, BsonDocument.Parse(expected), linqProvider);
669680
}
670681

671682
private void AssertThrow<TDocument, TException>(UpdateDefinition<TDocument> update, string errorMessage) where TException : Exception
@@ -681,9 +692,14 @@ private UpdateDefinitionBuilder<TDocument> CreateSubject<TDocument>()
681692
}
682693

683694
private BsonValue Render<TDocument>(UpdateDefinition<TDocument> update)
695+
{
696+
return Render(update, LinqProvider.V2);
697+
}
698+
699+
private BsonValue Render<TDocument>(UpdateDefinition<TDocument> update, LinqProvider linqProvider)
684700
{
685701
var documentSerializer = BsonSerializer.SerializerRegistry.GetSerializer<TDocument>();
686-
return update.Render(documentSerializer, BsonSerializer.SerializerRegistry, LinqProvider.V2);
702+
return update.Render(documentSerializer, BsonSerializer.SerializerRegistry, linqProvider);
687703
}
688704

689705
private class Person

0 commit comments

Comments
 (0)