1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: MIT-0
3
+
4
+ using System ;
5
+ using System . Reflection ;
6
+ using System . Net . Http ;
7
+ using System . Text ;
8
+ using System . Text . Json ;
9
+ using System . Text . Json . Serialization ;
10
+ using System . Linq ;
11
+ using System . Threading . Tasks ;
12
+ using System . Threading ;
13
+
14
+ namespace csharp_example_extension
15
+ {
16
+ /// <summary>
17
+ /// Lambda Extension API client
18
+ /// </summary>
19
+ internal class ExtensionClient : IDisposable
20
+ {
21
+ #region HTTP header key names
22
+
23
+ /// <summary>
24
+ /// HTTP header that is used to register a new extension name with Extension API
25
+ /// </summary>
26
+ private const string LambdaExtensionNameHeader = "Lambda-Extension-Name" ;
27
+
28
+ /// <summary>
29
+ /// HTTP header used to provide extension registration id
30
+ /// </summary>
31
+ /// <remarks>
32
+ /// Registration endpoint reply will have this header value with a new id, assigned to this extension by the API.
33
+ /// All other endpoints will expect HTTP calls to have id header attached to all requests.
34
+ /// </remarks>
35
+ private const string LambdaExtensionIdHeader = "Lambda-Extension-Identifier" ;
36
+
37
+ /// <summary>
38
+ /// HTTP header to report Lambda Extension error type string.
39
+ /// </summary>
40
+ /// <remarks>
41
+ /// This header is used to report additional error details for Init and Shutdown errors.
42
+ /// </remarks>
43
+ private const string LambdaExtensionFunctionErrorTypeHeader = "Lambda-Extension-Function-Error-Type" ;
44
+
45
+ #endregion
46
+
47
+ #region Environment variable names
48
+
49
+ /// <summary>
50
+ /// Environment variable that holds server name and port number for Extension API endpoints
51
+ /// </summary>
52
+ private const string LambdaRuntimeApiAddress = "AWS_LAMBDA_RUNTIME_API" ;
53
+
54
+ #endregion
55
+
56
+ #region Instance properties
57
+
58
+ /// <summary>
59
+ /// Extension id, which is assigned to this extension after the registration
60
+ /// </summary>
61
+ public string Id { get ; private set ; }
62
+
63
+ #endregion
64
+
65
+ #region Constructor and readonly variables
66
+
67
+ /// <summary>
68
+ /// Http client instance
69
+ /// </summary>
70
+ /// <remarks>This is an IDisposable object that must be properly disposed of,
71
+ /// thus <see cref="ExtensionClient"/> implements <see cref="IDisposable"/> interface too.</remarks>
72
+ private readonly HttpClient httpClient = new HttpClient ( ) ;
73
+
74
+ /// <summary>
75
+ /// Extension name, calculated from the current executing assembly name
76
+ /// </summary>
77
+ private readonly string extensionName ;
78
+
79
+ /// <summary>
80
+ /// Extension registration URL
81
+ /// </summary>
82
+ private readonly Uri registerUrl ;
83
+
84
+ /// <summary>
85
+ /// Next event long poll URL
86
+ /// </summary>
87
+ private readonly Uri nextUrl ;
88
+
89
+ /// <summary>
90
+ /// Extension initialization error reporting URL
91
+ /// </summary>
92
+ private readonly Uri initErrorUrl ;
93
+
94
+ /// <summary>
95
+ /// Extension shutdown error reporting URL
96
+ /// </summary>
97
+ private readonly Uri shutdownErrorUrl ;
98
+
99
+ /// <summary>
100
+ /// Constructor
101
+ /// </summary>
102
+ public ExtensionClient ( string extensionName )
103
+ {
104
+ this . extensionName = extensionName ?? throw new ArgumentNullException ( nameof ( extensionName ) , "Extension name cannot be null" ) ;
105
+
106
+ // Set infinite timeout so that underlying connection is kept alive
107
+ this . httpClient . Timeout = Timeout . InfiniteTimeSpan ;
108
+ // Get Extension API service base URL from the environment variable
109
+ var apiUri = new UriBuilder ( Environment . GetEnvironmentVariable ( LambdaRuntimeApiAddress ) ) . Uri ;
110
+ // Common path for all Extension API URLs
111
+ var basePath = "2020-01-01/extension" ;
112
+
113
+ // Calculate all Extension API endpoints' URLs
114
+ this . registerUrl = new Uri ( apiUri , $ "{ basePath } /register") ;
115
+ this . nextUrl = new Uri ( apiUri , $ "{ basePath } /event/next") ;
116
+ this . initErrorUrl = new Uri ( apiUri , $ "{ basePath } /init/error") ;
117
+ this . shutdownErrorUrl = new Uri ( apiUri , $ "{ basePath } /exit/error") ;
118
+ }
119
+
120
+ #endregion
121
+
122
+ #region Public interface
123
+
124
+ /// <summary>
125
+ /// Extension registration and event loop handling
126
+ /// </summary>
127
+ /// <param name="onInit">Optional lambda extension that is invoked when extension has been successfully registered with AWS Lambda Extension API.
128
+ /// This function will be called exactly once if it is defined and ignored if this parameter is null.</param>
129
+ /// <param name="onInvoke">Optional lambda extension that is invoked every time AWS Lambda Extension API reports a new <see cref="ExtensionEvent.INVOKE"/> event.
130
+ /// This function will be called once for each <see cref="ExtensionEvent.INVOKE"/> event during the entire lifetime of AWS Lambda function instance.</param>
131
+ /// <param name="onShutdown">Optional lambda extension that is invoked when extension receives <see cref="ExtensionEvent.SHUTDOWN"/> event from AWS LAmbda Extension API.
132
+ /// This function will be called exactly once if it is defined and ignored if this parameter is null.</param>
133
+ /// <returns>Awaitable void</returns>
134
+ /// <remarks>Unhandled exceptions thrown by <paramref name="onInit"/> and <paramref name="onShutdown"/> functions will be reported to AWS Lambda API with
135
+ /// <c>/init/error</c> and <c>/exit/error</c> calls, in any case <see cref="ProcessEvents"/> will immediately exit after reporting the error.
136
+ /// Unhandled <paramref name="onInvoke"/> exceptions are logged to console and ignored, so that extension execution can continue.
137
+ /// </remarks>
138
+ public async Task ProcessEvents ( Func < string , Task > onInit = null , Func < string , Task > onInvoke = null , Func < string , Task > onShutdown = null )
139
+ {
140
+ // Register extension with AWS Lambda Extension API to handle both INVOKE and SHUTDOWN events
141
+ await RegisterExtensionAsync ( ExtensionEvent . INVOKE , ExtensionEvent . SHUTDOWN ) ;
142
+
143
+ // If onInit function is defined, invoke it and report any unhandled exceptions
144
+ if ( ! await SafeInvoke ( onInit , this . Id , ex => ReportErrorAsync ( this . initErrorUrl , "Fatal.Unhandled" , ex ) ) ) return ;
145
+
146
+ // loop till SHUTDOWN event is received
147
+ var hasNext = true ;
148
+ while ( hasNext )
149
+ {
150
+ // get the next event type and details
151
+ var ( type , payload ) = await GetNextAsync ( ) ;
152
+
153
+ switch ( type )
154
+ {
155
+ case ExtensionEvent . INVOKE :
156
+ // invoke onInit function if one is defined and log unhandled exceptions
157
+ // event loop will continue even if there was an exception
158
+ await SafeInvoke ( onInvoke , payload , onException : ex => {
159
+ Console . WriteLine ( $ "[{ this . extensionName } ] Invoke handler threw an exception") ;
160
+ return Task . CompletedTask ;
161
+ } ) ;
162
+ break ;
163
+ case ExtensionEvent . SHUTDOWN :
164
+ // terminate the loop, invoke onShutdown function if there is any and report any unhandled exceptions to AWS Extension API
165
+ hasNext = false ;
166
+ await SafeInvoke ( onShutdown , this . Id , ex => ReportErrorAsync ( this . shutdownErrorUrl , "Fatal.Unhandled" , ex ) ) ;
167
+ break ;
168
+ default :
169
+ throw new ApplicationException ( $ "Unexpected event type: { type } ") ;
170
+ }
171
+ }
172
+ }
173
+
174
+ #endregion
175
+
176
+ #region Private methods
177
+
178
+ /// <summary>
179
+ /// Register extension with Extension API
180
+ /// </summary>
181
+ /// <param name="events">Event types to by notified with</param>
182
+ /// <returns>Awaitable void</returns>
183
+ /// <remarks>This method is expected to be called just once when extension is being registered with the Extension API.</remarks>
184
+ private async Task RegisterExtensionAsync ( params ExtensionEvent [ ] events )
185
+ {
186
+ // custom options for JsonSerializer to serialize ExtensionEvent enum values as strings, rather than integers
187
+ // thus we produce strongly typed code, which doesn't rely on strings
188
+ var options = new JsonSerializerOptions ( ) ;
189
+ options . Converters . Add ( new JsonStringEnumConverter ( ) ) ;
190
+
191
+ // create Json content for this extension registration
192
+ using var content = new StringContent ( JsonSerializer . Serialize ( new {
193
+ events
194
+ } , options ) , Encoding . UTF8 , "application/json" ) ;
195
+
196
+ // add extension name header value
197
+ content . Headers . Add ( LambdaExtensionNameHeader , this . extensionName ) ;
198
+
199
+ // POST call to Extension API
200
+ using var response = await this . httpClient . PostAsync ( this . registerUrl , content ) ;
201
+
202
+ // if POST call didn't succeed
203
+ if ( ! response . IsSuccessStatusCode )
204
+ {
205
+ // log details
206
+ Console . WriteLine ( $ "[{ this . extensionName } ] Error response received for registration request: { await response . Content . ReadAsStringAsync ( ) } ") ;
207
+ // throw an unhandled exception, so that extension is terminated by Lambda runtime
208
+ response . EnsureSuccessStatusCode ( ) ;
209
+ }
210
+
211
+ // get registration id from the response header
212
+ this . Id = response . Headers . GetValues ( LambdaExtensionIdHeader ) . FirstOrDefault ( ) ;
213
+ // if registration id is empty
214
+ if ( string . IsNullOrEmpty ( this . Id ) )
215
+ {
216
+ // throw an exception
217
+ throw new ApplicationException ( "Extension API register call didn't return a valid identifier." ) ;
218
+ }
219
+ // configure all HttpClient to send registration id header along with all subsequent requests
220
+ this . httpClient . DefaultRequestHeaders . Add ( LambdaExtensionIdHeader , this . Id ) ;
221
+ }
222
+
223
+ /// <summary>
224
+ /// Long poll for the next event from Extension API
225
+ /// </summary>
226
+ /// <returns>Awaitable tuple having event type and event details fields</returns>
227
+ /// <remarks>It is important to have httpClient.Timeout set to some value, that is longer than any expected wait time,
228
+ /// otherwise HttpClient will throw an exception when getting the next event details from the server.</remarks>
229
+ private async Task < ( ExtensionEvent type , string payload ) > GetNextAsync ( )
230
+ {
231
+ // use GET request to long poll for the next event
232
+ var contentBody = await this . httpClient . GetStringAsync ( this . nextUrl ) ;
233
+
234
+ // use JsonDocument instead of JsonSerializer, since there is no need to construct the entire object
235
+ using var doc = JsonDocument . Parse ( contentBody ) ;
236
+
237
+ // extract eventType from the reply, convert it to ExtensionEvent enum and reply with the typed event type and event content details.
238
+ return new ( Enum . Parse < ExtensionEvent > ( doc . RootElement . GetProperty ( "eventType" ) . GetString ( ) ) , contentBody ) ;
239
+ }
240
+
241
+ /// <summary>
242
+ /// Report initialization or shutdown error
243
+ /// </summary>
244
+ /// <param name="url"><see cref="initErrorUrl"/> or <see cref="shutdownErrorUrl"/></param>
245
+ /// <param name="errorType">Error type string, e.g. Fatal.ConnectionError or any other meaningful type</param>
246
+ /// <param name="exception">Exception details</param>
247
+ /// <returns>Awaitable void</returns>
248
+ /// <remarks>This implementation will append <paramref name="exception"/> name to <paramref name="errorType"/> for demonstration purposes</remarks>
249
+ private async Task ReportErrorAsync ( Uri url , string errorType , Exception exception )
250
+ {
251
+ using var content = new StringContent ( string . Empty ) ;
252
+ content . Headers . Add ( LambdaExtensionIdHeader , this . Id ) ;
253
+ content . Headers . Add ( LambdaExtensionFunctionErrorTypeHeader , $ "{ errorType } .{ exception . GetType ( ) . Name } ") ;
254
+
255
+ using var response = await this . httpClient . PostAsync ( url , content ) ;
256
+ if ( ! response . IsSuccessStatusCode )
257
+ {
258
+ Console . WriteLine ( $ "[{ this . extensionName } ] Error response received for { url . PathAndQuery } : { await response . Content . ReadAsStringAsync ( ) } ") ;
259
+ response . EnsureSuccessStatusCode ( ) ;
260
+ }
261
+ }
262
+
263
+ /// <summary>
264
+ /// Try to invoke <paramref name="func"/> and call <paramref name="onException"/> if <paramref name="func"/> threw an exception
265
+ /// </summary>
266
+ /// <param name="func">Function to be invoked. Do nothing if it is null.</param>
267
+ /// <param name="param">Parameter to pass to the <paramref name="func"/></param>
268
+ /// <param name="onException">Exception reporting function to be called in case of an exception. Can be null.</param>
269
+ /// <returns>Awaitable boolean value. True if <paramref name="func"/> succeeded and False otherwise.</returns>
270
+ private async Task < bool > SafeInvoke ( Func < string , Task > func , string param , Func < Exception , Task > onException )
271
+ {
272
+ try
273
+ {
274
+ await func ? . Invoke ( param ) ;
275
+ return true ;
276
+ }
277
+ catch ( Exception ex )
278
+ {
279
+ await onException ? . Invoke ( ex ) ;
280
+ return false ;
281
+ }
282
+ }
283
+
284
+ #endregion
285
+
286
+ #region IDisposable implementation
287
+
288
+ /// <summary>
289
+ /// Dispose of instance Disposable variables
290
+ /// </summary>
291
+ public void Dispose ( )
292
+ {
293
+ // Quick and dirty implementation to propagate Dispose call to HttpClient instance
294
+ ( ( IDisposable ) httpClient ) . Dispose ( ) ;
295
+ }
296
+
297
+ #endregion
298
+ }
299
+ }
0 commit comments