Skip to content

Commit 93addef

Browse files
committed
Fix unit tests
1 parent 033bd42 commit 93addef

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

tests/unit/test_amazon_estimator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def test_data_location_does_not_call_default_bucket(sagemaker_session):
9898
assert not sagemaker_session.default_bucket.called
9999

100100

101-
def test_prepare_for_training():
102-
pca = PCA(num_components=55, **COMMON_ARGS)
101+
def test_prepare_for_training(sagemaker_session):
102+
pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)
103103

104104
train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]]
105105
labels = [99, 85, 87, 2]
@@ -110,8 +110,8 @@ def test_prepare_for_training():
110110
assert pca.mini_batch_size == 1
111111

112112

113-
def test_prepare_for_training_list():
114-
pca = PCA(num_components=55, **COMMON_ARGS)
113+
def test_prepare_for_training_list(sagemaker_session):
114+
pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)
115115

116116
train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]]
117117
labels = [99, 85, 87, 2]
@@ -122,8 +122,8 @@ def test_prepare_for_training_list():
122122
assert pca.mini_batch_size == 1
123123

124124

125-
def test_prepare_for_training_list_no_train_channel():
126-
pca = PCA(num_components=55, **COMMON_ARGS)
125+
def test_prepare_for_training_list_no_train_channel(sagemaker_session):
126+
pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)
127127

128128
train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]]
129129
labels = [99, 85, 87, 2]

tests/unit/test_analytics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def test_abstract_base_class():
5555
AnalyticsMetricsBase()
5656

5757

58-
def test_tuner_name():
59-
tuner = HyperparameterTuningJobAnalytics("my-tuning-job")
58+
def test_tuner_name(sagemaker_session):
59+
tuner = HyperparameterTuningJobAnalytics("my-tuning-job", sagemaker_session=sagemaker_session)
6060
assert tuner.name == "my-tuning-job"
6161
assert str(tuner).find("my-tuning-job") != -1
6262

0 commit comments

Comments
 (0)