Skip to content

Commit 1922a11

Browse files
authored
Support Azure ML managed identity (#21851)
1 parent 6da3b75 commit 1922a11

File tree

5 files changed

+87
-3
lines changed

5 files changed

+87
-3
lines changed

sdk/azidentity/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## 1.5.0-beta.2 (Unreleased)
44

55
### Features Added
6+
* `DefaultAzureCredential` and `ManagedIdentityCredential` support Azure ML managed identity
67

78
### Breaking Changes
89

sdk/azidentity/assets.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "go",
44
"TagPrefix": "go/azidentity",
5-
"Tag": "go/azidentity_ae45facec3"
5+
"Tag": "go/azidentity_db4a26f583"
66
}

sdk/azidentity/live_test.go

+22
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@ const (
6060
recordingDirectory = "sdk/azidentity/testdata"
6161
azidentityRunManualTests = "AZIDENTITY_RUN_MANUAL_TESTS"
6262
fakeClientID = "fake-client-id"
63+
fakeMIEndpoint = "https://fake.local"
6364
fakeResourceID = "/fake/resource/ID"
6465
fakeTenantID = "fake-tenant"
6566
fakeUsername = "fake@user"
6667
fakeAdfsAuthority = "fake.adfs.local"
6768
fakeAdfsScope = "fake.adfs.local/fake-scope/.default"
6869
liveTestScope = "https://management.core.windows.net//.default"
70+
redacted = "redacted"
6971
)
7072

