Skip to content

Commit 9ac28fd

Browse files
gohar94Gohar Irfan Chaudhrypragnagopa
authored
Shared memory data transfer between Functions Host and out-of-proc workers (#6836)
* Shared memory data transfer Co-authored-by: Gohar Irfan Chaudhry <[email protected]> Co-authored-by: pragnagopa <[email protected]>
1 parent 79d9417 commit 9ac28fd

27 files changed

+5928
-1850
lines changed

src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs

+141-6
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,26 @@
99
using System.IO;
1010
using System.Linq;
1111
using System.Reactive.Linq;
12+
using System.Text;
1213
using System.Threading.Tasks;
1314
using System.Threading.Tasks.Dataflow;
1415
using Microsoft.Azure.WebJobs.Script.Description;
1516
using Microsoft.Azure.WebJobs.Script.Diagnostics;
1617
using Microsoft.Azure.WebJobs.Script.Eventing;
1718
using Microsoft.Azure.WebJobs.Script.Grpc.Eventing;
19+
using Microsoft.Azure.WebJobs.Script.Grpc.Extensions;
1820
using Microsoft.Azure.WebJobs.Script.Grpc.Messages;
1921
using Microsoft.Azure.WebJobs.Script.ManagedDependencies;
2022
using Microsoft.Azure.WebJobs.Script.Workers;
2123
using Microsoft.Azure.WebJobs.Script.Workers.Rpc;
24+
using Microsoft.Azure.WebJobs.Script.Workers.SharedMemoryDataTransfer;
25+
using Microsoft.CodeAnalysis.VisualBasic.Syntax;
2226
using Microsoft.Extensions.Logging;
2327
using Microsoft.Extensions.Options;
2428
using static Microsoft.Azure.WebJobs.Script.Grpc.Messages.RpcLog.Types;
2529
using FunctionMetadata = Microsoft.Azure.WebJobs.Script.Description.FunctionMetadata;
2630
using MsgType = Microsoft.Azure.WebJobs.Script.Grpc.Messages.StreamingMessage.ContentOneofCase;
31+
using ParameterBindingType = Microsoft.Azure.WebJobs.Script.Grpc.Messages.ParameterBinding.RpcDataOneofCase;
2732

2833
namespace Microsoft.Azure.WebJobs.Script.Grpc
2934
{
@@ -35,6 +40,7 @@ internal class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable
3540
private readonly string _runtime;
3641
private readonly IEnvironment _environment;
3742
private readonly IOptionsMonitor<ScriptApplicationHostOptions> _applicationHostOptions;
43+
private readonly ISharedMemoryManager _sharedMemoryManager;
3844

3945
private IDisposable _functionLoadRequestResponseEvent;
4046
private bool _disposed;
@@ -60,6 +66,7 @@ internal class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable
6066
private TaskCompletionSource<bool> _reloadTask = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
6167
private TaskCompletionSource<bool> _workerInitTask = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
6268
private TimeSpan _functionLoadTimeout = TimeSpan.FromMinutes(10);
69+
private bool _isSharedMemoryDataTransferEnabled;
6370

6471
internal GrpcWorkerChannel(
6572
string workerId,
@@ -70,7 +77,8 @@ internal GrpcWorkerChannel(
7077
IMetricsLogger metricsLogger,
7178
int attemptCount,
7279
IEnvironment environment,
73-
IOptionsMonitor<ScriptApplicationHostOptions> applicationHostOptions)
80+
IOptionsMonitor<ScriptApplicationHostOptions> applicationHostOptions,
81+
ISharedMemoryManager sharedMemoryManager)
7482
{
7583
_workerId = workerId;
7684
_eventManager = eventManager;
@@ -81,6 +89,7 @@ internal GrpcWorkerChannel(
8189
_metricsLogger = metricsLogger;
8290
_environment = environment;
8391
_applicationHostOptions = applicationHostOptions;
92+
_sharedMemoryManager = sharedMemoryManager;
8493

8594
_workerCapabilities = new GrpcCapabilities(_workerChannelLogger);
8695

@@ -101,7 +110,7 @@ internal GrpcWorkerChannel(
101110
.Subscribe(msg => _eventManager.Publish(new HostRestartEvent())));
102111

103112
_eventSubscriptions.Add(_inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.InvocationResponse)
104-
.Subscribe((msg) => InvokeResponse(msg.Message.InvocationResponse)));
113+
.Subscribe(async (msg) => await InvokeResponse(msg.Message.InvocationResponse)));
105114

106115
_inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.WorkerStatusResponse)
107116
.Subscribe((msg) => ReceiveWorkerStatusResponse(msg.Message.RequestId, msg.Message.WorkerStatusResponse));
@@ -239,6 +248,7 @@ internal void WorkerInitResponse(GrpcEvent initEvent)
239248
}
240249
_state = _state | RpcWorkerChannelState.Initialized;
241250
_workerCapabilities.UpdateCapabilities(_initMessage.Capabilities);
251+
_isSharedMemoryDataTransferEnabled = IsSharedMemoryDataTransferEnabled();
242252
_workerInitTask.SetResult(true);
243253
}
244254

