diff --git a/tests/conftest.py b/tests/conftest.py index 7bab05dfb3..ceb2a03f51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -269,7 +269,9 @@ def pytorch_training_py_version(pytorch_training_version, request): @pytest.fixture(scope="module", params=["py2", "py3"]) def pytorch_inference_py_version(pytorch_inference_version, request): - if Version(pytorch_inference_version) >= Version("2.0"): + if Version(pytorch_inference_version) >= Version("2.3"): + return "py311" + elif Version(pytorch_inference_version) >= Version("2.0"): return "py310" elif Version(pytorch_inference_version) >= Version("1.13"): return "py39"