Skip to content

Commit 566502f

Browse files
akrishna1995Ashwin Krishna
and
Ashwin Krishna
authored
feat: SDK defaults add disable profiler to createTrainingJob (#3936)
* feat: add disable profiler to createTrainingJob added profilerConfig to the traningjob schema for sagemaker_config disableprofiler attribute is injected via config in session,estimator * fix: Update comments based on PR * fix: spaces, linting, based on black-check, flake8 * fix: comment on PR add support for update_training_job doc change added unitTests --------- Co-authored-by: Ashwin Krishna <[email protected]>
1 parent ca141c0 commit 566502f

File tree

9 files changed

+293
-5
lines changed

9 files changed

+293
-5
lines changed

src/sagemaker/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
TRAINING_JOB_VPC_CONFIG_PATH,
2525
TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH,
2626
TRAINING_JOB_RESOURCE_CONFIG_PATH,
27+
TRAINING_JOB_PROFILE_CONFIG_PATH,
28+
TRAINING_JOB_DISABLE_PROFILER_PATH,
2729
PROCESSING_JOB_INPUTS_PATH,
2830
PROCESSING_JOB,
2931
PROCESSING_JOB_ENVIRONMENT_PATH,
@@ -141,6 +143,8 @@
141143
CLUSTER_ROLE_ARN,
142144
PROCESSING_OUTPUT_CONFIG,
143145
PROCESSING_RESOURCES,
146+
PROFILER_CONFIG,
147+
DISABLE_PROFILER,
144148
RESOURCE_CONFIG,
145149
EXECUTION_ROLE_ARN,
146150
ASYNC_INFERENCE_CONFIG,

src/sagemaker/config/config_schema.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@
9898
CONTAINERS = "Containers"
9999
PRIMARY_CONTAINER = "PrimaryContainer"
100100
INFERENCE_SPECIFICATION = "InferenceSpecification"
101+
PROFILER_CONFIG = "ProfilerConfig"
102+
DISABLE_PROFILER = "DisableProfiler"
101103

102104

103105
def _simple_path(*args: str):
@@ -128,6 +130,10 @@ def _simple_path(*args: str):
128130
TRAINING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS
129131
)
130132
TRAINING_JOB_SUBNETS_PATH = _simple_path(TRAINING_JOB_VPC_CONFIG_PATH, SUBNETS)
133+
TRAINING_JOB_PROFILE_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, PROFILER_CONFIG)
134+
TRAINING_JOB_DISABLE_PROFILER_PATH = _simple_path(
135+
TRAINING_JOB_PROFILE_CONFIG_PATH, DISABLE_PROFILER
136+
)
131137
EDGE_PACKAGING_KMS_KEY_ID_PATH = _simple_path(
132138
SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG, KMS_KEY_ID
133139
)
@@ -1001,6 +1007,11 @@ def _simple_path(*args: str):
10011007
ADDITIONAL_PROPERTIES: False,
10021008
PROPERTIES: {VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}},
10031009
},
1010+
PROFILER_CONFIG: {
1011+
TYPE: OBJECT,
1012+
ADDITIONAL_PROPERTIES: False,
1013+
PROPERTIES: {DISABLE_PROFILER: {TYPE: "boolean"}},
1014+
},
10041015
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
10051016
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
10061017
TAGS: {"$ref": "#/definitions/tags"},

