Skip to content

Commit b26cf21

Browse files
authored
Merge branch 'master' into issue-4856
2 parents f5eeeac + 0a86e60 commit b26cf21

File tree

9 files changed

+103
-50
lines changed

9 files changed

+103
-50
lines changed

doc/amazon_sagemaker_model_building_pipeline.rst

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,21 +408,39 @@ Example:
408408
step_args=step_args_register_model,
409409
)
410410
411-
CreateModelStep
411+
ModelStep
412412
````````````````
413413
Referable Property List:
414414

415415
- `DescribeModel`_
416416

417+
OR
418+
- `DescribeModelPackage`_
419+
417420
.. _DescribeModel: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeModel.html#API_DescribeModel_ResponseSyntax
421+
.. _DescribeModelPackage: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeModelPackage.html#API_DescribeModelPackage_ResponseSyntax
418422

419423
Example:
420424

425+
For model creation usecase:
426+
421427
.. code-block:: python
422428
423-
step_model = CreateModelStep(...)
424-
model_data = step_model.PrimaryContainer.ModelDataUrl
429+
create_model_step = ModelStep(
430+
name="MyModelCreationStep",
431+
step_args = model.create(...)
432+
)
433+
model_data = create_model_step.properties.PrimaryContainer.ModelDataUrl
434+
435+
For model registration usercase:
436+
437+
.. code-block:: python
425438
439+
register_model_step = ModelStep(
440+
name="MyModelRegistrationStep",
441+
step_args=model.register(...)
442+
)
443+
approval_status=register_model_step.properties.ModelApprovalStatus
426444
427445
LambdaStep
428446
`````````````

src/sagemaker/image_uri_config/instance_gpu_info.json

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
2424
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
2525
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
26-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
26+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
2727
},
2828
"ap-east-1": {
2929
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -49,7 +49,7 @@
4949
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
5050
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
5151
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
52-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
52+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
5353
},
5454
"ap-northeast-1": {
5555
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -75,7 +75,7 @@
7575
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
7676
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
7777
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
78-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
78+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
7979
},
8080
"ap-northeast-2": {
8181
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -101,7 +101,7 @@
101101
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
102102
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
103103
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
104-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
104+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
105105
},
106106
"ap-northeast-3": {
107107
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -127,7 +127,7 @@
127127
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
128128
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
129129
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
130-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
130+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
131131
},
132132
"ap-south-1": {
133133
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -153,7 +153,7 @@
153153
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
154154
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
155155
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
156-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
156+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
157157
},
158158
"ap-southeast-1": {
159159
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -179,7 +179,7 @@
179179
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
180180
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
181181
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
182-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
182+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
183183
},
184184
"ap-southeast-2": {
185185
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -205,7 +205,7 @@
205205
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
206206
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
207207
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
208-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
208+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
209209
},
210210
"ap-southeast-3": {
211211
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -231,7 +231,7 @@
231231
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
232232
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
233233
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
234-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
234+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
235235
},
236236
"ca-central-1": {
237237
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -257,7 +257,7 @@
257257
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
258258
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
259259
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
260-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
260+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
261261
},
262262
"cn-north-1": {
263263
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -283,7 +283,7 @@
283283
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
284284
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
285285
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
286-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
286+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
287287
},
288288
"cn-northwest-1": {
289289
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -309,7 +309,7 @@
309309
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
310310
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
311311
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
312-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
312+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
313313
},
314314
"eu-central-1": {
315315
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -335,7 +335,7 @@
335335
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
336336
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
337337
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
338-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
338+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
339339
},
340340
"eu-central-2": {
341341
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -361,7 +361,7 @@
361361
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
362362
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
363363
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
364-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
364+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
365365
},
366366
"eu-north-1": {
367367
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -387,7 +387,7 @@
387387
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
388388
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
389389
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
390-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
390+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
391391
},
392392
"eu-south-1": {
393393
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -413,7 +413,7 @@
413413
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
414414
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
415415
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
416-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
416+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
417417
},
418418
"eu-south-2": {
419419
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -439,7 +439,7 @@
439439
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
440440
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
441441
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
442-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
442+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
443443
},
444444
"eu-west-1": {
445445
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -465,7 +465,7 @@
465465
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
466466
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
467467
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
468-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
468+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
469469
},
470470
"eu-west-2": {
471471
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -491,7 +491,7 @@
491491
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
492492
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
493493
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
494-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
494+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
495495
},
496496
"eu-west-3": {
497497
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -517,7 +517,7 @@
517517
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
518518
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
519519
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
520-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
520+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
521521
},
522522
"il-central-1": {
523523
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -543,7 +543,7 @@
543543
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
544544
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
545545
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
546-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
546+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
547547
},
548548
"me-central-1": {
549549
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -569,7 +569,7 @@
569569
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
570570
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
571571
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
572-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
572+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
573573
},
574574
"me-south-1": {
575575
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -595,7 +595,7 @@
595595
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
596596
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
597597
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
598-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
598+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
599599
},
600600
"sa-east-1": {
601601
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -621,7 +621,7 @@
621621
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
622622
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
623623
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
624-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
624+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
625625
},
626626
"us-east-1": {
627627
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -647,7 +647,7 @@
647647
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
648648
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
649649
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
650-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
650+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
651651
},
652652
"us-east-2": {
653653
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -673,7 +673,7 @@
673673
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
674674
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
675675
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
676-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
676+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
677677
},
678678
"us-gov-east-1": {
679679
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -699,7 +699,7 @@
699699
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
700700
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
701701
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
702-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
702+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
703703
},
704704
"us-gov-west-1": {
705705
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -725,7 +725,7 @@
725725
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
726726
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
727727
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
728-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
728+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
729729
},
730730
"us-west-1": {
731731
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -751,7 +751,7 @@
751751
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
752752
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
753753
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
754-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
754+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
755755
},
756756
"us-west-2": {
757757
"ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360},
@@ -777,6 +777,6 @@
777777
"ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576},
778778
"ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
779779
"ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304},
780-
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608}
780+
"ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104}
781781
}
782782
}

src/sagemaker/jumpstart/accessors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.jumpstart.hub.utils import (
2626
construct_hub_model_arn_from_inputs,
2727
construct_hub_model_reference_arn_from_inputs,
28+
generate_hub_arn_for_init_kwargs,
2829
)
2930
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
3031
from sagemaker.session import Session
@@ -291,6 +292,10 @@ def get_model_specs(
291292
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
292293
if hub_arn:
293294
try:
295+
hub_arn = generate_hub_arn_for_init_kwargs(
296+
hub_name=hub_arn, region=region, session=sagemaker_session
297+
)
298+
294299
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
295300
hub_arn=hub_arn, model_name=model_id, version=version
296301
)

src/sagemaker/jumpstart/estimator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
validate_model_id_and_get_type,
4242
resolve_model_sagemaker_config_field,
4343
verify_model_region_and_return_specs,
44-
remove_env_var_from_estimator_kwargs_if_accept_eula_present,
44+
remove_env_var_from_estimator_kwargs_if_model_access_config_present,
4545
get_model_access_config,
4646
get_hub_access_config,
4747
)
@@ -616,6 +616,7 @@ def _validate_model_id_and_get_type_hook():
616616
self.tolerate_vulnerable_model = estimator_init_kwargs.tolerate_vulnerable_model
617617
self.instance_count = estimator_init_kwargs.instance_count
618618
self.region = estimator_init_kwargs.region
619+
self.environment = estimator_init_kwargs.environment
619620
self.orig_predictor_cls = None
620621
self.role = estimator_init_kwargs.role
621622
self.sagemaker_session = estimator_init_kwargs.sagemaker_session
@@ -693,7 +694,7 @@ def fit(
693694
accept the end-user license agreement (EULA) that some
694695
models require. (Default: None).
695696
"""
696-
self.model_access_config = get_model_access_config(accept_eula)
697+
self.model_access_config = get_model_access_config(accept_eula, self.environment)
697698
self.hub_access_config = get_hub_access_config(
698699
hub_content_arn=self.init_kwargs.get("model_reference_arn", None)
699700
)
@@ -713,7 +714,9 @@ def fit(
713714
config_name=self.config_name,
714715
hub_access_config=self.hub_access_config,
715716
)
716-
remove_env_var_from_estimator_kwargs_if_accept_eula_present(self.init_kwargs, accept_eula)
717+
remove_env_var_from_estimator_kwargs_if_model_access_config_present(
718+
self.init_kwargs, self.model_access_config
719+
)
717720

718721
return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
719722

0 commit comments

Comments
 (0)