Skip to content

Commit 79d9417

Browse files
authored
Serialize Function key generation (#7106) (#7124)
1 parent c026339 commit 79d9417

File tree

2 files changed

+177
-28
lines changed

2 files changed

+177
-28
lines changed

src/WebJobs.Script.WebHost/Security/KeyManagement/SecretManager.cs

+46-28
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public class SecretManager : IDisposable, ISecretManager
3131
private ConcurrentDictionary<string, (string, AuthorizationLevel)> _authorizationCache = new ConcurrentDictionary<string, (string, AuthorizationLevel)>(StringComparer.OrdinalIgnoreCase);
3232
private HostSecretsInfo _hostSecrets;
3333
private SemaphoreSlim _hostSecretsLock = new SemaphoreSlim(1, 1);
34+
private ConcurrentDictionary<string, SemaphoreSlim> _functionSecretsLocks = new ConcurrentDictionary<string, SemaphoreSlim>(StringComparer.OrdinalIgnoreCase);
3435
private IMetricsLogger _metricsLogger;
3536
private string _repositoryClassName;
3637
private DateTime _lastCacheResetTime;
@@ -149,41 +150,51 @@ public async virtual Task<IDictionary<string, string>> GetFunctionSecretsAsync(s
149150
{
150151
using (_metricsLogger.LatencyEvent(GetMetricEventName(MetricEventNames.SecretManagerGetFunctionSecrets), functionName))
151152
{
152-
_logger.LogDebug($"Loading secrets for function '{functionName}'");
153+
var functionSecretsLock = GetFunctionSecretsLock(functionName);
154+
await functionSecretsLock.WaitAsync();
153155

154-
FunctionSecrets secrets = await LoadFunctionSecretsAsync(functionName);
155-
if (secrets == null)
156+
try
156157
{
157-
// no secrets exist for this function so generate them
158-
string message = string.Format(Resources.TraceFunctionSecretGeneration, functionName);
159-
_logger.LogDebug(message);
160-
secrets = GenerateFunctionSecrets();
158+
_logger.LogDebug($"Loading secrets for function '{functionName}'");
161159

162-
await PersistSecretsAsync(secrets, functionName);
163-
}
160+
FunctionSecrets secrets = await LoadFunctionSecretsAsync(functionName);
161+
if (secrets == null)
162+
{
163+
// no secrets exist for this function so generate them
164+
string message = string.Format(Resources.TraceFunctionSecretGeneration, functionName);
165+
_logger.LogDebug(message);
166+
secrets = GenerateFunctionSecrets();
164167

165-
try
166-
{
167-
// Read all secrets, which will run the keys through the appropriate readers
168-
secrets.Keys = secrets.Keys.Select(k => _keyValueConverterFactory.ReadKey(k)).ToList();
169-
}
170-
catch (CryptographicException ex)
171-
{
172-
string message = string.Format(Resources.TraceNonDecryptedFunctionSecretRefresh, functionName, ex);
173-
_logger.LogDebug(message);
174-
await PersistSecretsAsync(secrets, functionName, true);
175-
secrets = GenerateFunctionSecrets(secrets);
176-
await RefreshSecretsAsync(secrets, functionName);
177-
}
168+
await PersistSecretsAsync(secrets, functionName);
169+
}
178170

179-
if (secrets.HasStaleKeys)
171+
try
172+
{
173+
// Read all secrets, which will run the keys through the appropriate readers
174+
secrets.Keys = secrets.Keys.Select(k => _keyValueConverterFactory.ReadKey(k)).ToList();
175+
}
176+
catch (CryptographicException ex)
177+
{
178+
string message = string.Format(Resources.TraceNonDecryptedFunctionSecretRefresh, functionName, ex);
179+
_logger.LogDebug(message);
180+
await PersistSecretsAsync(secrets, functionName, true);
181+
secrets = GenerateFunctionSecrets(secrets);
182+
await RefreshSecretsAsync(secrets, functionName);
183+
}
184+
185+
if (secrets.HasStaleKeys)
186+
{
187+
_logger.LogDebug(string.Format(Resources.TraceStaleFunctionSecretRefresh, functionName));
188+
await RefreshSecretsAsync(secrets, functionName);
189+
}
190+
191+
var result = secrets.Keys.ToDictionary(s => s.Name, s => s.Value);
192+
functionSecrets = _functionSecrets.AddOrUpdate(functionName, result, (n, r) => result);
193+
}
194+
finally
180195
{
181-
_logger.LogDebug(string.Format(Resources.TraceStaleFunctionSecretRefresh, functionName));
182-
await RefreshSecretsAsync(secrets, functionName);
196+
functionSecretsLock.Release();
183197
}
184-
185-
var result = secrets.Keys.ToDictionary(s => s.Name, s => s.Value);
186-
functionSecrets = _functionSecrets.AddOrUpdate(functionName, result, (n, r) => result);
187198
}
188199
}
189200

@@ -727,5 +738,12 @@ private string GetEncryptionKeysHashes()
727738

728739
return result;
729740
}
741+
742+
private SemaphoreSlim GetFunctionSecretsLock(string functionName)
743+
{
744+
// We're only serializing access to secrets per-function, not across all functions,
745+
// so we need to ensure we're using a single shared lock per-function.
746+
return _functionSecretsLocks.GetOrAdd(functionName, k => new SemaphoreSlim(1, 1));
747+
}
730748
}
731749
}