src/sagemaker/estimator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
TRAINING_JOB_ROLE_ARN_PATH,
3838
TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
3939
TRAINING_JOB_ENVIRONMENT_PATH,
40+
TRAINING_JOB_DISABLE_PROFILER_PATH,
4041
TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
4142
)
4243
from sagemaker.debugger import ( # noqa: F401 # pylint: disable=unused-import
@@ -157,7 +158,7 @@ def __init__(
157158
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
158159
enable_network_isolation: Union[bool, PipelineVariable] = None,
159160
profiler_config: Optional[ProfilerConfig] = None,
160-
disable_profiler: bool = False,
161+
disable_profiler: bool = None,
161162
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
162163
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
163164
source_dir: Optional[Union[str, PipelineVariable]] = None,
@@ -687,7 +688,12 @@ def __init__(
687688
)
688689

689690
self.profiler_config = profiler_config
690-
self.disable_profiler = disable_profiler
691+
self.disable_profiler = resolve_value_from_config(
692+
direct_input=disable_profiler,
693+
config_path=TRAINING_JOB_DISABLE_PROFILER_PATH,
694+
default_value=False,
695+
sagemaker_session=self.sagemaker_session,
696+
)
691697

692698
self.environment = resolve_value_from_config(
693699
direct_input=environment,

src/sagemaker/session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
TRAINING_JOB_VPC_CONFIG_PATH,
4646
TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH,
4747
TRAINING_JOB_RESOURCE_CONFIG_PATH,
48+
TRAINING_JOB_PROFILE_CONFIG_PATH,
4849
PROCESSING_JOB_INPUTS_PATH,
4950
PROCESSING_JOB,
5051
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
@@ -833,6 +834,9 @@ def train( # noqa: C901
833834
inferred_resource_config = update_nested_dictionary_with_values_from_config(
834835
resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH, sagemaker_session=self
835836
)
837+
inferred_profiler_config = update_nested_dictionary_with_values_from_config(
838+
profiler_config, TRAINING_JOB_PROFILE_CONFIG_PATH, sagemaker_session=self
839+
)
836840
if (
837841
not customer_supplied_kms_key
838842
and "InstanceType" in inferred_resource_config
@@ -875,7 +879,7 @@ def train( # noqa: C901
875879
tensorboard_output_config=tensorboard_output_config,
876880
enable_sagemaker_metrics=enable_sagemaker_metrics,
877881
profiler_rule_configs=profiler_rule_configs,
878-
profiler_config=profiler_config,
882+
profiler_config=inferred_profiler_config,
879883
environment=environment,
880884
retry_strategy=retry_strategy,
881885
)
@@ -1152,11 +1156,13 @@ def update_training_job(
11521156
# No injections from sagemaker_config because the UpdateTrainingJob API's resource_config
11531157
# object accepts fewer parameters than the CreateTrainingJob API, and none that the
11541158
# sagemaker_config currently supports
1155-
1159+
inferred_profiler_config = update_nested_dictionary_with_values_from_config(
1160+
profiler_config, TRAINING_JOB_PROFILE_CONFIG_PATH, sagemaker_session=self
1161+
)
11561162
update_training_job_request = self._get_update_training_job_request(
11571163
job_name=job_name,
11581164
profiler_rule_configs=profiler_rule_configs,
1159-
profiler_config=profiler_config,
1165+
profiler_config=inferred_profiler_config,
11601166
resource_config=resource_config,
11611167
)
11621168
LOGGER.info("Updating training job with name %s", job_name)

tests/data/config/config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ SageMaker:
167167
ResourceConfig:
168168
VolumeKmsKeyId: 'volumekmskeyid1'
169169
RoleArn: 'arn:aws:iam::555555555555:role/IMRole'
170+
ProfilerConfig:
171+
DisableProfiler: false
170172
VpcConfig:
171173
SecurityGroupIds:
172174
- 'sg123'

tests/unit/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868
PROCESSING_OUTPUT_CONFIG,
6969
PROCESSING_RESOURCES,
7070
TRAINING_JOB,
71+
PROFILER_CONFIG,
72+
DISABLE_PROFILER,
7173
RESOURCE_CONFIG,
7274
TRANSFORM_JOB,
7375
EXECUTION_ROLE_ARN,
@@ -327,6 +329,7 @@
327329
ENVIRONMENT: {"configEnvVar1": "value1", "configEnvVar2": "value2"},
328330
OUTPUT_DATA_CONFIG: {KMS_KEY_ID: "TestKms"},
329331
RESOURCE_CONFIG: {VOLUME_KMS_KEY_ID: "volumekey"},
332+
PROFILER_CONFIG: {DISABLE_PROFILER: False},
330333
ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole",
331334
VPC_CONFIG: {SUBNETS: ["subnets-123"], SECURITY_GROUP_IDS: ["sg-123"]},
332335
TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}],

tests/unit/sagemaker/config/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def valid_training_job_config(valid_iam_role_arn, valid_vpc_config, valid_enviro
142142
"Environment": valid_environment_config,
143143
"OutputDataConfig": {"KmsKeyId": "kmskeyid1"},
144144
"ResourceConfig": {"VolumeKmsKeyId": "volumekmskeyid1"},
145+
"ProfilerConfig": {"DisableProfiler": False},
145146
"RoleArn": valid_iam_role_arn,
146147
"VpcConfig": valid_vpc_config,
147148
}

tests/unit/test_estimator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ def test_framework_initialization_with_sagemaker_config_injection(sagemaker_sess
407407
"TrainingJob"
408408
]["EnableInterContainerTrafficEncryption"]
409409
expected_environment = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["Environment"]
410+
expected_disable_profiler_attribute = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
411+
"ProfilerConfig"
412+
]["DisableProfiler"]
410413
assert framework.role == expected_role_arn
411414
assert framework.enable_network_isolation() == expected_enable_network_isolation
412415
assert (
@@ -418,6 +421,7 @@ def test_framework_initialization_with_sagemaker_config_injection(sagemaker_sess
418421
assert framework.security_group_ids == expected_security_groups
419422
assert framework.subnets == expected_subnets
420423
assert framework.environment == expected_environment
424+
assert framework.disable_profiler == expected_disable_profiler_attribute
421425

422426

423427
def test_estimator_initialization_with_sagemaker_config_injection(sagemaker_session):
@@ -453,6 +457,9 @@ def test_estimator_initialization_with_sagemaker_config_injection(sagemaker_sess
453457
"TrainingJob"
454458
]["EnableInterContainerTrafficEncryption"]
455459
expected_environment = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["Environment"]
460+
expected_disable_profiler_attribute = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
461+
"ProfilerConfig"
462+
]["DisableProfiler"]
456463
assert estimator.role == expected_role_arn
457464
assert estimator.enable_network_isolation() == expected_enable_network_isolation
458465
assert (
@@ -464,6 +471,7 @@ def test_estimator_initialization_with_sagemaker_config_injection(sagemaker_sess
464471
assert estimator.security_group_ids == expected_security_groups
465472
assert estimator.subnets == expected_subnets
466473
assert estimator.environment == expected_environment
474+
assert estimator.disable_profiler == expected_disable_profiler_attribute
467475

468476

469477
def test_estimator_initialization_with_sagemaker_config_injection_no_kms_supported(
@@ -498,6 +506,9 @@ def test_estimator_initialization_with_sagemaker_config_injection_no_kms_support
498506
expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"][
499507
"TrainingJob"
500508
]["EnableInterContainerTrafficEncryption"]
509+
expected_disable_profiler_attribute = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
510+
"ProfilerConfig"
511+
]["DisableProfiler"]
501512
assert estimator.role == expected_role_arn
502513
assert estimator.enable_network_isolation() == expected_enable_network_isolation
503514
assert (
@@ -508,6 +519,7 @@ def test_estimator_initialization_with_sagemaker_config_injection_no_kms_support
508519
assert estimator.volume_kms_key is None
509520
assert estimator.security_group_ids == expected_security_groups
510521
assert estimator.subnets == expected_subnets
522+
assert estimator.disable_profiler == expected_disable_profiler_attribute
511523

512524

513525
def test_estimator_initialization_with_sagemaker_config_injection_partial_kms_support(
@@ -545,6 +557,9 @@ def test_estimator_initialization_with_sagemaker_config_injection_partial_kms_su
545557
expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"][
546558
"TrainingJob"
547559
]["EnableInterContainerTrafficEncryption"]
560+
expected_disable_profiler_attribute = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
561+
"ProfilerConfig"
562+
]["DisableProfiler"]
548563
assert estimator.role == expected_role_arn
549564
assert estimator.enable_network_isolation() == expected_enable_network_isolation
550565
assert (
@@ -555,6 +570,7 @@ def test_estimator_initialization_with_sagemaker_config_injection_partial_kms_su
555570
assert estimator.volume_kms_key == expected_volume_kms_key_id
556571
assert estimator.security_group_ids == expected_security_groups
557572
assert estimator.subnets == expected_subnets
573+
assert estimator.disable_profiler == expected_disable_profiler_attribute
558574

559575

560576
def test_framework_with_heterogeneous_cluster(sagemaker_session):

0 commit comments

Comments
 (0)