Skip to content

Commit db98fc6

Browse files
Added each RAI filter assertion
1 parent 62cc0fe commit db98fc6

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

model_armor/snippets/snippets_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,13 @@ def test_sanitize_user_prompt_with_all_rai_filter_template(
684684
template_id, _ = all_filter_template
685685

686686
user_prompt = "How to make cheesecake without oven at home?"
687+
expected_categories = [
688+
"hate_speech",
689+
"sexually_explicit",
690+
"harassment",
691+
"dangerous",
692+
]
693+
687694
response = sanitize_user_prompt(
688695
project_id, location_id, template_id, user_prompt
689696
)
@@ -699,6 +706,14 @@ def test_sanitize_user_prompt_with_all_rai_filter_template(
699706
== modelarmor_v1.FilterMatchState.NO_MATCH_FOUND
700707
)
701708

709+
assert all(
710+
response.sanitization_result.filter_results.get("rai")
711+
.rai_filter_result.rai_filter_type_results.get(expected_category)
712+
.match_state
713+
== modelarmor_v1.FilterMatchState.NO_MATCH_FOUND
714+
for expected_category in expected_categories
715+
)
716+
702717

703718
def test_sanitize_user_prompt_with_malicious_url_template(
704719
project_id: str,
@@ -876,6 +891,12 @@ def test_sanitize_model_response_with_all_rai_filter_template(
876891
model_response = (
877892
"To make cheesecake without oven, you'll need to follow these steps...."
878893
)
894+
expected_categories = [
895+
"hate_speech",
896+
"sexually_explicit",
897+
"harassment",
898+
"dangerous",
899+
]
879900

880901
response = sanitize_model_response(
881902
project_id, location_id, template_id, model_response
@@ -892,6 +913,14 @@ def test_sanitize_model_response_with_all_rai_filter_template(
892913
== modelarmor_v1.FilterMatchState.NO_MATCH_FOUND
893914
)
894915

916+
assert all(
917+
response.sanitization_result.filter_results.get("rai")
918+
.rai_filter_result.rai_filter_type_results.get(expected_category)
919+
.match_state
920+
== modelarmor_v1.FilterMatchState.NO_MATCH_FOUND
921+
for expected_category in expected_categories
922+
)
923+
895924

896925
def test_sanitize_model_response_with_basic_sdp_template(
897926
project_id: str,

0 commit comments

Comments
 (0)