test/WebJobs.Script.Tests/Security/SecretManagerTests.cs

+131
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
// Licensed under the MIT License. See License.txt in the project root for license information.
33

44
using System;
5+
using System.Collections.Concurrent;
56
using System.Collections.Generic;
67
using System.IO;
78
using System.Linq;
89
using System.Security.Cryptography;
10+
using System.Threading;
911
using System.Threading.Tasks;
1012
using Microsoft.Azure.WebJobs.Extensions.Http;
1113
using Microsoft.Azure.WebJobs.Logging;
@@ -14,6 +16,7 @@
1416
using Microsoft.Azure.WebJobs.Script.WebHost.Models;
1517
using Microsoft.Azure.WebJobs.Script.WebHost.Properties;
1618
using Microsoft.Azure.WebJobs.Script.WebHost.Security;
19+
using Microsoft.Azure.WebJobs.Script.WebHost.Storage;
1720
using Microsoft.Extensions.Logging;
1821
using Microsoft.WebJobs.Script.Tests;
1922
using Moq;
@@ -337,6 +340,55 @@ public async Task GetHostSecrets_UpdatesStaleSecrets()
337340
}
338341
}
339342

343+
[Fact]
344+
public async Task GetFunctionSecretsAsync_SecretGenerationIsSerialized()
345+
{
346+
var mockValueConverterFactory = GetConverterFactoryMock(false, false);
347+
var metricsLogger = new TestMetricsLogger();
348+
var testRepository = new TestSecretsRepository(true);
349+
string testFunctionName = $"TestFunction";
350+
351+
using (var secretManager = new SecretManager(testRepository, mockValueConverterFactory.Object, _logger, metricsLogger, _hostNameProvider, _startupContextProvider))
352+
{
353+
var tasks = new List<Task<IDictionary<string, string>>>();
354+
for (int i = 0; i < 10; i++)
355+
{
356+
tasks.Add(secretManager.GetFunctionSecretsAsync(testFunctionName));
357+
}
358+
359+
await Task.WhenAll(tasks);
360+
361+
// verify all calls return the same result
362+
Assert.Equal(1, testRepository.FunctionSecrets.Count);
363+
var functionSecrets = (FunctionSecrets)testRepository.FunctionSecrets[testFunctionName];
364+
string defaultKeyValue = functionSecrets.Keys.Where(p => p.Name == "default").Single().Value;
365+
Assert.True(tasks.Select(p => p.Result).All(t => t["default"] == defaultKeyValue));
366+
}
367+
}
368+
369+
[Fact]
370+
public async Task GetHostSecretsAsync_SecretGenerationIsSerialized()
371+
{
372+
var mockValueConverterFactory = GetConverterFactoryMock(false, false);
373+
var metricsLogger = new TestMetricsLogger();
374+
var testRepository = new TestSecretsRepository(true);
375+
376+
using (var secretManager = new SecretManager(testRepository, mockValueConverterFactory.Object, _logger, metricsLogger, _hostNameProvider, _startupContextProvider))
377+
{
378+
var tasks = new List<Task<HostSecretsInfo>>();
379+
for (int i = 0; i < 10; i++)
380+
{
381+
tasks.Add(secretManager.GetHostSecretsAsync());
382+
}
383+
384+
await Task.WhenAll(tasks);
385+
386+
// verify all calls return the same result
387+
var masterKey = tasks.First().Result.MasterKey;
388+
Assert.True(tasks.Select(p => p.Result).All(q => q.MasterKey == masterKey));
389+
}
390+
}
391+
340392
[Fact]
341393
public async Task GetHostSecrets_WhenNoHostSecretFileExists_GeneratesSecretsAndPersistsFiles()
342394
{
@@ -1124,5 +1176,84 @@ private void CreateTestSecrets(string path)
11241176
File.WriteAllText(Path.Combine(path, ScriptConstants.HostMetadataFileName), hostSecrets);
11251177
File.WriteAllText(Path.Combine(path, "testfunction.json"), functionSecrets);
11261178
}
1179+
1180+
private class TestSecretsRepository : ISecretsRepository
1181+
{
1182+
private int _writeCount = 0;
1183+
private Random _rand = new Random();
1184+
private bool _enforceSerialWrites = false;
1185+
1186+
public TestSecretsRepository(bool enforceSerialWrites)
1187+
{
1188+
_enforceSerialWrites = enforceSerialWrites;
1189+
}
1190+
1191+
public event EventHandler<SecretsChangedEventArgs> SecretsChanged;
1192+
1193+
public ConcurrentDictionary<string, ScriptSecrets> FunctionSecrets { get; } = new ConcurrentDictionary<string, ScriptSecrets>(StringComparer.OrdinalIgnoreCase);
1194+
1195+
public ScriptSecrets HostSecrets { get; private set; }
1196+
1197+
public bool IsEncryptionSupported => throw new NotImplementedException();
1198+
1199+
public Task<string[]> GetSecretSnapshots(ScriptSecretsType type, string functionName)
1200+
{
1201+
return Task.FromResult(new string[0]);
1202+
}
1203+
1204+
public Task PurgeOldSecretsAsync(IList<string> currentFunctions, ILogger logger)
1205+
{
1206+
return Task.CompletedTask;
1207+
}
1208+
1209+
public Task<ScriptSecrets> ReadAsync(ScriptSecretsType type, string functionName)
1210+
{
1211+
ScriptSecrets secrets = null;
1212+
1213+
if (type == ScriptSecretsType.Function)
1214+
{
1215+
FunctionSecrets.TryGetValue(functionName, out secrets);
1216+
}
1217+
else
1218+
{
1219+
secrets = HostSecrets;
1220+
}
1221+
1222+
return Task.FromResult(secrets);
1223+
}
1224+
1225+
public async Task WriteAsync(ScriptSecretsType type, string functionName, ScriptSecrets secrets)
1226+
{
1227+
if (_enforceSerialWrites && _writeCount > 1)
1228+
{
1229+
throw new Exception("Concurrent writes detected!");
1230+
}
1231+
1232+
Interlocked.Increment(ref _writeCount);
1233+
1234+
await Task.Delay(_rand.Next(100, 300));
1235+
1236+
if (type == ScriptSecretsType.Function)
1237+
{
1238+
FunctionSecrets[functionName] = secrets;
1239+
}
1240+
else
1241+
{
1242+
HostSecrets = secrets;
1243+
}
1244+
1245+
Interlocked.Decrement(ref _writeCount);
1246+
1247+
if (SecretsChanged != null)
1248+
{
1249+
SecretsChanged(this, new SecretsChangedEventArgs { SecretsType = type, Name = functionName });
1250+
}
1251+
}
1252+
1253+
public Task WriteSnapshotAsync(ScriptSecretsType type, string functionName, ScriptSecrets secrets)
1254+
{
1255+
return Task.CompletedTask;
1256+
}
1257+
}
11271258
}
11281259
}

0 commit comments

Comments
 (0)