Skip to content

Commit 24a381d

Browse files
authored
Support evaluating on endpoint without type (#13226)
Signed-off-by: B-Step62 <[email protected]>
1 parent 55348b8 commit 24a381d

File tree

2 files changed

+121
-7
lines changed

2 files changed

+121
-7
lines changed

mlflow/metrics/genai/model_utils.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ def _call_openai_api(openai_uri, payload, eval_parameters):
113113
return _parse_chat_response_format(resp)
114114

115115

116+
_PREDICT_ERROR_MSG = """\
117+
Failed to call the deployment endpoint. Please check the deployment URL\
118+
is set correctly and the input payload is valid.\n
119+
- Error: {e}\n
120+
- Deployment URI: {uri}\n
121+
- Input payload: {payload}"""
122+
123+
116124
def _call_deployments_api(deployment_uri, payload, eval_parameters, wrap_payload=True):
117125
"""Call the deployment endpoint with the given payload and parameters.
118126
@@ -142,19 +150,41 @@ def _call_deployments_api(deployment_uri, payload, eval_parameters, wrap_payload
142150
if wrap_payload:
143151
payload = {"prompt": payload}
144152
chat_inputs = {**payload, **eval_parameters}
145-
response = client.predict(endpoint=deployment_uri, inputs=chat_inputs)
153+
try:
154+
response = client.predict(endpoint=deployment_uri, inputs=chat_inputs)
155+
except Exception as e:
156+
raise MlflowException(
157+
_PREDICT_ERROR_MSG.format(e=e, uri=deployment_uri, payload=chat_inputs)
158+
) from e
146159
return _parse_completions_response_format(response)
147160
elif endpoint_type == "llm/v1/chat":
148161
if wrap_payload:
149162
payload = {"messages": [{"role": "user", "content": payload}]}
150163
completion_inputs = {**payload, **eval_parameters}
151-
response = client.predict(endpoint=deployment_uri, inputs=completion_inputs)
164+
try:
165+
response = client.predict(endpoint=deployment_uri, inputs=completion_inputs)
166+
except Exception as e:
167+
raise MlflowException(
168+
_PREDICT_ERROR_MSG.format(e=e, uri=deployment_uri, payload=completion_inputs)
169+
) from e
152170
return _parse_chat_response_format(response)
153-
171+
elif endpoint_type is None:
172+
# If the endpoint type is not specified, we don't assume any format
173+
# and directly send the payload to the endpoint. This is primary for Databricks
174+
# Managed Agent Evaluation, where the endpoint type may not be specified but the
175+
# eval harness ensures that the payload is formatted to the chat format, as well
176+
# as parsing the response.
177+
inputs = {**payload, **eval_parameters}
178+
try:
179+
return client.predict(endpoint=deployment_uri, inputs=inputs)
180+
except Exception as e:
181+
raise MlflowException(
182+
_PREDICT_ERROR_MSG.format(e=e, uri=deployment_uri, payload=inputs)
183+
) from e
154184
else:
155185
raise MlflowException(
156-
f"Unsupported endpoint type: {endpoint_type}. Use an "
157-
"endpoint of type 'llm/v1/completions' or 'llm/v1/chat' instead.",
186+
f"Unsupported endpoint type: {endpoint_type}. Endpoint type, if specified, "
187+
"must be 'llm/v1/completions' or 'llm/v1/chat'.",
158188
error_code=INVALID_PARAMETER_VALUE,
159189
)
160190

tests/evaluate/test_evaluation.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2032,7 +2032,7 @@ def test_evaluate_on_chat_model_endpoint(mock_deploy_client, input_data, feature
20322032
},
20332033
),
20342034
]
2035-
assert all(call in call_args_list for call in expected_calls)
2035+
assert call_args_list == expected_calls
20362036

20372037
# Validate the evaluation metrics
20382038
expected_metrics_subset = {"toxicity/v1/ratio", "ari_grade_level/v1/mean"}
@@ -2089,7 +2089,7 @@ def test_evaluate_on_completion_model_endpoint(mock_deploy_client, input_data, f
20892089
mock.call(endpoint="completions", inputs={"prompt": "What is MLflow?", "max_tokens": 10}),
20902090
mock.call(endpoint="completions", inputs={"prompt": "What is Spark?", "max_tokens": 10}),
20912091
]
2092-
assert all(call in call_args_list for call in expected_calls)
2092+
assert call_args_list == expected_calls
20932093

20942094
# Validate the evaluation metrics
20952095
expected_metrics_subset = {
@@ -2104,6 +2104,90 @@ def test_evaluate_on_completion_model_endpoint(mock_deploy_client, input_data, f
21042104
assert eval_results_table["outputs"].equals(pd.Series(["This is a response"] * 2))
21052105

21062106

2107+
@mock.patch("mlflow.deployments.get_deploy_client")
2108+
def test_evaluate_on_model_endpoint_without_type(mock_deploy_client):
2109+
# An endpoint that does not have endpoint type. For such endpoint, we simply
2110+
# pass the input data to the endpoint without any modification and return
2111+
# the response as is.
2112+
mock_deploy_client.return_value.get_endpoint.return_value = {}
2113+
mock_deploy_client.return_value.predict.return_value = "This is a response"
2114+
2115+
input_data = pd.DataFrame(
2116+
{
2117+
"inputs": [
2118+
{
2119+
"messages": [{"content": q, "role": "user"}],
2120+
"max_tokens": 10,
2121+
}
2122+
for q in _TEST_QUERY_LIST
2123+
],
2124+
"ground_truth": _TEST_GT_LIST,
2125+
}
2126+
)
2127+
2128+
with mlflow.start_run():
2129+
eval_result = mlflow.evaluate(
2130+
model="endpoints:/random",
2131+
data=input_data,
2132+
model_type="question-answering",
2133+
targets="ground_truth",
2134+
inference_params={"max_tokens": 10, "temperature": 0.5},
2135+
)
2136+
2137+
# Validate the endpoint is called with correct payloads
2138+
call_args_list = mock_deploy_client.return_value.predict.call_args_list
2139+
expected_calls = [
2140+
mock.call(
2141+
endpoint="random",
2142+
inputs={
2143+
"messages": [{"content": "What is MLflow?", "role": "user"}],
2144+
"max_tokens": 10,
2145+
"temperature": 0.5,
2146+
},
2147+
),
2148+
mock.call(
2149+
endpoint="random",
2150+
inputs={
2151+
"messages": [{"content": "What is Spark?", "role": "user"}],
2152+
"max_tokens": 10,
2153+
"temperature": 0.5,
2154+
},
2155+
),
2156+
]
2157+
assert call_args_list == expected_calls
2158+
2159+
# Validate the evaluation metrics
2160+
expected_metrics_subset = {"toxicity/v1/ratio", "ari_grade_level/v1/mean", "exact_match/v1"}
2161+
assert expected_metrics_subset.issubset(set(eval_result.metrics.keys()))
2162+
2163+
# Validate the model output is passed to the evaluator in the correct format (string)
2164+
eval_results_table = eval_result.tables["eval_results_table"]
2165+
assert eval_results_table["outputs"].equals(pd.Series(["This is a response"] * 2))
2166+
2167+
2168+
@mock.patch("mlflow.deployments.get_deploy_client")
2169+
def test_evaluate_on_model_endpoint_invalid_payload(mock_deploy_client):
2170+
# An endpoint that does not have endpoint type. For such endpoint, we simply
2171+
# pass the input data to the endpoint without any modification and return
2172+
# the response as is.
2173+
mock_deploy_client.return_value.get_endpoint.return_value = {}
2174+
mock_deploy_client.return_value.predict.side_effect = ValueError("Invalid payload")
2175+
2176+
input_data = pd.DataFrame(
2177+
{
2178+
"inputs": [{"invalid": "payload"}],
2179+
}
2180+
)
2181+
2182+
with pytest.raises(MlflowException, match="Failed to call the deployment endpoint"):
2183+
mlflow.evaluate(
2184+
model="endpoints:/random",
2185+
data=input_data,
2186+
model_type="question-answering",
2187+
inference_params={"max_tokens": 10, "temperature": 0.5},
2188+
)
2189+
2190+
21072191
@pytest.mark.parametrize(
21082192
("input_data", "error_message"),
21092193
[

0 commit comments

Comments
 (0)