|
2 | 2 | // Licensed under the MIT License. See License.txt in the project root for license information.
|
3 | 3 |
|
4 | 4 | using System;
|
| 5 | +using System.Collections.Concurrent; |
5 | 6 | using System.Collections.Generic;
|
6 | 7 | using System.IO;
|
7 | 8 | using System.Linq;
|
8 | 9 | using System.Security.Cryptography;
|
| 10 | +using System.Threading; |
9 | 11 | using System.Threading.Tasks;
|
10 | 12 | using Microsoft.Azure.WebJobs.Extensions.Http;
|
11 | 13 | using Microsoft.Azure.WebJobs.Logging;
|
|
14 | 16 | using Microsoft.Azure.WebJobs.Script.WebHost.Models;
|
15 | 17 | using Microsoft.Azure.WebJobs.Script.WebHost.Properties;
|
16 | 18 | using Microsoft.Azure.WebJobs.Script.WebHost.Security;
|
| 19 | +using Microsoft.Azure.WebJobs.Script.WebHost.Storage; |
17 | 20 | using Microsoft.Extensions.Logging;
|
18 | 21 | using Microsoft.WebJobs.Script.Tests;
|
19 | 22 | using Moq;
|
@@ -337,6 +340,55 @@ public async Task GetHostSecrets_UpdatesStaleSecrets()
|
337 | 340 | }
|
338 | 341 | }
|
339 | 342 |
|
| 343 | + [Fact] |
| 344 | + public async Task GetFunctionSecretsAsync_SecretGenerationIsSerialized() |
| 345 | + { |
| 346 | + var mockValueConverterFactory = GetConverterFactoryMock(false, false); |
| 347 | + var metricsLogger = new TestMetricsLogger(); |
| 348 | + var testRepository = new TestSecretsRepository(true); |
| 349 | + string testFunctionName = $"TestFunction"; |
| 350 | + |
| 351 | + using (var secretManager = new SecretManager(testRepository, mockValueConverterFactory.Object, _logger, metricsLogger, _hostNameProvider, _startupContextProvider)) |
| 352 | + { |
| 353 | + var tasks = new List<Task<IDictionary<string, string>>>(); |
| 354 | + for (int i = 0; i < 10; i++) |
| 355 | + { |
| 356 | + tasks.Add(secretManager.GetFunctionSecretsAsync(testFunctionName)); |
| 357 | + } |
| 358 | + |
| 359 | + await Task.WhenAll(tasks); |
| 360 | + |
| 361 | + // verify all calls return the same result |
| 362 | + Assert.Equal(1, testRepository.FunctionSecrets.Count); |
| 363 | + var functionSecrets = (FunctionSecrets)testRepository.FunctionSecrets[testFunctionName]; |
| 364 | + string defaultKeyValue = functionSecrets.Keys.Where(p => p.Name == "default").Single().Value; |
| 365 | + Assert.True(tasks.Select(p => p.Result).All(t => t["default"] == defaultKeyValue)); |
| 366 | + } |
| 367 | + } |
| 368 | + |
| 369 | + [Fact] |
| 370 | + public async Task GetHostSecretsAsync_SecretGenerationIsSerialized() |
| 371 | + { |
| 372 | + var mockValueConverterFactory = GetConverterFactoryMock(false, false); |
| 373 | + var metricsLogger = new TestMetricsLogger(); |
| 374 | + var testRepository = new TestSecretsRepository(true); |
| 375 | + |
| 376 | + using (var secretManager = new SecretManager(testRepository, mockValueConverterFactory.Object, _logger, metricsLogger, _hostNameProvider, _startupContextProvider)) |
| 377 | + { |
| 378 | + var tasks = new List<Task<HostSecretsInfo>>(); |
| 379 | + for (int i = 0; i < 10; i++) |
| 380 | + { |
| 381 | + tasks.Add(secretManager.GetHostSecretsAsync()); |
| 382 | + } |
| 383 | + |
| 384 | + await Task.WhenAll(tasks); |
| 385 | + |
| 386 | + // verify all calls return the same result |
| 387 | + var masterKey = tasks.First().Result.MasterKey; |
| 388 | + Assert.True(tasks.Select(p => p.Result).All(q => q.MasterKey == masterKey)); |
| 389 | + } |
| 390 | + } |
| 391 | + |
340 | 392 | [Fact]
|
341 | 393 | public async Task GetHostSecrets_WhenNoHostSecretFileExists_GeneratesSecretsAndPersistsFiles()
|
342 | 394 | {
|
@@ -1124,5 +1176,84 @@ private void CreateTestSecrets(string path)
|
1124 | 1176 | File.WriteAllText(Path.Combine(path, ScriptConstants.HostMetadataFileName), hostSecrets);
|
1125 | 1177 | File.WriteAllText(Path.Combine(path, "testfunction.json"), functionSecrets);
|
1126 | 1178 | }
|
| 1179 | + |
| 1180 | + private class TestSecretsRepository : ISecretsRepository |
| 1181 | + { |
| 1182 | + private int _writeCount = 0; |
| 1183 | + private Random _rand = new Random(); |
| 1184 | + private bool _enforceSerialWrites = false; |
| 1185 | + |
| 1186 | + public TestSecretsRepository(bool enforceSerialWrites) |
| 1187 | + { |
| 1188 | + _enforceSerialWrites = enforceSerialWrites; |
| 1189 | + } |
| 1190 | + |
| 1191 | + public event EventHandler<SecretsChangedEventArgs> SecretsChanged; |
| 1192 | + |
| 1193 | + public ConcurrentDictionary<string, ScriptSecrets> FunctionSecrets { get; } = new ConcurrentDictionary<string, ScriptSecrets>(StringComparer.OrdinalIgnoreCase); |
| 1194 | + |
| 1195 | + public ScriptSecrets HostSecrets { get; private set; } |
| 1196 | + |
| 1197 | + public bool IsEncryptionSupported => throw new NotImplementedException(); |
| 1198 | + |
| 1199 | + public Task<string[]> GetSecretSnapshots(ScriptSecretsType type, string functionName) |
| 1200 | + { |
| 1201 | + return Task.FromResult(new string[0]); |
| 1202 | + } |
| 1203 | + |
| 1204 | + public Task PurgeOldSecretsAsync(IList<string> currentFunctions, ILogger logger) |
| 1205 | + { |
| 1206 | + return Task.CompletedTask; |
| 1207 | + } |
| 1208 | + |
| 1209 | + public Task<ScriptSecrets> ReadAsync(ScriptSecretsType type, string functionName) |
| 1210 | + { |
| 1211 | + ScriptSecrets secrets = null; |
| 1212 | + |
| 1213 | + if (type == ScriptSecretsType.Function) |
| 1214 | + { |
| 1215 | + FunctionSecrets.TryGetValue(functionName, out secrets); |
| 1216 | + } |
| 1217 | + else |
| 1218 | + { |
| 1219 | + secrets = HostSecrets; |
| 1220 | + } |
| 1221 | + |
| 1222 | + return Task.FromResult(secrets); |
| 1223 | + } |
| 1224 | + |
| 1225 | + public async Task WriteAsync(ScriptSecretsType type, string functionName, ScriptSecrets secrets) |
| 1226 | + { |
| 1227 | + if (_enforceSerialWrites && _writeCount > 1) |
| 1228 | + { |
| 1229 | + throw new Exception("Concurrent writes detected!"); |
| 1230 | + } |
| 1231 | + |
| 1232 | + Interlocked.Increment(ref _writeCount); |
| 1233 | + |
| 1234 | + await Task.Delay(_rand.Next(100, 300)); |
| 1235 | + |
| 1236 | + if (type == ScriptSecretsType.Function) |
| 1237 | + { |
| 1238 | + FunctionSecrets[functionName] = secrets; |
| 1239 | + } |
| 1240 | + else |
| 1241 | + { |
| 1242 | + HostSecrets = secrets; |
| 1243 | + } |
| 1244 | + |
| 1245 | + Interlocked.Decrement(ref _writeCount); |
| 1246 | + |
| 1247 | + if (SecretsChanged != null) |
| 1248 | + { |
| 1249 | + SecretsChanged(this, new SecretsChangedEventArgs { SecretsType = type, Name = functionName }); |
| 1250 | + } |
| 1251 | + } |
| 1252 | + |
| 1253 | + public Task WriteSnapshotAsync(ScriptSecretsType type, string functionName, ScriptSecrets secrets) |
| 1254 | + { |
| 1255 | + return Task.CompletedTask; |
| 1256 | + } |
| 1257 | + } |
1127 | 1258 | }
|
1128 | 1259 | }
|
0 commit comments