Skip to content

Commit ca108c1

Browse files
CSHARP-4255: Fix bug and some tests.
1 parent def8eab commit ca108c1

File tree

4 files changed

+167
-25
lines changed

4 files changed

+167
-25
lines changed

src/MongoDB.Driver/Encryption/ClientEncryption.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,59 +82,59 @@ public Task<BsonDocument> AddAlternateKeyNameAsync(Guid id, string alternateKeyN
8282
/// <summary>
8383
/// Create encrypted collection.
8484
/// </summary>
85-
/// <param name="collectionNamespace">The collection namespace.</param>
85+
/// <param name="database">The database.</param>
86+
/// <param name="collectionName">The collectionName.</param>
8687
/// <param name="createCollectionOptions">The create collection options.</param>
8788
/// <param name="kmsProvider">The kms provider.</param>
8889
/// <param name="dataKeyOptions">The datakey options.</param>
8990
/// <param name="cancellationToken">The cancellation token.</param>
9091
/// <remarks>
9192
/// if EncryptionFields contains a keyId with a null value, a data key will be automatically generated and assigned to keyId value.
9293
/// </remarks>
93-
public void CreateEncryptedCollection<TCollection>(CollectionNamespace collectionNamespace, CreateCollectionOptions createCollectionOptions, string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken = default)
94+
public void CreateEncryptedCollection(IMongoDatabase database, string collectionName, CreateCollectionOptions createCollectionOptions, string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken = default)
9495
{
95-
Ensure.IsNotNull(collectionNamespace, nameof(collectionNamespace));
96+
Ensure.IsNotNull(database, nameof(database));
97+
Ensure.IsNotNull(collectionName, nameof(collectionName));
9698
Ensure.IsNotNull(createCollectionOptions, nameof(createCollectionOptions));
9799
Ensure.IsNotNull(dataKeyOptions, nameof(dataKeyOptions));
98100
Ensure.IsNotNull(kmsProvider, nameof(kmsProvider));
99101

100-
foreach (var fieldDocument in EncryptedCollectionHelper.IterateEmptyKeyIds(collectionNamespace, createCollectionOptions.EncryptedFields))
102+
foreach (var fieldDocument in EncryptedCollectionHelper.IterateEmptyKeyIds(new CollectionNamespace(database.DatabaseNamespace.DatabaseName, collectionName), createCollectionOptions.EncryptedFields))
101103
{
102104
var dataKey = CreateDataKey(kmsProvider, dataKeyOptions, cancellationToken);
103105
EncryptedCollectionHelper.ModifyEncryptedFields(fieldDocument, dataKey);
104106
}
105107

106-
var database = _libMongoCryptController.KeyVaultClient.GetDatabase(collectionNamespace.DatabaseNamespace.DatabaseName);
107-
108-
database.CreateCollection(collectionNamespace.CollectionName, createCollectionOptions, cancellationToken);
108+
database.CreateCollection(collectionName, createCollectionOptions, cancellationToken);
109109
}
110110

111111
/// <summary>
112112
/// Create encrypted collection.
113113
/// </summary>
114-
/// <param name="collectionNamespace">The collection namespace.</param>
114+
/// <param name="database">The database.</param>
115+
/// <param name="collectionName">The collectionName.</param>
115116
/// <param name="createCollectionOptions">The create collection options.</param>
116117
/// <param name="kmsProvider">The kms provider.</param>
117118
/// <param name="dataKeyOptions">The datakey options.</param>
118119
/// <param name="cancellationToken">The cancellation token.</param>
119120
/// <remarks>
120121
/// if EncryptionFields contains a keyId with a null value, a data key will be automatically generated and assigned to keyId value.
121122
/// </remarks>
122-
public async Task CreateEncryptedCollectionAsync<TCollection>(CollectionNamespace collectionNamespace, CreateCollectionOptions createCollectionOptions, string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken = default)
123+
public async Task CreateEncryptedCollectionAsync(IMongoDatabase database, string collectionName, CreateCollectionOptions createCollectionOptions, string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken = default)
123124
{
124-
Ensure.IsNotNull(collectionNamespace, nameof(collectionNamespace));
125+
Ensure.IsNotNull(database, nameof(database));
126+
Ensure.IsNotNull(collectionName, nameof(collectionName));
125127
Ensure.IsNotNull(createCollectionOptions, nameof(createCollectionOptions));
126128
Ensure.IsNotNull(dataKeyOptions, nameof(dataKeyOptions));
127129
Ensure.IsNotNull(kmsProvider, nameof(kmsProvider));
128130

129-
foreach (var fieldDocument in EncryptedCollectionHelper.IterateEmptyKeyIds(collectionNamespace, createCollectionOptions.EncryptedFields))
131+
foreach (var fieldDocument in EncryptedCollectionHelper.IterateEmptyKeyIds(new CollectionNamespace(database.DatabaseNamespace.DatabaseName, collectionName), createCollectionOptions.EncryptedFields))
130132
{
131133
var dataKey = await CreateDataKeyAsync(kmsProvider, dataKeyOptions, cancellationToken).ConfigureAwait(false);
132134
EncryptedCollectionHelper.ModifyEncryptedFields(fieldDocument, dataKey);
133135
}
134136

135-
var database = _libMongoCryptController.KeyVaultClient.GetDatabase(collectionNamespace.DatabaseNamespace.DatabaseName);
136-
137-
await database.CreateCollectionAsync(collectionNamespace.CollectionName, createCollectionOptions, cancellationToken).ConfigureAwait(false);
137+
await database.CreateCollectionAsync(collectionName, createCollectionOptions, cancellationToken).ConfigureAwait(false);
138138
}
139139

140140
/// <summary>

tests/MongoDB.Bson.TestHelpers/BsonValueEquivalencyComparer.cs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
1617
using System.Collections.Generic;
1718

1819
namespace MongoDB.Bson.TestHelpers
@@ -22,15 +23,17 @@ public class BsonValueEquivalencyComparer : IEqualityComparer<BsonValue>
2223
#region static
2324
public static BsonValueEquivalencyComparer Instance { get; } = new BsonValueEquivalencyComparer();
2425

25-
public static bool Compare(BsonValue a, BsonValue b)
26+
public static bool Compare(BsonValue a, BsonValue b, Action<BsonValue, BsonValue> massageAction = null)
2627
{
28+
massageAction?.Invoke(a, b);
29+
2730
if (a.BsonType == BsonType.Document && b.BsonType == BsonType.Document)
2831
{
29-
return CompareDocuments((BsonDocument)a, (BsonDocument)b);
32+
return CompareDocuments((BsonDocument)a, (BsonDocument)b, massageAction);
3033
}
3134
else if (a.BsonType == BsonType.Array && b.BsonType == BsonType.Array)
3235
{
33-
return CompareArrays((BsonArray)a, (BsonArray)b);
36+
return CompareArrays((BsonArray)a, (BsonArray)b, massageAction);
3437
}
3538
else if (a.BsonType == b.BsonType)
3639
{
@@ -50,7 +53,7 @@ public static bool Compare(BsonValue a, BsonValue b)
5053
}
5154
}
5255

53-
private static bool CompareArrays(BsonArray a, BsonArray b)
56+
private static bool CompareArrays(BsonArray a, BsonArray b, Action<BsonValue, BsonValue> massageAction = null)
5457
{
5558
if (a.Count != b.Count)
5659
{
@@ -59,7 +62,7 @@ private static bool CompareArrays(BsonArray a, BsonArray b)
5962

6063
for (var i = 0; i < a.Count; i++)
6164
{
62-
if (!Compare(a[i], b[i]))
65+
if (!Compare(a[i], b[i], massageAction))
6366
{
6467
return false;
6568
}
@@ -68,7 +71,7 @@ private static bool CompareArrays(BsonArray a, BsonArray b)
6871
return true;
6972
}
7073

71-
private static bool CompareDocuments(BsonDocument a, BsonDocument b)
74+
private static bool CompareDocuments(BsonDocument a, BsonDocument b, Action<BsonValue, BsonValue> massageAction = null)
7275
{
7376
if (a.ElementCount != b.ElementCount)
7477
{
@@ -83,7 +86,7 @@ private static bool CompareDocuments(BsonDocument a, BsonDocument b)
8386
return false;
8487
}
8588

86-
if (!Compare(aElement.Value, bElement.Value))
89+
if (!Compare(aElement.Value, bElement.Value, massageAction))
8790
{
8891
return false;
8992
}

tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
using MongoDB.Driver.Tests.Specifications.client_side_encryption;
2626
using MongoDB.Libmongocrypt;
2727
using Xunit;
28+
using Moq;
29+
using System.Collections.Generic;
30+
using System.Threading;
2831

2932
namespace MongoDB.Driver.Tests.Encryption
3033
{
@@ -64,6 +67,141 @@ public async Task CreateDataKey_should_correctly_handle_input_arguments()
6467
}
6568
}
6669

70+
[Fact]
71+
public async Task CreateEncryptedCollection_should_handle_input_arguments()
72+
{
73+
const string kmsProvider = "local";
74+
const string collectionName = "collName";
75+
var createCollectionOptions = new CreateCollectionOptions();
76+
var database = Mock.Of<IMongoDatabase>();
77+
78+
var dataKeyOptions = new DataKeyOptions();
79+
80+
using (var subject = CreateSubject())
81+
{
82+
ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database: null, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedParamName: "database");
83+
ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database: null, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedParamName: "database");
84+
85+
ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName: null, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedParamName: "collectionName");
86+
ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName: null, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedParamName: "collectionName");
87+
88+
ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName: collectionName, createCollectionOptions: null, kmsProvider, dataKeyOptions)), expectedParamName: "createCollectionOptions");
89+
ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions: null, kmsProvider, dataKeyOptions)), expectedParamName: "createCollectionOptions");
90+
91+
ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName: collectionName, createCollectionOptions, kmsProvider: null, dataKeyOptions)), expectedParamName: "kmsProvider");
92+
ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions, kmsProvider: null, dataKeyOptions)), expectedParamName: "kmsProvider");
93+
94+
ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName: collectionName, createCollectionOptions, kmsProvider, dataKeyOptions: null)), expectedParamName: "dataKeyOptions");
95+
ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions: null)), expectedParamName: "dataKeyOptions");
96+
}
97+
}
98+
99+
[Fact]
100+
public async Task CreateEncryptedCollection_should_handle_save_generated_key_when_second_key_failed()
101+
{
102+
const string kmsProvider = "local";
103+
const string collectionName = "collName";
104+
const string encryptedFieldsStr = "{ fields : [{ keyId : null }, { keyId : null }] }";
105+
var database = Mock.Of<IMongoDatabase>(d => d.DatabaseNamespace == new DatabaseNamespace("db"));
106+
107+
var dataKeyOptions = new DataKeyOptions();
108+
109+
var mockCollection = new Mock<IMongoCollection<BsonDocument>>();
110+
mockCollection
111+
.SetupSequence(c => c.InsertOne(It.IsAny<BsonDocument>(), It.IsAny<InsertOneOptions>(), It.IsAny<CancellationToken>()))
112+
.Pass()
113+
.Throws(new Exception("test"));
114+
mockCollection
115+
.SetupSequence(c => c.InsertOneAsync(It.IsAny<BsonDocument>(), It.IsAny<InsertOneOptions>(), It.IsAny<CancellationToken>()))
116+
.Returns(Task.CompletedTask)
117+
.Throws(new Exception("test"));
118+
var mockDatabase = new Mock<IMongoDatabase>();
119+
mockDatabase.Setup(c => c.GetCollection<BsonDocument>(It.IsAny<string>(), It.IsAny<MongoCollectionSettings>())).Returns(mockCollection.Object);
120+
var client = new Mock<IMongoClient>();
121+
client.Setup(c => c.GetDatabase(It.IsAny<string>(), It.IsAny<MongoDatabaseSettings>())).Returns(mockDatabase.Object);
122+
123+
using (var subject = CreateSubject(client.Object))
124+
{
125+
var createCollectionOptions = new CreateCollectionOptions() { EncryptedFields = BsonDocument.Parse(encryptedFieldsStr) };
126+
var exception = Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions));
127+
AssertResults(exception.InnerException, createCollectionOptions);
128+
129+
createCollectionOptions = new CreateCollectionOptions() { EncryptedFields = BsonDocument.Parse(encryptedFieldsStr) };
130+
exception = await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions));
131+
AssertResults(exception.InnerException, createCollectionOptions);
132+
}
133+
134+
void AssertResults(Exception ex, CreateCollectionOptions createCollectionOptions)
135+
{
136+
ex.Should().BeOfType<Exception>().Which.Message.Should().Be("test");
137+
var fields = createCollectionOptions.EncryptedFields["fields"].AsBsonArray;
138+
fields[0].AsBsonDocument["keyId"].Should().BeOfType<BsonBinaryData>(); // pass
139+
/*
140+
- If generating `D` resulted in an error `E`, the entire
141+
`CreateEncryptedCollection` must now fail with error `E`. Return the
142+
partially-formed `EF'` with the error so that the caller may know what
143+
datakeys have already been created by the helper.
144+
*/
145+
fields[1].AsBsonDocument["keyId"].Should().BeOfType<BsonNull>(); // throw
146+
}
147+
}
148+
149+
[Theory]
150+
[InlineData(null, "There are no encrypted fields defined for the collection.")]
151+
[InlineData("{}", "{}")]
152+
[InlineData("{ a : 1 }", "{ a : 1 }")]
153+
[InlineData("{ fields : { } }", "{ fields: { } }")]
154+
[InlineData("{ fields : [] }", "{ fields: [] }")]
155+
[InlineData("{ fields : [{ a : 1 }] }", "{ fields: [{ a : 1 }] }")]
156+
[InlineData("{ fields : [{ keyId : 1 }] }", "{ fields: [{ keyId : 1 }] }")]
157+
[InlineData("{ fields : [{ keyId : null }] }", "{ fields: [{ keyId : '#binary_generated#' }] }")]
158+
[InlineData("{ fields : [{ keyId : null }, { keyId : null }] }", "{ fields: [{ keyId : '#binary_generated#' }, { keyId : '#binary_generated#' }] }")]
159+
[InlineData("{ fields : [{ keyId : 3 }, { keyId : null }] }", "{ fields: [{ keyId : 3 }, { keyId : '#binary_generated#' }] }")]
160+
public async Task CreateEncryptedCollection_should_handle_various_encryptedFields(string encryptedFieldsStr, string expectedResult)
161+
{
162+
const string kmsProvider = "local";
163+
const string collectionName = "collName";
164+
var database = Mock.Of<IMongoDatabase>(d => d.DatabaseNamespace == new DatabaseNamespace("db"));
165+
166+
var dataKeyOptions = new DataKeyOptions();
167+
168+
using (var subject = CreateSubject())
169+
{
170+
if (BsonDocument.TryParse(expectedResult, out var encryptedFields))
171+
{
172+
var createCollectionOptions = new CreateCollectionOptions() { EncryptedFields = encryptedFieldsStr != null ? BsonDocument.Parse(encryptedFieldsStr) : null };
173+
subject.CreateEncryptedCollection(database, collectionName: collectionName, createCollectionOptions, kmsProvider: kmsProvider, dataKeyOptions);
174+
createCollectionOptions.EncryptedFields.WithComparer(new EncryptedFieldsComparer()).Should().Be(encryptedFields.DeepClone());
175+
176+
createCollectionOptions = new CreateCollectionOptions() { EncryptedFields = encryptedFieldsStr != null ? BsonDocument.Parse(encryptedFieldsStr) : null };
177+
await subject.CreateEncryptedCollectionAsync(database, collectionName: collectionName, createCollectionOptions, kmsProvider: kmsProvider, dataKeyOptions);
178+
createCollectionOptions.EncryptedFields.WithComparer(new EncryptedFieldsComparer()).Should().Be(encryptedFields.DeepClone());
179+
}
180+
else
181+
{
182+
var createCollectionOptions = new CreateCollectionOptions() { EncryptedFields = encryptedFieldsStr != null ? BsonDocument.Parse(encryptedFieldsStr) : null };
183+
AssertInvalidOperationException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName: collectionName, createCollectionOptions, kmsProvider: kmsProvider, dataKeyOptions)), expectedResult);
184+
185+
createCollectionOptions = new CreateCollectionOptions() { EncryptedFields = encryptedFieldsStr != null ? BsonDocument.Parse(encryptedFieldsStr) : null };
186+
AssertInvalidOperationException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName: collectionName, createCollectionOptions, kmsProvider: kmsProvider, dataKeyOptions)), expectedResult);
187+
}
188+
}
189+
190+
void AssertInvalidOperationException(Exception ex, string message) => ex.Should().BeOfType<InvalidOperationException>().Which.Message.Should().Be(message);
191+
}
192+
193+
private class EncryptedFieldsComparer : IEqualityComparer<BsonDocument>
194+
{
195+
public bool Equals(BsonDocument x, BsonDocument y) => BsonValueEquivalencyComparer.Compare(x, y, (a, b) =>
196+
{
197+
if (a is BsonDocument aDocument && aDocument.TryGetValue("keyId", out var aKeyId) && aKeyId.IsBsonBinaryData &&
198+
b is BsonDocument bDocument && bDocument.TryGetValue("keyId", out var bKeyId) && bKeyId == "#binary_generated#")
199+
{
200+
bDocument["keyId"] = aDocument["keyId"];
201+
}
202+
});
203+
public int GetHashCode(BsonDocument obj) => obj.GetHashCode();
204+
}
67205

68206
[Fact]
69207
public void CryptClient_should_be_initialized()
@@ -167,10 +305,10 @@ public async Task RewrapManyDataKey_should_correctly_handle_input_arguments()
167305
}
168306

169307
// private methods
170-
private ClientEncryption CreateSubject()
308+
private ClientEncryption CreateSubject(IMongoClient client = null)
171309
{
172310
var clientEncryptionOptions = new ClientEncryptionOptions(
173-
DriverTestConfiguration.Client,
311+
client ?? DriverTestConfiguration.Client,
174312
__keyVaultCollectionNamespace,
175313
kmsProviders: EncryptionTestHelper.GetKmsProviders(filter: "local"));
176314

0 commit comments

Comments
 (0)