diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 8bdd7c8ae3..92273f2c9a 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -44,7 +44,7 @@ nbformat>=5.9,<6 accelerate>=0.24.1,<=0.27.0 schema==0.7.5 tensorflow>=2.16.2,<=2.18.0 -mlflow>=2.12.2,<2.13 +mlflow>=2.14.2,<3 huggingface_hub==0.26.2 uvicorn>=0.30.1 fastapi==0.115.4 diff --git a/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py b/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py index 4b53c93ad4..14502880c3 100644 --- a/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py +++ b/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py @@ -48,7 +48,7 @@ def mock_mlflow_client(): def test_encode(): existing_names = set() assert encode("test-name", existing_names) == "test-name" - assert encode("test:name", existing_names) == "test_3a_name" + assert encode("test:name", existing_names) == "test:name" assert encode("test-name", existing_names) == "test-name_1" @@ -183,6 +183,7 @@ def getenv_side_effect(arg, default=None): spec=requests.Response ), "https://test.sagemaker.aws/api/2.0/mlflow/runs/create": Mock(spec=requests.Response), + "https://test.sagemaker.aws/api/2.0/mlflow/runs/update": Mock(spec=requests.Response), "https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch": [ Mock(spec=requests.Response), Mock(spec=requests.Response), @@ -211,6 +212,11 @@ def getenv_side_effect(arg, default=None): {"run_id": "test_run_id"} ) + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/update"].status_code = 200 + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/update"].text = json.dumps( + {"run_id": "test_run_id"} + ) + for mock_response in mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch"]: mock_response.status_code = 200 mock_response.text = json.dumps({}) @@ -221,6 +227,7 @@ def getenv_side_effect(arg, default=None): mock_request.side_effect = [ mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name"], mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/create"], + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/update"], *mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch"], mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate"], ] @@ -231,7 +238,7 @@ def getenv_side_effect(arg, default=None): log_to_mlflow(metrics, params, tags) - assert mock_request.call_count == 6 # Total number of API calls + assert mock_request.call_count == 7 # Total number of API calls @patch("sagemaker.mlflow.forward_sagemaker_metrics.get_training_job_details")