7173
var adfsLiveSP = struct {
@@ -159,6 +161,9 @@ func run(m *testing.M) int {
159161
strings.TrimPrefix(adfsScope, "https://"): fakeAdfsScope,
160162
strings.TrimPrefix(adfsAuthority, "https://"): fakeAdfsAuthority,
161163
}
164+
if id := os.Getenv(defaultIdentityClientID); id != "" {
165+
pathVars[id] = fakeClientID
166+
}
162167
for target, replacement := range pathVars {
163168
if target != "" {
164169
err := recording.AddURISanitizer(replacement, target, nil)
@@ -184,6 +189,23 @@ func run(m *testing.M) int {
184189
if err != nil {
185190
panic(err)
186191
}
192+
// some managed identity requests include a "secret" header. It isn't dangerous
193+
// to record the value, however it must be static for matching to work in playback
194+
err = recording.AddHeaderRegexSanitizer("secret", redacted, "", nil)
195+
if err != nil {
196+
panic(err)
197+
}
198+
if url, ok := os.LookupEnv(msiEndpoint); ok {
199+
err = recording.AddURISanitizer(fakeMIEndpoint, url, nil)
200+
if err == nil {
201+
if clientID, ok := os.LookupEnv(defaultIdentityClientID); ok {
202+
err = recording.AddURISanitizer(fakeClientID, clientID, nil)
203+
}
204+
}
205+
if err != nil {
206+
panic(err)
207+
}
208+
}
187209
// redact secrets returned by Microsoft Entra ID
188210
for _, key := range []string{"access_token", "device_code", "message", "refresh_token", "user_code"} {
189211
err = recording.AddBodyKeySanitizer("$."+key, "redacted", "", nil)

sdk/azidentity/managed_identity_client.go

+35-2
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ import (
2828

2929
const (
3030
arcIMDSEndpoint = "IMDS_ENDPOINT"
31+
defaultIdentityClientID = "DEFAULT_IDENTITY_CLIENT_ID"
3132
identityEndpoint = "IDENTITY_ENDPOINT"
3233
identityHeader = "IDENTITY_HEADER"
3334
identityServerThumbprint = "IDENTITY_SERVER_THUMBPRINT"
3435
headerMetadata = "Metadata"
3536
imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
3637
msiEndpoint = "MSI_ENDPOINT"
38+
msiSecret = "MSI_SECRET"
3739
imdsAPIVersion = "2018-02-01"
3840
azureArcAPIVersion = "2019-08-15"
3941
serviceFabricAPIVersion = "2019-07-01-preview"
@@ -47,6 +49,7 @@ type msiType int
4749
const (
4850
msiTypeAppService msiType = iota
4951
msiTypeAzureArc
52+
msiTypeAzureML
5053
msiTypeCloudShell
5154
msiTypeIMDS
5255
msiTypeServiceFabric
@@ -135,9 +138,14 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
135138
c.msiType = msiTypeAzureArc
136139
}
137140
} else if endpoint, ok := os.LookupEnv(msiEndpoint); ok {
138-
env = "Cloud Shell"
139141
c.endpoint = endpoint
140-
c.msiType = msiTypeCloudShell
142+
if _, ok := os.LookupEnv(msiSecret); ok {
143+
env = "Azure ML"
144+
c.msiType = msiTypeAzureML
145+
} else {
146+
env = "Cloud Shell"
147+
c.msiType = msiTypeCloudShell
148+
}
141149
} else {
142150
setIMDSRetryOptionDefaults(&cp.Retry)
143151
}
@@ -247,6 +255,8 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
247255
return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err)
248256
}
249257
return c.createAzureArcAuthRequest(ctx, id, scopes, key)
258+
case msiTypeAzureML:
259+
return c.createAzureMLAuthRequest(ctx, id, scopes)
250260
case msiTypeServiceFabric:
251261
return c.createServiceFabricAuthRequest(ctx, id, scopes)
252262
case msiTypeCloudShell:
@@ -296,6 +306,29 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context,
296306
return request, nil
297307
}
298308

309+
func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
310+
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
311+
if err != nil {
312+
return nil, err
313+
}
314+
request.Raw().Header.Set("secret", os.Getenv(msiSecret))
315+
q := request.Raw().URL.Query()
316+
q.Add("api-version", "2017-09-01")
317+
q.Add("resource", strings.Join(scopes, " "))
318+
q.Add("clientid", os.Getenv(defaultIdentityClientID))
319+
if id != nil {
320+
if id.idKind() == miResourceID {
321+
log.Write(EventAuthentication, "WARNING: Azure ML doesn't support specifying a managed identity by resource ID")
322+
q.Set("clientid", "")
323+
q.Set(qpResID, id.String())
324+
} else {
325+
q.Set("clientid", id.String())
326+
}
327+
}
328+
request.Raw().URL.RawQuery = q.Encode()
329+
return request, nil
330+
}
331+
299332
func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
300333
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
301334
if err != nil {

sdk/azidentity/managed_identity_credential_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,34 @@ func TestManagedIdentityCredential_AzureArcErrors(t *testing.T) {
156156
})
157157
}
158158

159+
func TestManagedIdentityCredential_AzureMLLive(t *testing.T) {
160+
switch recording.GetRecordMode() {
161+
case recording.LiveMode:
162+
t.Skip("this test doesn't run in live mode because it can't pass in CI")
163+
case recording.PlaybackMode:
164+
t.Setenv(defaultIdentityClientID, fakeClientID)
165+
t.Setenv(msiEndpoint, fakeMIEndpoint)
166+
t.Setenv(msiSecret, redacted)
167+
case recording.RecordingMode:
168+
missing := []string{}
169+
for _, v := range []string{defaultIdentityClientID, msiEndpoint, msiSecret} {
170+
if len(os.Getenv(v)) == 0 {
171+
missing = append(missing, v)
172+
}
173+
}
174+
if len(missing) > 0 {
175+
t.Skip("no value for " + strings.Join(missing, ", "))
176+
}
177+
}
178+
opts, stop := initRecording(t)
179+
defer stop()
180+
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ClientOptions: opts})
181+
if err != nil {
182+
t.Fatal(err)
183+
}
184+
testGetTokenSuccess(t, cred)
185+
}
186+
159187
func TestManagedIdentityCredential_CloudShell(t *testing.T) {
160188
validateReq := func(req *http.Request) *http.Response {
161189
err := req.ParseForm()

0 commit comments

Comments
 (0)