We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 541dec4 commit 38295b7Copy full SHA for 38295b7
pymc/sampling_jax.py
@@ -9,9 +9,9 @@
9
from aesara.graph import optimize_graph
10
from aesara.tensor import TensorVariable
11
12
-xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
13
-xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
14
-os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"])
+xla_flags = os.getenv("XLA_FLAGS", "")
+xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
+os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
15
16
import aesara.tensor as at
17
import arviz as az
0 commit comments