Skip to content

Commit 3a19c08

Browse files
rohangujarathiRohan Gujarathimetrizable
authored
change: use sagemaker_session in workflow tests (#2152)
* change: use sagemaker_session in workflow tests * remove unused imports Co-authored-by: Rohan Gujarathi <[email protected]> Co-authored-by: Eric Johnson <[email protected]>
1 parent 4634951 commit 3a19c08

File tree

1 file changed

+4
-24
lines changed

1 file changed

+4
-24
lines changed

tests/integ/test_workflow.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
import time
1919
import uuid
2020

21-
import boto3
2221
import pytest
2322

24-
from botocore.config import Config
2523
from botocore.exceptions import WaiterError
2624
from sagemaker.debugger import (
2725
DebuggerHookConfig,
@@ -32,7 +30,7 @@
3230
from sagemaker.model import Model
3331
from sagemaker.processing import ProcessingInput, ProcessingOutput
3432
from sagemaker.pytorch.estimator import PyTorch
35-
from sagemaker.session import get_execution_role, Session
33+
from sagemaker.session import get_execution_role
3634
from sagemaker.sklearn.estimator import SKLearn
3735
from sagemaker.sklearn.processing import SKLearnProcessor
3836
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
@@ -74,21 +72,6 @@ def role(sagemaker_session):
7472
return get_execution_role(sagemaker_session)
7573

7674

77-
@pytest.fixture(scope="module")
78-
def workflow_session(region_name):
79-
boto_session = boto3.Session(region_name=region_name)
80-
81-
sagemaker_client_config = dict()
82-
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=2)))
83-
sagemaker_client = boto_session.client("sagemaker", **sagemaker_client_config)
84-
85-
return Session(
86-
boto_session=boto_session,
87-
sagemaker_client=sagemaker_client,
88-
sagemaker_runtime_client=None,
89-
)
90-
91-
9275
@pytest.fixture(scope="module")
9376
def script_dir():
9477
return os.path.join(DATA_DIR, "sklearn_processing")
@@ -119,7 +102,6 @@ def athena_dataset_definition(sagemaker_session):
119102

120103
def test_three_step_definition(
121104
sagemaker_session,
122-
workflow_session,
123105
region_name,
124106
role,
125107
script_dir,
@@ -205,7 +187,7 @@ def test_three_step_definition(
205187
name=pipeline_name,
206188
parameters=[instance_type, instance_count, output_prefix],
207189
steps=[step_process, step_train, step_model],
208-
sagemaker_session=workflow_session,
190+
sagemaker_session=sagemaker_session,
209191
)
210192

211193
definition = json.loads(pipeline.definition())
@@ -277,7 +259,6 @@ def test_three_step_definition(
277259

278260
def test_one_step_sklearn_processing_pipeline(
279261
sagemaker_session,
280-
workflow_session,
281262
role,
282263
sklearn_latest_version,
283264
cpu_instance_type,
@@ -313,7 +294,7 @@ def test_one_step_sklearn_processing_pipeline(
313294
name=pipeline_name,
314295
parameters=[instance_count],
315296
steps=[step_sklearn],
316-
sagemaker_session=workflow_session,
297+
sagemaker_session=sagemaker_session,
317298
)
318299

319300
try:
@@ -363,7 +344,6 @@ def test_one_step_sklearn_processing_pipeline(
363344

364345
def test_conditional_pytorch_training_model_registration(
365346
sagemaker_session,
366-
workflow_session,
367347
role,
368348
cpu_instance_type,
369349
pipeline_name,
@@ -433,7 +413,7 @@ def test_conditional_pytorch_training_model_registration(
433413
name=pipeline_name,
434414
parameters=[good_enough_input, instance_count, instance_type],
435415
steps=[step_cond],
436-
sagemaker_session=workflow_session,
416+
sagemaker_session=sagemaker_session,
437417
)
438418

439419
try:

0 commit comments

Comments
 (0)