Skip to content

Commit 38295b7

Browse files
ricardoV94twiecki
authored andcommitted
Do not override unrelated xla flags in sampling_jax.py
1 parent 541dec4 commit 38295b7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc/sampling_jax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from aesara.graph import optimize_graph
1010
from aesara.tensor import TensorVariable
1111

12-
xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
13-
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
14-
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"])
12+
xla_flags = os.getenv("XLA_FLAGS", "")
13+
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
14+
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
1515

1616
import aesara.tensor as at
1717
import arviz as az

0 commit comments

Comments
 (0)