Skip to content

Commit adb04d7

Browse files
committed
Unit test fixes
1 parent 69955d4 commit adb04d7

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def mock_mlflow_client():
4848
def test_encode():
4949
existing_names = set()
5050
assert encode("test-name", existing_names) == "test-name"
51-
assert encode("test:name", existing_names) == "test_3a_name"
51+
assert encode("test:name", existing_names) == "test:name"
5252
assert encode("test-name", existing_names) == "test-name_1"
5353

5454

@@ -183,6 +183,7 @@ def getenv_side_effect(arg, default=None):
183183
spec=requests.Response
184184
),
185185
"https://test.sagemaker.aws/api/2.0/mlflow/runs/create": Mock(spec=requests.Response),
186+
"https://test.sagemaker.aws/api/2.0/mlflow/runs/update": Mock(spec=requests.Response),
186187
"https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch": [
187188
Mock(spec=requests.Response),
188189
Mock(spec=requests.Response),
@@ -211,6 +212,11 @@ def getenv_side_effect(arg, default=None):
211212
{"run_id": "test_run_id"}
212213
)
213214

215+
mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/update"].status_code = 200
216+
mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/update"].text = json.dumps(
217+
{"run_id": "test_run_id"}
218+
)
219+
214220
for mock_response in mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch"]:
215221
mock_response.status_code = 200
216222
mock_response.text = json.dumps({})
@@ -221,6 +227,7 @@ def getenv_side_effect(arg, default=None):
221227
mock_request.side_effect = [
222228
mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name"],
223229
mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/create"],
230+
mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/update"],
224231
*mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch"],
225232
mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate"],
226233
]
@@ -231,7 +238,7 @@ def getenv_side_effect(arg, default=None):
231238

232239
log_to_mlflow(metrics, params, tags)
233240

234-
assert mock_request.call_count == 6 # Total number of API calls
241+
assert mock_request.call_count == 7 # Total number of API calls
235242

236243

237244
@patch("sagemaker.mlflow.forward_sagemaker_metrics.get_training_job_details")

0 commit comments

Comments
 (0)