Skip to content

fix: add describe_transform_job in session class #1507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 19, 2020
12 changes: 12 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2573,6 +2573,18 @@ def wait_for_tuning_job(self, job, poll=5):
self._check_job_status(job, desc, "HyperParameterTuningJobStatus")
return desc

def describe_transform_job(self, job_name):
"""Calls the DescribeTransformJob API for the given job name
and returns the response.

Args:
job_name (str): The name of the transform job to describe.

Returns:
dict: A dictionary response with the transform job description.
"""
return self.sagemaker_client.describe_transform_job(TransformJobName=job_name)

def wait_for_transform_job(self, job, poll=5):
"""Wait for an Amazon SageMaker transform job to complete.

Expand Down
12 changes: 5 additions & 7 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def test_transform_mxnet(
):
transformer.wait()

job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=transformer.latest_transform_job.name
job_desc = transformer.sagemaker_session.describe_transform_job(
job_name=transformer.latest_transform_job.name
)
assert kms_key_arn == job_desc["TransformResources"]["VolumeKmsKeyId"]
assert output_filter == job_desc["DataProcessing"]["OutputFilter"]
Expand Down Expand Up @@ -323,8 +323,8 @@ def test_stop_transform_job(mxnet_estimator, mxnet_transform_input, cpu_instance

transformer.stop_transform_job()

desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=latest_transform_job_name
desc = transformer.latest_transform_job.sagemaker_session.describe_transform_job(
job_name=latest_transform_job_name
)
assert desc["TransformJobStatus"] == "Stopped"

Expand Down Expand Up @@ -393,9 +393,7 @@ def test_transform_tf_kms_network_isolation(sagemaker_session, cpu_instance_type
)
assert model_desc["EnableNetworkIsolation"]

job_desc = sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=job_name
)
job_desc = sagemaker_session.describe_transform_job(job_name=job_name)
assert job_desc["TransformOutput"]["S3OutputPath"] == output_path
assert job_desc["TransformOutput"]["KmsKeyId"] == kms_key
assert job_desc["TransformResources"]["VolumeKmsKeyId"] == kms_key
Expand Down