@@ -28,12 +28,14 @@ import (
28
28
29
29
const (
30
30
arcIMDSEndpoint = "IMDS_ENDPOINT"
31
+ defaultIdentityClientID = "DEFAULT_IDENTITY_CLIENT_ID"
31
32
identityEndpoint = "IDENTITY_ENDPOINT"
32
33
identityHeader = "IDENTITY_HEADER"
33
34
identityServerThumbprint = "IDENTITY_SERVER_THUMBPRINT"
34
35
headerMetadata = "Metadata"
35
36
imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
36
37
msiEndpoint = "MSI_ENDPOINT"
38
+ msiSecret = "MSI_SECRET"
37
39
imdsAPIVersion = "2018-02-01"
38
40
azureArcAPIVersion = "2019-08-15"
39
41
serviceFabricAPIVersion = "2019-07-01-preview"
@@ -47,6 +49,7 @@ type msiType int
47
49
const (
48
50
msiTypeAppService msiType = iota
49
51
msiTypeAzureArc
52
+ msiTypeAzureML
50
53
msiTypeCloudShell
51
54
msiTypeIMDS
52
55
msiTypeServiceFabric
@@ -135,9 +138,14 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
135
138
c .msiType = msiTypeAzureArc
136
139
}
137
140
} else if endpoint , ok := os .LookupEnv (msiEndpoint ); ok {
138
- env = "Cloud Shell"
139
141
c .endpoint = endpoint
140
- c .msiType = msiTypeCloudShell
142
+ if _ , ok := os .LookupEnv (msiSecret ); ok {
143
+ env = "Azure ML"
144
+ c .msiType = msiTypeAzureML
145
+ } else {
146
+ env = "Cloud Shell"
147
+ c .msiType = msiTypeCloudShell
148
+ }
141
149
} else {
142
150
setIMDSRetryOptionDefaults (& cp .Retry )
143
151
}
@@ -247,6 +255,8 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
247
255
return nil , newAuthenticationFailedError (credNameManagedIdentity , msg , nil , err )
248
256
}
249
257
return c .createAzureArcAuthRequest (ctx , id , scopes , key )
258
+ case msiTypeAzureML :
259
+ return c .createAzureMLAuthRequest (ctx , id , scopes )
250
260
case msiTypeServiceFabric :
251
261
return c .createServiceFabricAuthRequest (ctx , id , scopes )
252
262
case msiTypeCloudShell :
@@ -296,6 +306,29 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context,
296
306
return request , nil
297
307
}
298
308
309
+ func (c * managedIdentityClient ) createAzureMLAuthRequest (ctx context.Context , id ManagedIDKind , scopes []string ) (* policy.Request , error ) {
310
+ request , err := runtime .NewRequest (ctx , http .MethodGet , c .endpoint )
311
+ if err != nil {
312
+ return nil , err
313
+ }
314
+ request .Raw ().Header .Set ("secret" , os .Getenv (msiSecret ))
315
+ q := request .Raw ().URL .Query ()
316
+ q .Add ("api-version" , "2017-09-01" )
317
+ q .Add ("resource" , strings .Join (scopes , " " ))
318
+ q .Add ("clientid" , os .Getenv (defaultIdentityClientID ))
319
+ if id != nil {
320
+ if id .idKind () == miResourceID {
321
+ log .Write (EventAuthentication , "WARNING: Azure ML doesn't support specifying a managed identity by resource ID" )
322
+ q .Set ("clientid" , "" )
323
+ q .Set (qpResID , id .String ())
324
+ } else {
325
+ q .Set ("clientid" , id .String ())
326
+ }
327
+ }
328
+ request .Raw ().URL .RawQuery = q .Encode ()
329
+ return request , nil
330
+ }
331
+
299
332
func (c * managedIdentityClient ) createServiceFabricAuthRequest (ctx context.Context , id ManagedIDKind , scopes []string ) (* policy.Request , error ) {
300
333
request , err := runtime .NewRequest (ctx , http .MethodGet , c .endpoint )
301
334
if err != nil {
0 commit comments