@@ -48,7 +48,7 @@ def mock_mlflow_client():
48
48
def test_encode ():
49
49
existing_names = set ()
50
50
assert encode ("test-name" , existing_names ) == "test-name"
51
- assert encode ("test:name" , existing_names ) == "test_3a_name "
51
+ assert encode ("test:name" , existing_names ) == "test:name "
52
52
assert encode ("test-name" , existing_names ) == "test-name_1"
53
53
54
54
@@ -183,6 +183,7 @@ def getenv_side_effect(arg, default=None):
183
183
spec = requests .Response
184
184
),
185
185
"https://test.sagemaker.aws/api/2.0/mlflow/runs/create" : Mock (spec = requests .Response ),
186
+ "https://test.sagemaker.aws/api/2.0/mlflow/runs/update" : Mock (spec = requests .Response ),
186
187
"https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch" : [
187
188
Mock (spec = requests .Response ),
188
189
Mock (spec = requests .Response ),
@@ -211,6 +212,11 @@ def getenv_side_effect(arg, default=None):
211
212
{"run_id" : "test_run_id" }
212
213
)
213
214
215
+ mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/update" ].status_code = 200
216
+ mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/update" ].text = json .dumps (
217
+ {"run_id" : "test_run_id" }
218
+ )
219
+
214
220
for mock_response in mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch" ]:
215
221
mock_response .status_code = 200
216
222
mock_response .text = json .dumps ({})
@@ -221,6 +227,7 @@ def getenv_side_effect(arg, default=None):
221
227
mock_request .side_effect = [
222
228
mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name" ],
223
229
mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/create" ],
230
+ mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/update" ],
224
231
* mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch" ],
225
232
mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate" ],
226
233
]
@@ -231,7 +238,7 @@ def getenv_side_effect(arg, default=None):
231
238
232
239
log_to_mlflow (metrics , params , tags )
233
240
234
- assert mock_request .call_count == 6 # Total number of API calls
241
+ assert mock_request .call_count == 7 # Total number of API calls
235
242
236
243
237
244
@patch ("sagemaker.mlflow.forward_sagemaker_metrics.get_training_job_details" )
0 commit comments