-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathsession.py
8956 lines (7888 loc) · 389 KB
/
session.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import, annotations, print_function
import json
import logging
import os
import re
import sys
import time
import typing
import warnings
import uuid
import datetime
from copy import deepcopy
from typing import List, Dict, Any, Sequence, Optional
import boto3
import botocore
import botocore.config
from botocore.exceptions import ClientError
import six
from sagemaker.utils import instance_supports_kms, create_paginator_config
import sagemaker.logs
from sagemaker import vpc_utils, s3_utils
from sagemaker._studio import _append_project_tags
from sagemaker.config import load_sagemaker_config, validate_sagemaker_config
from sagemaker.config import (
KEY,
TRAINING_JOB,
TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
TRAINING_JOB_ROLE_ARN_PATH,
TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
TRAINING_JOB_ENVIRONMENT_PATH,
TRAINING_JOB_VPC_CONFIG_PATH,
TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH,
TRAINING_JOB_RESOURCE_CONFIG_PATH,
TRAINING_JOB_PROFILE_CONFIG_PATH,
PROCESSING_JOB_INPUTS_PATH,
PROCESSING_JOB,
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
PROCESSING_JOB_ENVIRONMENT_PATH,
PROCESSING_JOB_ROLE_ARN_PATH,
PROCESSING_JOB_NETWORK_CONFIG_PATH,
PROCESSING_OUTPUT_CONFIG_PATH,
PROCESSING_JOB_PROCESSING_RESOURCES_PATH,
MONITORING_JOB_ENVIRONMENT_PATH,
MONITORING_JOB_ROLE_ARN_PATH,
MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH,
MONITORING_JOB_NETWORK_CONFIG_PATH,
MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH,
MONITORING_SCHEDULE,
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH,
AUTO_ML_ROLE_ARN_PATH,
AUTO_ML_V2_ROLE_ARN_PATH,
AUTO_ML_OUTPUT_CONFIG_PATH,
AUTO_ML_V2_OUTPUT_CONFIG_PATH,
AUTO_ML_JOB_CONFIG_PATH,
AUTO_ML_JOB,
AUTO_ML_JOB_V2,
COMPILATION_JOB_ROLE_ARN_PATH,
COMPILATION_JOB_OUTPUT_CONFIG_PATH,
COMPILATION_JOB_VPC_CONFIG_PATH,
COMPILATION_JOB,
EDGE_PACKAGING_ROLE_ARN_PATH,
EDGE_PACKAGING_OUTPUT_CONFIG_PATH,
EDGE_PACKAGING_RESOURCE_KEY_PATH,
EDGE_PACKAGING_JOB,
TRANSFORM_JOB,
TRANSFORM_JOB_ENVIRONMENT_PATH,
TRANSFORM_JOB_KMS_KEY_ID_PATH,
TRANSFORM_OUTPUT_KMS_KEY_ID_PATH,
VOLUME_KMS_KEY_ID,
TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH,
MODEL,
MODEL_CONTAINERS_PATH,
MODEL_EXECUTION_ROLE_ARN_PATH,
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
MODEL_PRIMARY_CONTAINER_PATH,
MODEL_PRIMARY_CONTAINER_ENVIRONMENT_PATH,
MODEL_VPC_CONFIG_PATH,
MODEL_PACKAGE_VALIDATION_ROLE_PATH,
VALIDATION_ROLE,
VALIDATION_PROFILES,
MODEL_PACKAGE_INFERENCE_SPECIFICATION_CONTAINERS_PATH,
MODEL_PACKAGE_VALIDATION_PROFILES_PATH,
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
KMS_KEY_ID,
ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
ENDPOINT_CONFIG,
ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
ENDPOINT_CONFIG_VPC_CONFIG_PATH,
ENDPOINT_CONFIG_ENABLE_NETWORK_ISOLATION_PATH,
ENDPOINT_CONFIG_EXECUTION_ROLE_ARN_PATH,
ENDPOINT,
INFERENCE_COMPONENT,
SAGEMAKER,
FEATURE_GROUP,
TAGS,
FEATURE_GROUP_ROLE_ARN_PATH,
FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH,
FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH,
SESSION_DEFAULT_S3_BUCKET_PATH,
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
)
from sagemaker.config.config_utils import _log_sagemaker_config_merge
from sagemaker.deprecations import deprecated_class
from sagemaker.enums import EndpointType
from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig
from sagemaker.user_agent import get_user_agent_extra_suffix
from sagemaker.utils import (
name_from_image,
secondary_training_status_changed,
secondary_training_status_message,
sts_regional_endpoint,
retries,
resolve_value_from_config,
get_sagemaker_config_value,
resolve_class_attribute_from_config,
resolve_nested_dict_value_from_config,
update_nested_dictionary_with_values_from_config,
update_list_of_dicts_with_values_from_config,
format_tags,
Tags,
TagsDict,
)
from sagemaker import exceptions
from sagemaker.session_settings import SessionSettings
from sagemaker.utils import can_model_package_source_uri_autopopulate
# Setting LOGGER for backward compatibility, in case users import it...
logger = LOGGER = logging.getLogger("sagemaker")
NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
MODEL_MONITOR_ONE_TIME_SCHEDULE = "NOW"
_STATUS_CODE_TABLE = {
"COMPLETED": "Completed",
"INPROGRESS": "InProgress",
"IN_PROGRESS": "InProgress",
"FAILED": "Failed",
"STOPPED": "Stopped",
"STOPPING": "Stopping",
"STARTING": "Starting",
"PENDING": "Pending",
}
EP_LOGGER_POLL = 10
DEFAULT_EP_POLL = 30
class LogState(object):
"""Placeholder docstring"""
STARTING = 1
WAIT_IN_PROGRESS = 2
TAILING = 3
JOB_COMPLETE = 4
COMPLETE = 5
class Session(object): # pylint: disable=too-many-public-methods
"""Manage interactions with the Amazon SageMaker APIs and any other AWS services needed.
This class provides convenient methods for manipulating entities and resources that Amazon
SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
AWS service calls are delegated to an underlying Boto3 session, which by default
is initialized using the AWS configuration chain. When you make an Amazon SageMaker API call
that accesses an S3 bucket location and one is not specified, the ``Session`` creates a default
bucket based on a naming convention which includes the current AWS account ID.
"""
def __init__(
self,
boto_session=None,
sagemaker_client=None,
sagemaker_runtime_client=None,
sagemaker_featurestore_runtime_client=None,
default_bucket=None,
settings=None,
sagemaker_metrics_client=None,
sagemaker_config: dict = None,
default_bucket_prefix: str = None,
):
"""Initialize a SageMaker ``Session``.
Args:
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
calls are delegated to (default: None). If not provided, one is created with
default AWS configuration chain.
sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker service
calls other than ``InvokeEndpoint`` (default: None). Estimators created using this
``Session`` use this client. If not provided, one will be created using this
instance's ``boto_session``.
sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
using this ``Session`` use this client. If not provided, one will be created using
this instance's ``boto_session``.
sagemaker_featurestore_runtime_client (boto3.SageMakerFeatureStoreRuntime.Client):
Client which makes SageMaker FeatureStore record related calls to Amazon SageMaker
(default: None). If not provided, one will be created using
this instance's ``boto_session``.
default_bucket (str): The default Amazon S3 bucket to be used by this session.
This will be created the next time an Amazon S3 bucket is needed (by calling
:func:`default_bucket`).
If not provided, it will be fetched from the sagemaker_config. If not configured
there either, a default bucket will be created based on the following format:
"sagemaker-{region}-{aws-account-id}".
Example: "sagemaker-my-custom-bucket".
settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional
parameters to apply to the session.
sagemaker_metrics_client (boto3.SageMakerMetrics.Client):
Client which makes SageMaker Metrics related calls to Amazon SageMaker
(default: None). If not provided, one will be created using
this instance's ``boto_session``.
sagemaker_config (dict): A dictionary containing default values for the
SageMaker Python SDK. (default: None). The dictionary must adhere to the schema
defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`.
If sagemaker_config is not provided and configuration files exist (at the default
paths for admins and users, or paths set through the environment variables
SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE),
a new dictionary will be generated from those configuration files. Alternatively,
this dictionary can be generated by calling
:func:`~sagemaker.config.load_sagemaker_config` and then be provided to the
Session.
default_bucket_prefix (str): The default prefix to use for S3 Object Keys. (default:
None). If provided and where applicable, it will be used by the SDK to construct
default S3 URIs, in the format:
`s3://{default_bucket}/{default_bucket_prefix}/<rest of object key>`
This parameter can also be specified via `{sagemaker_config}` instead of here. If
not provided here or within `{sagemaker_config}`, default S3 URIs will have the
format: `s3://{default_bucket}/<rest of object key>`
"""
# sagemaker_config is validated and initialized inside :func:`_initialize`,
# so if default_bucket is None and the sagemaker_config has a default S3 bucket configured,
# _default_bucket_name_override will be set again inside :func:`_initialize`.
self.endpoint_arn = None
self._default_bucket = None
self._default_bucket_name_override = default_bucket
# this may also be set again inside :func:`_initialize` if it is None
self.default_bucket_prefix = default_bucket_prefix
self._default_bucket_set_by_sdk = False
self.s3_resource = None
self.s3_client = None
self.resource_groups_client = None
self.resource_group_tagging_client = None
self._config = None
self.lambda_client = None
self.settings = settings if settings else SessionSettings()
self._initialize(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=sagemaker_runtime_client,
sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client,
sagemaker_metrics_client=sagemaker_metrics_client,
sagemaker_config=sagemaker_config,
)
def _initialize(
self,
boto_session,
sagemaker_client,
sagemaker_runtime_client,
sagemaker_featurestore_runtime_client,
sagemaker_metrics_client,
sagemaker_config: dict = None,
):
"""Initialize this SageMaker Session.
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
Sets the region_name.
"""
self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session()
self._region_name = self.boto_session.region_name
if self._region_name is None:
raise ValueError(
"Must setup local AWS configuration with a region supported by SageMaker."
)
# Make use of user_agent_extra field of the botocore_config object
# to append SageMaker Python SDK specific user_agent suffix
# to the current User-Agent header value from boto3
# This config will also make sure that user_agent never fails to log the User-Agent string
# even if boto User-Agent header format is updated in the future
# Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
botocore_config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix())
# Create sagemaker_client with the botocore_config object
# This config is customized to append SageMaker Python SDK specific user_agent suffix
self.sagemaker_client = sagemaker_client or self.boto_session.client(
"sagemaker", config=botocore_config
)
if sagemaker_runtime_client is not None:
self.sagemaker_runtime_client = sagemaker_runtime_client
else:
config = botocore.config.Config(
read_timeout=80, user_agent_extra=get_user_agent_extra_suffix()
)
self.sagemaker_runtime_client = self.boto_session.client(
"runtime.sagemaker", config=config
)
if sagemaker_featurestore_runtime_client:
self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client
else:
self.sagemaker_featurestore_runtime_client = self.boto_session.client(
"sagemaker-featurestore-runtime"
)
if sagemaker_metrics_client:
self.sagemaker_metrics_client = sagemaker_metrics_client
else:
self.sagemaker_metrics_client = self.boto_session.client(
"sagemaker-metrics", config=botocore_config
)
self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name)
self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name)
self.local_mode = False
if sagemaker_config:
validate_sagemaker_config(sagemaker_config)
self.sagemaker_config = sagemaker_config
else:
# self.s3_resource might be None. If it is None, load_sagemaker_config will
# create a default S3 resource, but only if it needs to fetch from S3
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
# after sagemaker_config initialization, update self._default_bucket_name_override if needed
self._default_bucket_name_override = resolve_value_from_config(
direct_input=self._default_bucket_name_override,
config_path=SESSION_DEFAULT_S3_BUCKET_PATH,
sagemaker_session=self,
)
# after sagemaker_config initialization, update self.default_bucket_prefix if needed
self.default_bucket_prefix = resolve_value_from_config(
direct_input=self.default_bucket_prefix,
config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
sagemaker_session=self,
)
@property
def config(self) -> Dict | None:
"""The config for the local mode, unused in a normal session"""
return self._config
@config.setter
def config(self, value: Dict | None):
"""The config for the local mode, unused in a normal session"""
self._config = value
@property
def boto_region_name(self):
"""Placeholder docstring"""
return self._region_name
def upload_data(self, path, bucket=None, key_prefix="data", callback=None, extra_args=None):
"""Upload local file or directory to S3.
If a single file is specified for upload, the resulting S3 object key is
``{key_prefix}/{filename}`` (filename does not include the local path, if any specified).
If a directory is specified for upload, the API uploads all content, recursively,
preserving relative structure of subdirectories. The resulting object key names are:
``{key_prefix}/{relative_subdirectory_path}/filename``.
Args:
path (str): Path (absolute or relative) of local file or directory to upload.
bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
default bucket of the ``Session`` is used (if default bucket does not exist, the
``Session`` creates it).
key_prefix (str): Optional S3 object key name prefix (default: 'data'). S3 uses the
prefix to create a directory structure for the bucket content that it display in
the S3 console.
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
ExtraArgs parameter documentation here:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
Returns:
str: The S3 URI of the uploaded file(s). If a file is specified in the path argument,
the URI format is: ``s3://{bucket name}/{key_prefix}/{original_file_name}``.
If a directory is specified in the path argument, the URI format is
``s3://{bucket name}/{key_prefix}``.
"""
bucket, key_prefix = s3_utils.determine_bucket_and_prefix(
bucket=bucket, key_prefix=key_prefix, sagemaker_session=self
)
# Generate a tuple for each file that we want to upload of the form (local_path, s3_key).
files = []
key_suffix = None
if os.path.isdir(path):
for dirpath, _, filenames in os.walk(path):
for name in filenames:
local_path = os.path.join(dirpath, name)
s3_relative_prefix = (
"" if path == dirpath else os.path.relpath(dirpath, start=path) + "/"
)
s3_key = "{}/{}{}".format(key_prefix, s3_relative_prefix, name)
files.append((local_path, s3_key))
else:
_, name = os.path.split(path)
s3_key = "{}/{}".format(key_prefix, name)
files.append((path, s3_key))
key_suffix = name
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
for local_path, s3_key in files:
s3.Object(bucket, s3_key).upload_file(
local_path, Callback=callback, ExtraArgs=extra_args
)
s3_uri = "s3://{}/{}".format(bucket, key_prefix)
# If a specific file was used as input (instead of a directory), we return the full S3 key
# of the uploaded object. This prevents unintentionally using other files under the same
# prefix during training.
if key_suffix:
s3_uri = "{}/{}".format(s3_uri, key_suffix)
return s3_uri
def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
"""Upload a string as a file body.
Args:
body (str): String representing the body of the file.
bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
default bucket of the ``Session`` is used (if default bucket does not exist, the
``Session`` creates it).
key (str): S3 object key. This is the s3 path to the file.
kms_key (str): The KMS key to use for encrypting the file.
Returns:
str: The S3 URI of the uploaded file.
The URI format is: ``s3://{bucket name}/{key}``.
"""
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
s3_object = s3.Object(bucket_name=bucket, key=key)
if kms_key is not None:
s3_object.put(Body=body, SSEKMSKeyId=kms_key, ServerSideEncryption="aws:kms")
else:
s3_object.put(Body=body)
s3_uri = "s3://{}/{}".format(bucket, key)
return s3_uri
def download_data(self, path, bucket, key_prefix="", extra_args=None):
"""Download file or directory from S3.
Args:
path (str): Local path where the file or directory should be downloaded to.
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): Optional S3 object key name prefix.
extra_args (dict): Optional extra arguments that may be passed to the
download operation. Please refer to the ExtraArgs parameter in the boto3
documentation here:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html
Returns:
list[str]: List of local paths of downloaded files
"""
# Initialize the S3 client.
if self.s3_client is None:
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_client
# Initialize the variables used to loop through the contents of the S3 bucket.
keys = []
directories = []
next_token = ""
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
# Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into
# a "keys" list.
while next_token is not None:
request_parameters = base_parameters.copy()
if next_token != "":
request_parameters.update({"ContinuationToken": next_token})
response = s3.list_objects_v2(**request_parameters)
contents = response.get("Contents", None)
if not contents:
logger.info(
"Nothing to download from bucket: %s, key_prefix: %s.", bucket, key_prefix
)
return []
# For each object, save its key or directory.
for s3_object in contents:
key: str = s3_object.get("Key")
obj_size = s3_object.get("Size")
if key.endswith("/") and int(obj_size) == 0:
directories.append(os.path.join(path, key))
else:
keys.append(key)
next_token = response.get("NextContinuationToken")
# For each object key, create the directory on the local machine if needed, and then
# download the file.
downloaded_paths = []
for dir_path in directories:
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
for key in keys:
tail_s3_uri_path = os.path.basename(key)
if not os.path.splitext(key_prefix)[1]:
tail_s3_uri_path = os.path.relpath(key, key_prefix)
destination_path = os.path.join(path, tail_s3_uri_path)
if not os.path.exists(os.path.dirname(destination_path)):
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
s3.download_file(
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
)
downloaded_paths.append(destination_path)
return downloaded_paths
def read_s3_file(self, bucket, key_prefix):
"""Read a single file from S3.
Args:
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): S3 object key name prefix.
Returns:
str: The body of the s3 file as a string.
"""
if self.s3_client is None:
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_client
# Explicitly passing a None kms_key to boto3 throws a validation error.
s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
return s3_object["Body"].read().decode("utf-8")
def list_s3_files(self, bucket, key_prefix):
"""Lists the S3 files given an S3 bucket and key.
Args:
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): S3 object key name prefix.
Returns:
[str]: The list of files at the S3 path.
"""
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
s3_bucket = s3.Bucket(name=bucket)
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
return [s3_object.key for s3_object in s3_objects]
def default_bucket(self):
"""Return the name of the default bucket to use in relevant Amazon SageMaker interactions.
This function will create the s3 bucket if it does not exist.
Returns:
str: The name of the default bucket. If the name was not explicitly specified through
the Session or sagemaker_config, the bucket will take the form:
``sagemaker-{region}-{AWS account ID}``.
"""
if self._default_bucket:
return self._default_bucket
region = self.boto_session.region_name
default_bucket = self._default_bucket_name_override
if not default_bucket:
default_bucket = generate_default_sagemaker_bucket_name(self.boto_session)
self._default_bucket_set_by_sdk = True
self._create_s3_bucket_if_it_does_not_exist(
bucket_name=default_bucket,
region=region,
)
self._default_bucket = default_bucket
return self._default_bucket
def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
"""Creates an S3 Bucket if it does not exist.
Also swallows a few common exceptions that indicate that the bucket already exists or
that it is being created.
Args:
bucket_name (str): Name of the S3 bucket to be created.
region (str): The region in which to create the bucket.
Raises:
botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket
creation.
If the exception is due to the bucket already existing or
already being created, no exception is raised.
"""
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=region)
else:
s3 = self.s3_resource
bucket = s3.Bucket(name=bucket_name)
if bucket.creation_date is None:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
elif self._default_bucket_set_by_sdk:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)
expected_bucket_owner_id = self.account_id()
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)
def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
"""Checks if the bucket belongs to a particular owner and throws a Client Error if it is not
Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
expected_bucket_owner_id (str): Owner ID string
"""
try:
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name,
Prefix=self.default_bucket_prefix,
ExpectedBucketOwner=expected_bucket_owner_id,
)
else:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "403" and message == "Forbidden":
LOGGER.error(
"Since default_bucket param was not set, SageMaker Python SDK tried to use "
"%s bucket. "
"This bucket cannot be configured to use as it is not owned by Account %s. "
"To unblock it's recommended to use custom default_bucket "
"parameter in sagemaker.Session",
bucket_name,
expected_bucket_owner_id,
)
raise
def general_bucket_check_if_user_has_permission(
self, bucket_name, s3, bucket, region, bucket_creation_date_none
):
"""Checks if the person running has the permissions to the bucket
If there is any other error that comes up with calling head bucket, it is raised up here
If there is no bucket , it will create one
Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
region (str): The region in which to create the bucket.
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
"""
try:
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name, Prefix=self.default_bucket_prefix
)
else:
s3.meta.client.head_bucket(Bucket=bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
# bucket does not exist or forbidden to access
if bucket_creation_date_none:
if error_code == "404" and message == "Not Found":
self.create_bucket_for_not_exist_error(bucket_name, region, s3)
elif error_code == "403" and message == "Forbidden":
LOGGER.error(
"Bucket %s exists, but access is forbidden. Please try again after "
"adding appropriate access.",
bucket.name,
)
raise
else:
raise
def create_bucket_for_not_exist_error(self, bucket_name, region, s3):
"""Creates the S3 bucket in the given region
Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
region (str): The region in which to create the bucket.
"""
# bucket does not exist, create one
try:
if region == "us-east-1":
# 'us-east-1' cannot be specified because it is the default region:
# https://github.com/boto/boto3/issues/125
s3.create_bucket(Bucket=bucket_name)
else:
s3.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": region},
)
logger.info("Created S3 bucket: %s", bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "OperationAborted" and "conflicting conditional operation" in message:
# If this bucket is already being concurrently created,
# we don't need to create it again.
pass
else:
raise
def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str):
"""Appends tags specified in the sagemaker_config to the given list of tags.
To minimize the chance of duplicate tags being applied, this is intended to be used
immediately before calls to sagemaker_client, rather than during initialization of
classes like EstimatorBase.
Args:
tags: The list of tags to append to.
config_path_to_tags: The path to look up tags in the config.
Returns:
A list of tags.
"""
config_tags = get_sagemaker_config_value(self, config_path_to_tags)
if config_tags is None or len(config_tags) == 0:
return tags
all_tags = tags or []
for config_tag in config_tags:
config_tag_key = config_tag[KEY]
if not any(tag.get("Key", None) == config_tag_key for tag in all_tags):
# This check prevents new tags with duplicate keys from being added
# (to prevent API failure and/or overwriting of tags). If there is a conflict,
# the user-provided tag should take precedence over the config-provided tag.
# Note: this does not check user-provided tags for conflicts with other
# user-provided tags.
all_tags.append(config_tag)
_log_sagemaker_config_merge(
source_value=tags,
config_value=config_tags,
merged_source_and_config_value=all_tags,
config_key_path=config_path_to_tags,
)
return all_tags
def train( # noqa: C901
self,
input_mode,
input_config,
role=None,
job_name=None,
output_config=None,
resource_config=None,
vpc_config=None,
hyperparameters=None,
stop_condition=None,
tags=None,
metric_definitions=None,
enable_network_isolation=None,
image_uri=None,
training_image_config=None,
infra_check_config=None,
container_entry_point=None,
container_arguments=None,
algorithm_arn=None,
encrypt_inter_container_traffic=None,
use_spot_instances=False,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
experiment_config=None,
debugger_rule_configs=None,
debugger_hook_config=None,
tensorboard_output_config=None,
enable_sagemaker_metrics=None,
profiler_rule_configs=None,
profiler_config=None,
environment: Optional[Dict[str, str]] = None,
retry_strategy=None,
remote_debug_config=None,
session_chaining_config=None,
):
"""Create an Amazon SageMaker training job.
Args:
input_mode (str): The input mode that the algorithm supports. Valid modes:
* 'File' - Amazon SageMaker copies the training dataset from the S3 location to
a directory in the Docker container.
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a
Unix-named pipe.
* 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of
downloading the entire dataset before training begins.
input_config (list): A list of Channel objects. Each channel is a named input source.
Please refer to the format details described:
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
jobs and APIs that create Amazon SageMaker endpoints use this role to access
training data and model artifacts. You must grant sufficient permissions to this
role.
job_name (str): Name of the training job being created.
output_config (dict): The S3 URI where you want to store the training results and
optional KMS key ID.
resource_config (dict): Contains values for ResourceConfig:
* instance_count (int): Number of EC2 instances to use for training.
The key in resource_config is 'InstanceCount'.
* instance_type (str): Type of EC2 instance to use for training, for example,
'ml.c4.xlarge'. The key in resource_config is 'InstanceType'.
vpc_config (dict): Contains values for VpcConfig:
* subnets (list[str]): List of subnet ids.
The key in vpc_config is 'Subnets'.
* security_group_ids (list[str]): List of security group ids.
The key in vpc_config is 'SecurityGroupIds'.
hyperparameters (dict): Hyperparameters for model training. The hyperparameters are
made accessible as a dict[str, str] to the training code on SageMaker. For
convenience, this accepts other types for keys and values, but ``str()`` will be
called to convert them before training.
stop_condition (dict): Defines when training shall finish. Contains entries that can
be understood by the service like ``MaxRuntimeInSeconds``.
tags (Optional[Tags]): Tags for labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s)
used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for
the name of the metric, and 'Regex' for the regular expression used to extract the
metric from the logs.
enable_network_isolation (bool): Whether to request for the training job to run with
network isolation or not.
image_uri (str): Docker image containing training code.
training_image_config(dict): Training image configuration.
Optionally, the dict can contain 'TrainingRepositoryAccessMode' and
'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig').
For example,
.. code:: python
training_image_config = {
"TrainingRepositoryAccessMode": "Vpc",
"TrainingRepositoryAuthConfig": {
"TrainingRepositoryCredentialsProviderArn":
"arn:aws:lambda:us-west-2:1234567890:function:test"
},
}
If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed
through a private Docker registry in customer Vpc. If it's set to Platform or None,
the training image is accessed through ECR.
If TrainingRepositoryCredentialsProviderArn is provided, the credentials to
authenticate to the private Docker registry will be retrieved from this AWS Lambda
function. (default: ``None``). When it's set to None, SageMaker will not do
authentication before pulling the image in the private Docker registry.
container_entry_point (List[str]): Optional. The entrypoint script for a Docker
container used to run a training job. This script takes precedence over
the default train processing instructions.
container_arguments (List[str]): Optional. The arguments for a container used to run
a training job.
algorithm_arn (str): Algorithm Arn from Marketplace.
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
containers is encrypted for the training job (default: ``False``).
use_spot_instances (bool): whether to use spot instances for training.
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
experiment_config (dict[str, str]): Experiment management configuration.
Optionally, the dict can contain four keys:
'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
The behavior of setting these keys is as follows:
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
automatically created and the job's Trial Component associated with the Trial.
* If `TrialName` is supplied and the Trial already exists the job's Trial Component
will be associated with the Trial.
* If both `ExperimentName` and `TrialName` are not supplied the trial component
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
* `RunName` is used to record an experiment run.
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html
#SageMaker-Type
-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default: ``None``).
profiler_rule_configs (list[dict]): A list of profiler rule
configurations.src/sagemaker/lineage/artifact.py:285
profiler_config (dict): Configuration for how profiling information is emitted
with SageMaker Profiler. (default: ``None``).
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,
.. code:: python
remote_debug_config = {
"EnableRemoteDebug": True,
}
session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
The dict can contain 'EnableSessionTagChaining'(bool).
For example,
.. code:: python
session_chaining_config = {
"EnableSessionTagChaining": True,
}
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
* max_retry_attsmpts (int): Number of times a job should be retried.
The key in RetryStrategy is 'MaxRetryAttempts'.
infra_check_config(dict): Infra check configuration.
Optionally, the dict can contain 'EnableInfraCheck'(bool).
For example,
.. code:: python
infra_check_config = {
"EnableInfraCheck": True,
}
Returns:
str: ARN of the training job, if it is created.
Raises:
- botocore.exceptions.ClientError: If Sagemaker throws an exception while creating
training job.
- ValueError: If both image_uri and algorithm are provided, or if neither is provided.
"""
tags = _append_project_tags(format_tags(tags))
tags = self._append_sagemaker_config_tags(
tags, "{}.{}.{}".format(SAGEMAKER, TRAINING_JOB, TAGS)
)
_encrypt_inter_container_traffic = resolve_value_from_config(
direct_input=encrypt_inter_container_traffic,
config_path=TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
default_value=False,
sagemaker_session=self,
)
role = resolve_value_from_config(role, TRAINING_JOB_ROLE_ARN_PATH, sagemaker_session=self)
enable_network_isolation = resolve_value_from_config(
direct_input=enable_network_isolation,
config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
default_value=False,
sagemaker_session=self,
)
inferred_vpc_config = update_nested_dictionary_with_values_from_config(
vpc_config, TRAINING_JOB_VPC_CONFIG_PATH, sagemaker_session=self
)
inferred_output_config = update_nested_dictionary_with_values_from_config(
output_config, TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, sagemaker_session=self
)
customer_supplied_kms_key = "VolumeKmsKeyId" in resource_config
inferred_resource_config = update_nested_dictionary_with_values_from_config(
resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH, sagemaker_session=self
)
inferred_profiler_config = update_nested_dictionary_with_values_from_config(
profiler_config, TRAINING_JOB_PROFILE_CONFIG_PATH, sagemaker_session=self
)