@@ -406,7 +416,7 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context)
406416
context.ResultSource.SetCanceled();
407417
return;
408418
}
409-
var invocationRequest = await context.ToRpcInvocationRequest(_workerChannelLogger, _workerCapabilities);
419+
var invocationRequest = await context.ToRpcInvocationRequest(_workerChannelLogger, _workerCapabilities, _isSharedMemoryDataTransferEnabled, _sharedMemoryManager);
410420
_executingInvocations.TryAdd(invocationRequest.InvocationId, context);
411421

412422
SendStreamingMessage(new StreamingMessage
@@ -421,7 +431,41 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context)
421431
}
422432
}
423433

424-
internal void InvokeResponse(InvocationResponse invokeResponse)
434+
private async Task<object> GetBindingDataAsync(ParameterBinding binding, string invocationId)
435+
{
436+
switch (binding.RpcDataCase)
437+
{
438+
case ParameterBindingType.RpcSharedMemory:
439+
// Data was transferred by the worker using shared memory
440+
return await binding.RpcSharedMemory.ToObjectAsync(_workerChannelLogger, invocationId, _sharedMemoryManager);
441+
case ParameterBindingType.Data:
442+
// Data was transferred by the worker using RPC
443+
return binding.Data.ToObject();
444+
default:
445+
throw new InvalidOperationException("Unknown ParameterBindingType");
446+
}
447+
}
448+
449+
/// <summary>
450+
/// From the output data produced by the worker, get a list of the shared memory maps that were created for this invocation.
451+
/// </summary>
452+
/// <param name="bindings">List of <see cref="ParameterBinding"/> produced by the worker as output.</param>
453+
/// <returns>List of names of shared memory maps produced by the worker.</returns>
454+
private IList<string> GetOutputMaps(IList<ParameterBinding> bindings)
455+
{
456+
IList<string> outputMaps = new List<string>();
457+
foreach (ParameterBinding binding in bindings)
458+
{
459+
if (binding.RpcSharedMemory != null)
460+
{
461+
outputMaps.Add(binding.RpcSharedMemory.Name);
462+
}
463+
}
464+
465+
return outputMaps;
466+
}
467+
468+
internal async Task InvokeResponse(InvocationResponse invokeResponse)
425469
{
426470
_workerChannelLogger.LogDebug("InvocationResponse received for invocation id: {Id}", invokeResponse.InvocationId);
427471

@@ -430,8 +474,29 @@ internal void InvokeResponse(InvocationResponse invokeResponse)
430474
{
431475
try
432476
{
433-
IDictionary<string, object> bindingsDictionary = invokeResponse.OutputData
434-
.ToDictionary(binding => binding.Name, binding => binding.Data.ToObject());
477+
StringBuilder logBuilder = new StringBuilder();
478+
bool usedSharedMemory = false;
479+
480+
foreach (ParameterBinding binding in invokeResponse.OutputData)
481+
{
482+
switch (binding.RpcDataCase)
483+
{
484+
case ParameterBindingType.RpcSharedMemory:
485+
logBuilder.AppendFormat("{0}:{1},", binding.Name, binding.RpcSharedMemory.Count);
486+
usedSharedMemory = true;
487+
break;
488+
default:
489+
break;
490+
}
491+
}
492+
493+
if (usedSharedMemory)
494+
{
495+
_workerChannelLogger.LogDebug("Shared memory usage for response of invocation Id: {Id} is {SharedMemoryUsage}", invokeResponse.InvocationId, logBuilder.ToString());
496+
}
497+
498+
IDictionary<string, object> bindingsDictionary = await invokeResponse.OutputData
499+
.ToDictionaryAsync(binding => binding.Name, binding => GetBindingDataAsync(binding, invokeResponse.InvocationId));
435500

436501
var result = new ScriptInvocationResult()
437502
{
@@ -444,9 +509,40 @@ internal void InvokeResponse(InvocationResponse invokeResponse)
444509
{
445510
context.ResultSource.TrySetException(responseEx);
446511
}
512+
finally
513+
{
514+
// Free memory allocated by the host (for input bindings)
515+
if (!_sharedMemoryManager.TryFreeSharedMemoryMapsForInvocation(invokeResponse.InvocationId))
516+
{
517+
_workerChannelLogger.LogWarning($"Cannot free all shared memory resources for invocation: {invokeResponse.InvocationId}");
518+
}
519+
520+
// List of shared memory maps that were produced by the worker (for output bindings)
521+
IList<string> outputMaps = GetOutputMaps(invokeResponse.OutputData);
522+
if (outputMaps.Count > 0)
523+
{
524+
// If this invocation was using any shared memory maps produced by the worker, close them to free memory
525+
SendCloseSharedMemoryResourcesForInvocationRequest(outputMaps);
526+
}
527+
}
447528
}
448529
}
449530

531+
/// <summary>
532+
/// Request to free memory allocated by the worker (for output bindings)
533+
/// </summary>
534+
/// <param name="outputMaps">List of names of shared memory maps to close from the worker.</param>
535+
internal void SendCloseSharedMemoryResourcesForInvocationRequest(IList<string> outputMaps)
536+
{
537+
CloseSharedMemoryResourcesRequest closeSharedMemoryResourcesRequest = new CloseSharedMemoryResourcesRequest();
538+
closeSharedMemoryResourcesRequest.MapNames.AddRange(outputMaps);
539+
540+
SendStreamingMessage(new StreamingMessage()
541+
{
542+
CloseSharedMemoryResourcesRequest = closeSharedMemoryResourcesRequest
543+
});
544+
}
545+
450546
internal void Log(GrpcEvent msg)
451547
{
452548
var rpcLog = msg.Message.RpcLog;
@@ -616,5 +712,44 @@ public bool TryFailExecutions(Exception workerException)
616712
}
617713
return true;
618714
}
715+
716+
/// <summary>
717+
/// Determine if shared memory transfer is enabled.
718+
/// The following conditions must be met:
719+
/// 1) <see cref="RpcWorkerConstants.FunctionsWorkerSharedMemoryDataTransferEnabledSettingName"/> must be set in environment variable (AppSetting).
720+
/// 2) Worker must have the capability <see cref="RpcWorkerConstants.SharedMemoryDataTransfer"/>.
721+
/// </summary>
722+
/// <returns><see cref="true"/> if shared memory data transfer is enabled, <see cref="false"/> otherwise.</returns>
723+
internal bool IsSharedMemoryDataTransferEnabled()
724+
{
725+
// Check if the environment variable (AppSetting) has this feature enabled
726+
string envVal = _environment.GetEnvironmentVariable(RpcWorkerConstants.FunctionsWorkerSharedMemoryDataTransferEnabledSettingName);
727+
if (string.IsNullOrEmpty(envVal))
728+
{
729+
return false;
730+
}
731+
732+
bool envValEnabled = false;
733+
if (bool.TryParse(envVal, out bool boolResult))
734+
{
735+
// Check if value was specified as a bool (true/false)
736+
envValEnabled = boolResult;
737+
}
738+
else if (int.TryParse(envVal, out int intResult) && intResult == 1)
739+
{
740+
// Check if value was specified as an int (1/0)
741+
envValEnabled = true;
742+
}
743+
744+
if (!envValEnabled)
745+
{
746+
return false;
747+
}
748+
749+
// Check if the worker supports this feature
750+
bool capabilityEnabled = !string.IsNullOrEmpty(_workerCapabilities.GetCapabilityState(RpcWorkerConstants.SharedMemoryDataTransfer));
751+
_workerChannelLogger.LogDebug("IsSharedMemoryDataTransferEnabled: {SharedMemoryDataTransferEnabled}", capabilityEnabled);
752+
return capabilityEnabled;
753+
}
619754
}
620755
}

