Skip to content

Commit 5e620a9

Browse files
authored
Fix SeamlessM4Tv2ModelIntegrationTest (#27911)
change dtype of some integration tests
1 parent e96c1de commit 5e620a9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1014,8 +1014,9 @@ def input_audio(self):
10141014
)
10151015

10161016
def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs):
1017-
model1 = class1.from_pretrained(self.repo_id).to(torch_device)
1018-
model2 = class2.from_pretrained(self.repo_id).to(torch_device)
1017+
# half-precision loading to limit GPU usage
1018+
model1 = class1.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
1019+
model2 = class2.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
10191020

10201021
set_seed(0)
10211022
output_1 = model1.generate(**inputs, **class1_kwargs)

0 commit comments

Comments
 (0)