@@ -684,6 +684,13 @@ def test_sanitize_user_prompt_with_all_rai_filter_template(
684
684
template_id , _ = all_filter_template
685
685
686
686
user_prompt = "How to make cheesecake without oven at home?"
687
+ expected_categories = [
688
+ "hate_speech" ,
689
+ "sexually_explicit" ,
690
+ "harassment" ,
691
+ "dangerous" ,
692
+ ]
693
+
687
694
response = sanitize_user_prompt (
688
695
project_id , location_id , template_id , user_prompt
689
696
)
@@ -699,6 +706,14 @@ def test_sanitize_user_prompt_with_all_rai_filter_template(
699
706
== modelarmor_v1 .FilterMatchState .NO_MATCH_FOUND
700
707
)
701
708
709
+ assert all (
710
+ response .sanitization_result .filter_results .get ("rai" )
711
+ .rai_filter_result .rai_filter_type_results .get (expected_category )
712
+ .match_state
713
+ == modelarmor_v1 .FilterMatchState .NO_MATCH_FOUND
714
+ for expected_category in expected_categories
715
+ )
716
+
702
717
703
718
def test_sanitize_user_prompt_with_malicious_url_template (
704
719
project_id : str ,
@@ -876,6 +891,12 @@ def test_sanitize_model_response_with_all_rai_filter_template(
876
891
model_response = (
877
892
"To make cheesecake without oven, you'll need to follow these steps...."
878
893
)
894
+ expected_categories = [
895
+ "hate_speech" ,
896
+ "sexually_explicit" ,
897
+ "harassment" ,
898
+ "dangerous" ,
899
+ ]
879
900
880
901
response = sanitize_model_response (
881
902
project_id , location_id , template_id , model_response
@@ -892,6 +913,14 @@ def test_sanitize_model_response_with_all_rai_filter_template(
892
913
== modelarmor_v1 .FilterMatchState .NO_MATCH_FOUND
893
914
)
894
915
916
+ assert all (
917
+ response .sanitization_result .filter_results .get ("rai" )
918
+ .rai_filter_result .rai_filter_type_results .get (expected_category )
919
+ .match_state
920
+ == modelarmor_v1 .FilterMatchState .NO_MATCH_FOUND
921
+ for expected_category in expected_categories
922
+ )
923
+
895
924
896
925
def test_sanitize_model_response_with_basic_sdp_template (
897
926
project_id : str ,
0 commit comments