src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannelFactory.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Microsoft.Azure.WebJobs.Script.Eventing;
1010
using Microsoft.Azure.WebJobs.Script.Workers;
1111
using Microsoft.Azure.WebJobs.Script.Workers.Rpc;
12+
using Microsoft.Azure.WebJobs.Script.Workers.SharedMemoryDataTransfer;
1213
using Microsoft.Extensions.Logging;
1314
using Microsoft.Extensions.Options;
1415

@@ -21,15 +22,17 @@ public class GrpcWorkerChannelFactory : IRpcWorkerChannelFactory
2122
private readonly IScriptEventManager _eventManager = null;
2223
private readonly IEnvironment _environment = null;
2324
private readonly IOptionsMonitor<ScriptApplicationHostOptions> _applicationHostOptions = null;
25+
private readonly ISharedMemoryManager _sharedMemoryManager = null;
2426

2527
public GrpcWorkerChannelFactory(IScriptEventManager eventManager, IEnvironment environment, IRpcServer rpcServer, ILoggerFactory loggerFactory, IOptionsMonitor<LanguageWorkerOptions> languageWorkerOptions,
26-
IOptionsMonitor<ScriptApplicationHostOptions> applicationHostOptions, IRpcWorkerProcessFactory rpcWorkerProcessManager)
28+
IOptionsMonitor<ScriptApplicationHostOptions> applicationHostOptions, IRpcWorkerProcessFactory rpcWorkerProcessManager, ISharedMemoryManager sharedMemoryManager)
2729
{
2830
_eventManager = eventManager;
2931
_loggerFactory = loggerFactory;
3032
_rpcWorkerProcessFactory = rpcWorkerProcessManager;
3133
_environment = environment;
3234
_applicationHostOptions = applicationHostOptions;
35+
_sharedMemoryManager = sharedMemoryManager;
3336
}
3437

3538
public IRpcWorkerChannel Create(string scriptRootPath, string runtime, IMetricsLogger metricsLogger, int attemptCount, IEnumerable<RpcWorkerConfig> workerConfigs)
@@ -51,7 +54,8 @@ public IRpcWorkerChannel Create(string scriptRootPath, string runtime, IMetricsL
5154
metricsLogger,
5255
attemptCount,
5356
_environment,
54-
_applicationHostOptions);
57+
_applicationHostOptions,
58+
_sharedMemoryManager);
5559
}
5660
}
5761
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the MIT License. See License.txt in the project root for license information.
3+
4+
using System.IO;
5+
using System.Threading.Tasks;
6+
using Microsoft.Azure.WebJobs.Script.Grpc.Messages;
7+
using Microsoft.Azure.WebJobs.Script.Workers.SharedMemoryDataTransfer;
8+
using Microsoft.Extensions.Logging;
9+
10+
namespace Microsoft.Azure.WebJobs.Script.Grpc.Extensions
11+
{
12+
internal static class RpcSharedMemoryDataExtensions
13+
{
14+
internal static async Task<RpcSharedMemory> ToRpcSharedMemoryAsync(this object value, ILogger logger, string invocationId, ISharedMemoryManager sharedMemoryManager)
15+
{
16+
if (!sharedMemoryManager.IsSupported(value))
17+
{
18+
return null;
19+
}
20+
21+
// Put the content into shared memory and get the name of the shared memory map written to
22+
SharedMemoryMetadata putResponse = await sharedMemoryManager.PutObjectAsync(value);
23+
if (putResponse == null)
24+
{
25+
logger.LogTrace("Cannot write to shared memory for invocation id: {Id}", invocationId);
26+
return null;
27+
}
28+
29+
// If written to shared memory successfully, add this shared memory map to the list of maps for this invocation
30+
sharedMemoryManager.AddSharedMemoryMapForInvocation(invocationId, putResponse.Name);
31+
32+
RpcDataType? dataType = GetRpcDataType(value);
33+
if (!dataType.HasValue)
34+
{
35+
logger.LogTrace("Cannot get shared memory data type for invocation id: {Id}", invocationId);
36+
return null;
37+
}
38+
39+
// Generate a response
40+
RpcSharedMemory sharedMem = new RpcSharedMemory()
41+
{
42+
Name = putResponse.Name,
43+
Offset = 0,
44+
Count = putResponse.Count,
45+
Type = dataType.Value
46+
};
47+
48+
logger.LogTrace("Put object in shared memory for invocation id: {Id}", invocationId);
49+
return sharedMem;
50+
}
51+
52+
internal static async Task<object> ToObjectAsync(this RpcSharedMemory sharedMem, ILogger logger, string invocationId, ISharedMemoryManager sharedMemoryManager)
53+
{
54+
// Data was transferred by the worker using shared memory
55+
string mapName = sharedMem.Name;
56+
int offset = (int)sharedMem.Offset;
57+
int count = (int)sharedMem.Count;
58+
logger.LogTrace("Shared memory data transfer for invocation id: {Id} with shared memory map name: {MapName} and size: {Size} bytes", invocationId, mapName, count);
59+
60+
switch (sharedMem.Type)
61+
{
62+
case RpcDataType.Bytes:
63+
return await sharedMemoryManager.GetObjectAsync(mapName, offset, count, typeof(byte[]));
64+
case RpcDataType.String:
65+
return await sharedMemoryManager.GetObjectAsync(mapName, offset, count, typeof(string));
66+
default:
67+
logger.LogError("Unsupported shared memory data type: {SharedMemDataType} for invocation id: {Id}", sharedMem.Type, invocationId);
68+
throw new InvalidDataException($"Unsupported shared memory data type: {sharedMem.Type}");
69+
}
70+
}
71+
72+
private static RpcDataType? GetRpcDataType(object value)
73+
{
74+
if (value is byte[])
75+
{
76+
return RpcDataType.Bytes;
77+
}
78+
else if (value is string)
79+
{
80+
return RpcDataType.String;
81+
}
82+
83+
return null;
84+
}
85+
}
86+
}

0 commit comments

Comments
 (0)