diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..46800a1e13 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -82,6 +82,7 @@ jobs: install-numba: [0] install-jax: [0] install-torch: [0] + install-mlx: [0] part: - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" @@ -115,6 +116,7 @@ jobs: install-numba: 0 install-jax: 0 install-torch: 0 + install-mlx: 0 - install-numba: 1 os: "ubuntu-latest" python-version: "3.10" @@ -150,6 +152,13 @@ jobs: fast-compile: 0 float32: 0 part: "tests/link/pytorch" + - install-mlx: 1 + os: "ubuntu-latest" + python-version: "3.10" + numpy-version: ">=2.0" + fast-compile: 0 + float32: 0 + part: "tests/link/mlx" - os: macos-15 python-version: "3.13" numpy-version: ">=2.0" @@ -196,6 +205,7 @@ jobs: if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi + if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi pip install pytest-sphinx pip install -e ./ @@ -212,6 +222,7 @@ jobs: INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_TORCH: ${{ matrix.install-torch}} + INSTALL_MLX: ${{ matrix.install-mlx }} OS: ${{ matrix.os}} - name: Run tests diff --git a/.gitignore b/.gitignore index dfe862b868..ebe8e61bd0 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,6 @@ __pycache__ \#*\# build compiled/*.cpp -core.* cutils_ext.cpp dist doc/.build/ diff --git a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb new file mode 100644 index 0000000000..a5e1b2dfa4 --- /dev/null +++ b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb @@ -0,0 +1,399 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "from pytensor.compile.function import function\n", + "from pytensor.compile.mode import Mode\n", + "from pytensor.graph import RewriteDatabaseQuery\n", + "from pytensor.link.jax import JAXLinker\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure JAX to use float32 for consistency with MLX\n", + "jax.config.update(\"jax_enable_x64\", False)\n", + "\n", + "# Set up PyTensor JAX mode\n", + "jax_optimizer = RewriteDatabaseQuery(include=[\"jax\"], exclude=[])\n", + "pytensor_jax_mode = \"JAX\"\n", + "\n", + "# Try to set up MLX mode\n", + "try:\n", + " from pytensor.link.mlx import MLXLinker\n", + " import mlx.core as mx\n", + " mlx_optimizer = RewriteDatabaseQuery(include=[\"mlx\"], exclude=[])\n", + " pytensor_mlx_mode = \"MLX\"\n", + " MLX_AVAILABLE = True\n", + "except ImportError:\n", + " MLX_AVAILABLE = False\n", + "\n", + "def timer_jax(func, N=1000):\n", + " \"\"\"Time function execution with proper JAX synchronization, repeated N times\"\"\"\n", + " def wrapper(*args, **kwargs):\n", + " times = []\n", + " for _ in range(N):\n", + " start = time.perf_counter()\n", + " result = func(*args, **kwargs)\n", + " if hasattr(result, 'block_until_ready'):\n", + " result.block_until_ready()\n", + " elif isinstance(result, (list, tuple)):\n", + " for r in result:\n", + " if hasattr(r, 'block_until_ready'):\n", + " r.block_until_ready()\n", + " end = time.perf_counter()\n", + " times.append(end - start)\n", + " \n", + " mean_time = np.mean(times)\n", + " std_time = np.std(times)\n", + " return result, mean_time, std_time\n", + " return wrapper\n", + "\n", + "def timer_mlx(func, N=1000):\n", + " \"\"\"Time function execution with proper MLX synchronization, repeated N times\"\"\"\n", + " def wrapper(*args, **kwargs):\n", + " times = []\n", + " for _ in range(N):\n", + " start = time.perf_counter()\n", + " result = func(*args, **kwargs)\n", + " # For MLX, we need to use mx.eval() to force computation\n", + " if MLX_AVAILABLE:\n", + " if isinstance(result, (list, tuple)):\n", + " mx.eval(*result)\n", + " else:\n", + " mx.eval(result)\n", + " end = time.perf_counter()\n", + " times.append(end - start)\n", + " \n", + " mean_time = np.mean(times)\n", + " std_time = np.std(times)\n", + " return result, mean_time, std_time\n", + " return wrapper\n", + "\n", + "def run_benchmark(N=1000):\n", + " \"\"\"Run comprehensive benchmark comparing PyTensor JAX vs MLX backends\"\"\"\n", + " import pandas as pd\n", + " \n", + " sizes = [2, 4, 2000, 4000]\n", + " results = []\n", + " \n", + " print(f\"Running benchmarks with N={N} repetitions per test...\")\n", + " \n", + " for size in sizes:\n", + " print(f\"Testing {size}x{size} matrices...\")\n", + " \n", + " # Generate test matrices with fixed seed for reproducibility\n", + " np.random.seed(42)\n", + " A = np.random.randn(size, size).astype(np.float32)\n", + " B = np.random.randn(size, size).astype(np.float32)\n", + " C = np.random.randn(size, size).astype(np.float32)\n", + "\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32') \n", + " pt_C = pt.matrix('C', dtype='float32')\n", + " result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n", + "\n", + "\n", + " f_jax = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode, trust_input=True)\n", + " f_mlx = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode, trust_input=True)\n", + " f_jax(A, B, C)\n", + " f_mlx(A, B, C)\n", + " \n", + " # === TEST 1: Matrix Multiplication Chain ===\n", + " # PyTensor + JAX backend\n", + " @timer_jax\n", + " def pytensor_jax_matmul():\n", + " return f_jax(A, B, C)\n", + " \n", + " # PyTensor + MLX backend\n", + " @timer_mlx\n", + " def pytensor_mlx_matmul():\n", + " if not MLX_AVAILABLE:\n", + " return None, float('inf'), 0\n", + " return f_mlx(A, B, C)\n", + " \n", + " # Run matrix multiplication test\n", + " _, jax_mean, jax_std = pytensor_jax_matmul()\n", + " try:\n", + " _, mlx_mean, mlx_std = pytensor_mlx_matmul()\n", + " except Exception as e:\n", + " print(f\"MLX matmul error: {e}\")\n", + " mlx_mean, mlx_std = float('inf'), 0\n", + " \n", + " # Calculate percentage improvement (positive = MLX is faster, negative = MLX is slower)\n", + " if mlx_mean != float('inf') and mlx_mean > 0:\n", + " speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n", + " speedup_str = f'{speedup_percentage:+.1f}%'\n", + " else:\n", + " speedup_str = 'N/A'\n", + " \n", + " results.append({\n", + " 'Size': f'{size}x{size}',\n", + " 'Operation': 'Matrix Chain (A @ B @ C)',\n", + " 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n", + " 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n", + " 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n", + " 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n", + " 'MLX Performance': speedup_str\n", + " })\n", + " \n", + " # === TEST 2: Element-wise Operations ===\n", + " # PyTensor + JAX\n", + " result = pt.sin(pt_A) + pt.cos(pt_B)\n", + " f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n", + " f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n", + " f_jax(A, B)\n", + " f_mlx(A, B)\n", + "\n", + " @timer_jax\n", + " def pytensor_jax_elemwise():\n", + " return f_jax(A, B)\n", + " \n", + " # PyTensor + MLX\n", + " @timer_mlx\n", + " def pytensor_mlx_elemwise():\n", + " if not MLX_AVAILABLE:\n", + " return None, float('inf'), 0\n", + " return f_mlx(A, B)\n", + " \n", + " # Run element-wise test\n", + " _, jax_mean, jax_std = pytensor_jax_elemwise()\n", + " try:\n", + " _, mlx_mean, mlx_std = pytensor_mlx_elemwise()\n", + " except Exception as e:\n", + " print(f\"MLX elemwise error: {e}\")\n", + " mlx_mean, mlx_std = float('inf'), 0\n", + " \n", + " # Calculate percentage improvement\n", + " if mlx_mean != float('inf') and mlx_mean > 0:\n", + " speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n", + " speedup_str = f'{speedup_percentage:+.1f}%'\n", + " else:\n", + " speedup_str = 'N/A'\n", + " \n", + " results.append({\n", + " 'Size': f'{size}x{size}',\n", + " 'Operation': 'Element-wise (sin(A) + cos(B))',\n", + " 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n", + " 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n", + " 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n", + " 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n", + " 'MLX Performance': speedup_str\n", + " })\n", + " \n", + " # === TEST 3: Matrix Addition with Broadcasting ===\n", + " # PyTensor + JAX\n", + " result = pt_A + pt_B.T\n", + " f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n", + " f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n", + " f_jax(A, B)\n", + " f_mlx(A, B)\n", + " @timer_jax\n", + " def pytensor_jax_broadcast():\n", + " return f_jax(A, B)\n", + " \n", + " # PyTensor + MLX\n", + " @timer_mlx\n", + " def pytensor_mlx_broadcast():\n", + " if not MLX_AVAILABLE:\n", + " return None, float('inf'), 0\n", + " return f_mlx(A, B)\n", + " \n", + " # Run broadcasting test\n", + " _, jax_mean, jax_std = pytensor_jax_broadcast()\n", + " try:\n", + " _, mlx_mean, mlx_std = pytensor_mlx_broadcast()\n", + " except Exception as e:\n", + " print(f\"MLX broadcast error: {e}\")\n", + " mlx_mean, mlx_std = float('inf'), 0\n", + " \n", + " # Calculate percentage improvement\n", + " if mlx_mean != float('inf') and mlx_mean > 0:\n", + " speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n", + " speedup_str = f'{speedup_percentage:+.1f}%'\n", + " else:\n", + " speedup_str = 'N/A'\n", + " \n", + " results.append({\n", + " 'Size': f'{size}x{size}',\n", + " 'Operation': 'Broadcasting (A + B.T)',\n", + " 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n", + " 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n", + " 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n", + " 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n", + " 'MLX Performance': speedup_str\n", + " })\n", + " \n", + " # Create and display results table\n", + " df = pd.DataFrame(results)\n", + " return df\n", + "\n", + "def main(N=1000):\n", + " \"\"\"Main benchmark execution\"\"\"\n", + " # Display system info\n", + " system_info = {\n", + " 'JAX version': jax.__version__,\n", + " 'PyTensor version': pytensor.__version__,\n", + " 'MLX Available': 'Yes' if MLX_AVAILABLE else 'No',\n", + " 'Platform': 'Apple Silicon' if MLX_AVAILABLE else 'Generic',\n", + " 'Repetitions (N)': N\n", + " }\n", + " \n", + " if MLX_AVAILABLE:\n", + " system_info['MLX version'] = mx.__version__\n", + " \n", + " import pandas as pd\n", + " info_df = pd.DataFrame([system_info])\n", + " \n", + " # Then run benchmarks\n", + " results_df = run_benchmark(N=N)\n", + " \n", + " return info_df, results_df\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running benchmarks with N=150 repetitions per test...\n", + "Testing 2x2 matrices...\n", + "Testing 4x4 matrices...\n", + "Testing 2000x2000 matrices...\n", + "Testing 4000x4000 matrices...\n" + ] + } + ], + "source": [ + "iteration=150\n", + "_, results = main(N=iteration)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Benchmark Results over 150 repetitions:\n", + " Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Performance\n", + " 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000311 0.000266 -3277.7%\n", + " 2x2 Element-wise (sin(A) + cos(B)) 0.000008 0.000003 0.000233 0.000105 -2830.3%\n", + " 2x2 Broadcasting (A + B.T) 0.000007 0.000003 0.000253 0.000151 -3429.1%\n", + " 4x4 Matrix Chain (A @ B @ C) 0.000011 0.000008 0.000285 0.000111 -2537.7%\n", + " 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000235 0.000124 -3217.0%\n", + " 4x4 Broadcasting (A + B.T) 0.000007 0.000002 0.000202 0.000077 -2755.8%\n", + "2000x2000 Matrix Chain (A @ B @ C) 0.024714 0.000919 0.004166 0.003531 +83.1%\n", + "2000x2000 Element-wise (sin(A) + cos(B)) 0.009464 0.000417 0.000844 0.000284 +91.1%\n", + "2000x2000 Broadcasting (A + B.T) 0.000690 0.000022 0.000821 0.000093 -19.0%\n", + "4000x4000 Matrix Chain (A @ B @ C) 0.196587 0.008780 0.027411 0.001132 +86.1%\n", + "4000x4000 Element-wise (sin(A) + cos(B)) 0.037744 0.001247 0.003355 0.000467 +91.1%\n", + "4000x4000 Broadcasting (A + B.T) 0.012233 0.000421 0.003323 0.000370 +72.8%\n" + ] + } + ], + "source": [ + "print(f\"\\nBenchmark Results over {iteration} repetitions:\")\n", + "print(results.to_string(index=False))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# # Additional timing analysis - separate compilation vs execution time\n", + "# if MLX_AVAILABLE:\n", + "# print(\"\\n=== Detailed MLX Timing Analysis ===\")\n", + " \n", + "# # Test with medium-sized matrix\n", + "# np.random.seed(42)\n", + "# A = np.random.randn(512, 512).astype(np.float32)\n", + "# B = np.random.randn(512, 512).astype(np.float32)\n", + "# C = np.random.randn(512, 512).astype(np.float32)\n", + " \n", + "# # Create PyTensor function (compilation time)\n", + "# start = time.perf_counter()\n", + "# pt_A = pt.matrix('A', dtype='float32')\n", + "# pt_B = pt.matrix('B', dtype='float32')\n", + "# pt_C = pt.matrix('C', dtype='float32')\n", + "# result_expr = pt_A @ pt_B @ pt_C\n", + "# f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n", + "# compilation_time = time.perf_counter() - start\n", + " \n", + "# # First execution (may include additional compilation/optimization)\n", + "# start = time.perf_counter()\n", + "# result = f_mlx(A, B, C)\n", + "# mx.eval(result) # Force evaluation\n", + "# first_exec_time = time.perf_counter() - start\n", + " \n", + "# # Subsequent executions (should be faster)\n", + "# exec_times = []\n", + "# for _ in range(1000):\n", + "# start = time.perf_counter()\n", + "# result = f_mlx(A, B, C)\n", + "# mx.eval(result)\n", + "# exec_times.append(time.perf_counter() - start)\n", + " \n", + "# avg_exec_time = np.mean(exec_times)\n", + "# std_exec_time = np.std(exec_times)\n", + " \n", + "# print(f\"Compilation time: {compilation_time:.4f}s\")\n", + "# print(f\"First execution: {first_exec_time:.4f}s\")\n", + "# print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n", + "# print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mlx_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index f80dfaaf5c..8dc7c742bc 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -27,6 +27,7 @@ from pytensor.link.basic import Linker, PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.jax.linker import JAXLinker +from pytensor.link.mlx.linker import MLXLinker from pytensor.link.numba.linker import NumbaLinker from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.vm import VMLinker @@ -50,6 +51,7 @@ "jax": JAXLinker(), "pytorch": PytorchLinker(), "numba": NumbaLinker(), + "mlx": MLXLinker(), } @@ -494,6 +496,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) +MLX = Mode( + MLXLinker(), + RewriteDatabaseQuery( + include=["fast_run"], + exclude=[ + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + "scan_save_mem_prealloc", + ], + ), +) + predefined_modes = { "FAST_COMPILE": FAST_COMPILE, @@ -501,6 +517,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "JAX": JAX, "NUMBA": NUMBA, "PYTORCH": PYTORCH, + "MLX": MLX, } _CACHED_RUNTIME_MODES: dict[str, Mode] = {} diff --git a/pytensor/link/mlx/__init__.py b/pytensor/link/mlx/__init__.py new file mode 100644 index 0000000000..d5a6ab19ff --- /dev/null +++ b/pytensor/link/mlx/__init__.py @@ -0,0 +1 @@ +from pytensor.link.mlx.linker import MLXLinker diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py new file mode 100644 index 0000000000..f039263a37 --- /dev/null +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -0,0 +1,13 @@ +# isort: off +from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify + +import pytensor.link.mlx.dispatch.math +import pytensor.link.mlx.dispatch.basic +import pytensor.link.mlx.dispatch.elemwise +import pytensor.link.mlx.dispatch.shape +import pytensor.link.mlx.dispatch.subtensor +import pytensor.link.mlx.dispatch.core +import pytensor.link.mlx.dispatch.signal +import pytensor.link.mlx.dispatch.signal.conv +import pytensor.link.mlx.dispatch.blockwise +# isort: on diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py new file mode 100644 index 0000000000..7b9a7e68f9 --- /dev/null +++ b/pytensor/link/mlx/dispatch/basic.py @@ -0,0 +1,84 @@ +import warnings +from copy import deepcopy +from functools import singledispatch +from types import NoneType + +import mlx.core as mx +import numpy as np + +from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.fg import FunctionGraph +from pytensor.link.utils import fgraph_to_python +from pytensor.raise_op import Assert, CheckAndRaise + + +@singledispatch +def mlx_typify(data, **kwargs): + raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}") + + +@mlx_typify.register(np.ndarray) +def mlx_typify_tensor(data, dtype=None, **kwargs): + return mx.array(data, dtype=dtype) + + +@mlx_typify.register(slice) +@mlx_typify.register(NoneType) +@mlx_typify.register(np.number) +@mlx_typify.register(mx.array) +def mlx_typify_no_conversion_needed(data, **kwargs): + return data + + +@mlx_typify.register(int) +@mlx_typify.register(float) +def mlx_typify_python_scalar(data, **kwargs): + return mx.array(data) + + +@singledispatch +def mlx_funcify(op, node=None, storage_map=None, **kwargs): + """Create a MLX compatible function from an PyTensor `Op`.""" + raise NotImplementedError( + f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation" + ) + + +@mlx_funcify.register(FunctionGraph) +def mlx_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="mlx_funcified_fgraph", + conversion_func=mlx_funcify, + **kwargs, +): + built_kwargs = {"conversion_func": conversion_func, **kwargs} + return fgraph_to_python( + fgraph, + conversion_func, + type_conversion_fn=mlx_typify, + fgraph_name=fgraph_name, + **built_kwargs, + ) + + +@mlx_funcify.register(DeepCopyOp) +def mlx_funcify_DeepCopyOp(op, **kwargs): + def deepcopyop(x): + return deepcopy(x) + + return deepcopyop + + +@mlx_funcify.register(Assert) +@mlx_funcify.register(CheckAndRaise) +def mlx_funcify_CheckAndRaise(op, **kwargs): + warnings.warn( + f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""", + stacklevel=2, + ) + + def assert_fn(x, *inputs): + return x + + return assert_fn diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py new file mode 100644 index 0000000000..7a6bda8a66 --- /dev/null +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -0,0 +1,107 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.signal.conv import Conv1d + + +def blockwise_conv1d(op, node, **kwargs): + """ + Custom implementation of Blockwise.conv1d for MLX. + """ + + def batched_conv1d( + x: mx.array, + kernels: mx.array, + mode: str = op.core_op.mode, + stride: int = 1, + dilation: int = 1, + ) -> mx.array: + """ + Apply B separate 1D convolutions (full or valid) to B sequences in parallel. + + Parameters + ---------- + x : array of shape (B, T) + B sequences of length T. + kernels : array of shape (B, K) + B kernels of length K. + mode : {"valid", "full"} + "valid" → no padding, output length = T - K + 1 + "full" → zero-pad so output length = T + K - 1 + stride : int, convolution stride (default=1) + dilation : int, convolution dilation (default=1) + + Returns + ------- + out : array of shape (B, L) + where L = + - T - K + 1 if mode="valid" + - T + K - 1 if mode="full" + """ + # --- 1) shape checks --- + B, T = x.shape + Bk, K = kernels.shape + if B != Bk: + raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") + + # --- 2) flip kernels for convolution --- + kernels_flipped = kernels[:, ::-1] # shape (B, K) + + # --- 3) decide padding --- + if mode == "valid": + pad = 0 + elif mode == "full": + pad = (K - 1) * dilation + else: + raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'") + + # --- 4) reshape into MLX conv1d form --- + # input: (N=1, H=T, C_in=B) + x_in = x.T[None, :, :] + + # weight: (C_out=B, H_f=K, C_in=1) + w = kernels_flipped[:, :, None] + + # --- 5) run grouped conv1d --- + y = mx.conv1d(x_in, w, stride=stride, padding=pad, dilation=dilation, groups=B) + # y shape: (1, H_out, B) + + # --- 6) return shape (B, H_out) --- + return y[0].T + + return batched_conv1d + + +@mlx_funcify.register(Blockwise) +def funcify_Blockwise(op: Blockwise, node, **kwargs): + # 1) If it's a Conv1d Blockwise, use the custom implementation + if isinstance(op.core_op, Conv1d): + return blockwise_conv1d(op, node, **kwargs) + + # 2) Otherwise, get the core python function for this Blockwise + core_node = op._create_dummy_core_node(node.inputs) + core_f = mlx_funcify(op.core_op, core_node) + + # 3) Determine how many inputs correspond to batch dimensions + n_batch = op.batch_ndim(node) + + # 4) Build in_axes: map only the first n_batch args, keep the rest static + in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs))) + + # 5) Handle case where no vectorization is needed + if n_batch == 0 or all(axis is None for axis in in_axes): + # No batch dimensions, just return the core function + def blockwise_fun(*inputs): + return core_f(*inputs) + + return blockwise_fun + + # 6) Vectorize (vmap) with in_axes + blockwise_f = mx.vmap(core_f, in_axes=in_axes) + + # 7) Return the mapped function + def blockwise_fun(*inputs): + return blockwise_f(*inputs) + + return blockwise_fun diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py new file mode 100644 index 0000000000..f1d960e760 --- /dev/null +++ b/pytensor/link/mlx/dispatch/core.py @@ -0,0 +1,279 @@ +""" +pytensor/link/mlx/dispatch/basic.py +----------------------------------- + +First-cut MLX translations for the most common tensor Ops. + +The structure intentionally follows pytensor's JAX dispatcher so that +once these kernels stabilise they can be optimised further (e.g. fusing +element-wise graphs, adding in-place updates, RNG thinning, etc.). +""" + +from __future__ import annotations + +import mlx.core as mx +import numpy as np + +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor import get_vector_length +from pytensor.tensor.basic import ( + Alloc, + AllocEmpty, + ExtractDiag, + Eye, + Join, + MakeVector, + ScalarFromTensor, + Split, + TensorFromScalar, + Tri, + get_scalar_constant_value, +) +from pytensor.tensor.exceptions import NotScalarConstantError + + +@mlx_funcify.register(Join) +def mlx_funcify_Join(op, **kwargs): + def join(axis, *tensors): + return mx.concatenate(tensors, axis=axis) + + return join + + +@mlx_funcify.register(Split) +def mlx_funcify_Split(op: Split, node, **kwargs): + _, axis_sym, splits_sym = node.inputs + + try: + constant_axis = get_scalar_constant_value(axis_sym) + except NotScalarConstantError: + constant_axis = None + + try: + constant_splits = np.array( + [ + get_scalar_constant_value(splits_sym[i]) + for i in range(get_vector_length(splits_sym)) + ] + ) + except (ValueError, NotScalarConstantError): + constant_splits = None + + def split(x, axis, splits): + # Resolve constants for significant performance improvement (14x speedup) + if constant_axis is not None: + axis = int(constant_axis) + + if constant_splits is not None: + splits = constant_splits + cumsum_splits = np.cumsum(splits[:-1]) + else: + # Dynamic case - use MLX operations + splits_arr = mx.array(splits) + cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() + + # Validation checks + if len(splits) != op.len_splits: + raise ValueError("Length of 'splits' is not equal to n_splits") + if np.sum(np.asarray(splits)) != x.shape[axis]: + raise ValueError( + "Split sizes do not sum to the input length on the chosen axis." + ) + if np.any(np.asarray(splits) < 0): + raise ValueError("Split sizes cannot be negative.") + + return mx.split(x, cumsum_splits, axis=axis) + + return split + + +@mlx_funcify.register(ExtractDiag) +def mlx_funcify_ExtractDiag(op, **kwargs): + offset, axis1, axis2 = op.offset, op.axis1, op.axis2 + + def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): + return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) + + return extract_diag + + +@mlx_funcify.register(Eye) +def mlx_funcify_Eye(op, node, **kwargs): + # Extract constants for performance optimization + const_args = [getattr(inp, "data", None) for inp in node.inputs] + dtype = convert_dtype_to_mlx(op.dtype) + + def eye(*args): + # Replace args with compile-time constants when available for better performance + args = [ + arg if const_a is None else const_a + for arg, const_a in zip(args, const_args, strict=True) + ] + N, M, k = args + return mx.eye(int(N), int(M), int(k), dtype=dtype) + + return eye + + +def convert_dtype_to_mlx(dtype_str, auto_cast_unsupported=True): + """Convert PyTensor dtype strings to MLX dtype objects. + + MLX expects dtype objects rather than string literals for type conversion. + This function maps common dtype strings to their MLX equivalents. + + Parameters + ---------- + dtype_str : str or MLX dtype + The dtype to convert + auto_cast_unsupported : bool + If True, automatically cast unsupported dtypes to supported ones with warnings + + Returns + ------- + MLX dtype object + """ + import warnings + + if isinstance(dtype_str, str): + if dtype_str == "bool": + return mx.bool_ + elif dtype_str == "int8": + return mx.int8 + elif dtype_str == "int16": + return mx.int16 + elif dtype_str == "int32": + return mx.int32 + elif dtype_str == "int64": + return mx.int64 + elif dtype_str == "uint8": + return mx.uint8 + elif dtype_str == "uint16": + return mx.uint16 + elif dtype_str == "uint32": + return mx.uint32 + elif dtype_str == "uint64": + return mx.uint64 + elif dtype_str == "float16": + return mx.float16 + elif dtype_str == "float32": + return mx.float32 + elif dtype_str == "float64": + if auto_cast_unsupported: + warnings.warn( + "MLX does not support float64 on GPU. Automatically casting to float32. " + "This may result in reduced precision. To avoid this warning, " + "explicitly use float32 in your code or set floatX='float32' in PyTensor config.", + UserWarning, + stacklevel=3, + ) + return mx.float32 + else: + return mx.float64 + elif dtype_str == "bfloat16": + return mx.bfloat16 + elif dtype_str == "complex64": + return mx.complex64 + elif dtype_str == "complex128": + if auto_cast_unsupported: + warnings.warn( + "MLX does not support complex128. Automatically casting to complex64. " + "This may result in reduced precision. To avoid this warning, " + "explicitly use complex64 in your code.", + UserWarning, + stacklevel=3, + ) + return mx.complex64 + else: + # Return the original even though it might fail + # This allows users to opt out of auto-casting if needed + return mx.complex64 # MLX doesn't have complex128, so fallback + # Return as is if it's already an MLX dtype or not a recognized string + return dtype_str + + +@mlx_funcify.register(MakeVector) +def mlx_funcify_MakeVector(op, **kwargs): + dtype = convert_dtype_to_mlx(op.dtype) + + def makevector(*x): + return mx.array(x, dtype=dtype) + + return makevector + + +@mlx_funcify.register(TensorFromScalar) +def mlx_funcify_TensorFromScalar(op, **kwargs): + def tensor_from_scalar(x): + return x # already an MLX array / scalar + + return tensor_from_scalar + + +@mlx_funcify.register(ScalarFromTensor) +def mlx_funcify_ScalarFromTensor(op, **kwargs): + def scalar_from_tensor(x): + "We can't not return a scalar in MLX without trigger evaluation" + return x + + return scalar_from_tensor + + +@mlx_funcify.register(Tri) +def mlx_funcify_Tri(op, node, **kwargs): + # node.inputs -> N, M, k + const_args = [getattr(inp, "data", None) for inp in node.inputs] + dtype = convert_dtype_to_mlx(op.dtype) + + def tri(*args): + # Replace args with compile-time constants when available + args = [ + arg if const_a is None else const_a + for arg, const_a in zip(args, const_args, strict=True) + ] + return mx.tri(*args, dtype=dtype) + + return tri + + +@mlx_funcify.register(AllocEmpty) +def mlx_funcify_AllocEmpty(op, **kwargs): + dtype = convert_dtype_to_mlx(op.dtype) + + def allocempty(*shape): + return mx.zeros(shape, dtype=dtype) + + return allocempty + + +@mlx_funcify.register(Alloc) +def mlx_funcify_Alloc(op, node, **kwargs): + def alloc(x, *shape): + try: + # Convert shape elements to Python ints for MLX compatibility + # MLX requires shape dimensions to be Python integers, not MLX arrays + shape_ints = tuple( + int(s.item()) if hasattr(s, "item") else int(s) for s in shape + ) + return mx.broadcast_to(x, shape_ints) + except ValueError as e: + if ( + "[eval] Attempting to eval an array during function transformations" + in str(e) + ): + # This is the MLX compilation limitation - provide helpful error + raise ValueError( + "MLX compilation limitation: Alloc operations with dynamic shapes " + "cannot be used inside compiled functions. This is because MLX " + "compilation forbids evaluating arrays to extract shape values. " + # Just a note! TODO: remove this once we have a better solution + "\n\nWorkarounds:" + "\n1. Avoid using Alloc with dynamic shapes in compiled contexts" + "\n2. Use static shapes when possible" + "\n3. Move Alloc operations outside compiled functions" + "\n\nOriginal error: " + str(e) + ) from e + else: + # Re-raise other ValueError exceptions + raise + + return alloc diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py new file mode 100644 index 0000000000..0bbc98cf81 --- /dev/null +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -0,0 +1,173 @@ +from functools import singledispatch + +import mlx.core as mx +import numpy as np + +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx +from pytensor.scalar import Softplus +from pytensor.scalar.basic import ( + AND, + OR, + Add, + Cast, + Mul, + ScalarMaximum, + ScalarMinimum, +) +from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.tensor.special import Softmax, SoftmaxGrad + + +@mlx_funcify.register(DimShuffle) +def mlx_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + # Convert scalar to array if needed + if isinstance(x, int | float) or ( + isinstance(x, np.number) and not isinstance(x, np.ndarray) + ): + x = mx.array(x) + res = mx.transpose(x, op.transposition) + shape = list(res.shape[: len(op.shuffle)]) + for augm in op.augment: + shape.insert(augm, 1) + return mx.reshape(res, shape) + + return dimshuffle + + +# Second-level dispatch for scalar operations in CAReduce +@singledispatch +def mlx_funcify_CAReduce_scalar_op(scalar_op): + raise NotImplementedError( + f"MLX does not support CAReduce with scalar op {scalar_op}" + ) + + +@mlx_funcify.register(CAReduce) +def mlx_funcify_CAReduce(op, **kwargs): + # Dispatch to the appropriate scalar op handler + scalar_reduce_fn = mlx_funcify_CAReduce_scalar_op(op.scalar_op) + axis = op.axis + + def reduce(x): + return scalar_reduce_fn(x, axis) + + return reduce + + +@mlx_funcify_CAReduce_scalar_op.register(Add) +def _(scalar_op): + def sum_reduce(x, axis): + return mx.sum(x, axis=axis) + + return sum_reduce + + +@mlx_funcify_CAReduce_scalar_op.register(Mul) +def _(scalar_op): + def prod_reduce(x, axis): + return mx.prod(x, axis=axis) + + return prod_reduce + + +@mlx_funcify_CAReduce_scalar_op.register(AND) +def _(scalar_op): + def all_reduce(x, axis): + return x.all(axis=axis) + + return all_reduce + + +@mlx_funcify_CAReduce_scalar_op.register(OR) +def _(scalar_op): + def any_reduce(x, axis): + return mx.any(x, axis=axis) + + return any_reduce + + +@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum) +def _(scalar_op): + def max_reduce(x, axis): + return mx.max(x, axis=axis) + + return max_reduce + + +@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum) +def _(scalar_op): + def min_reduce(x, axis): + return mx.min(x, axis=axis) + + return min_reduce + + +@mlx_funcify.register(Softmax) +def mlx_funcify_Softmax(op, **kwargs): + axis = op.axis + + def softmax(x): + return mx.softmax(x, axis=axis) + + return softmax + + +@mlx_funcify.register(SoftmaxGrad) +def mlx_funcify_SoftmaxGrad(op, **kwargs): + axis = op.axis + + def softmax_grad(dy, sm): + dy_times_sm = dy * sm + return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm + + return softmax_grad + + +@mlx_funcify.register(Softplus) +def mlx_funcify_Softplus(op, **kwargs): + def softplus(x): + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, + mx.log1p(mx.exp(x)), + mx.where( + x < 33.3, + x + mx.exp(-x), + x, + ), + ), + ) + + return softplus + + +@mlx_funcify.register(Cast) +def mlx_funcify_Cast(op, **kwargs): + def cast(x): + dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype) + try: + return x.astype(dtype) + except ValueError as e: + if "is not supported on the GPU" in str(e): + # MLX GPU limitation - try auto-casting with warning + import warnings + + warnings.warn( + f"MLX GPU limitation: {e}. Attempting automatic fallback casting.", + UserWarning, + stacklevel=2, + ) + # Get the auto-cast version + fallback_dtype = convert_dtype_to_mlx( + op.scalar_op.o_type.dtype, auto_cast_unsupported=True + ) + return x.astype(fallback_dtype) + else: + # Re-raise other ValueError exceptions + raise + + return cast diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py new file mode 100644 index 0000000000..9c7f94d491 --- /dev/null +++ b/pytensor/link/mlx/dispatch/math.py @@ -0,0 +1,394 @@ +from functools import singledispatch + +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx +from pytensor.scalar.basic import ( + AND, + EQ, + GE, + GT, + LE, + LT, + NEQ, + OR, + Abs, + Add, + Cast, + Cos, + Exp, + IntDiv, + Invert, + IsNan, + Log, + Log1p, + Mul, + Neg, + Pow, + ScalarMaximum, + ScalarMinimum, + Sign, + Sin, + Sqr, + Sqrt, + Sub, + Switch, + TrueDiv, +) +from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.math import Dot + + +@mlx_funcify.register(Dot) +def mlx_funcify_Dot(op, node=None, **kwargs): + def dot(x, y): + return mx.matmul(x, y) + + return dot + + +# Second-level dispatch for scalar operations in Elemwise +@singledispatch +def mlx_funcify_Elemwise_scalar_op(scalar_op): + """Simplified implementation for MLX scalar operations.""" + + # Try using the operation name directly (most common case) + op_name = getattr(scalar_op, "name", None) + if op_name is not None: + try: + mlx_func = getattr(mx, op_name) + # Handle variadic functions like Add + if hasattr(scalar_op, "inputs") and len(scalar_op.inputs) > 2: + + def variadic_func(*args): + result = args[0] + for arg in args[1:]: + result = mlx_func(result, arg) + return result + + return variadic_func + else: + return mlx_func + except AttributeError: + pass + + raise NotImplementedError(f"MLX does not support Elemwise scalar op {scalar_op}") + + +@mlx_funcify_Elemwise_scalar_op.register(Add) +def _(scalar_op): + def add(*args): + result = args[0] + for arg in args[1:]: + result = mx.add(result, arg) + return result + + return add + + +@mlx_funcify_Elemwise_scalar_op.register(Sub) +def _(scalar_op): + def sub(x, y): + return mx.subtract(x, y) + + return sub + + +@mlx_funcify_Elemwise_scalar_op.register(Mul) +def _(scalar_op): + def mul(*args): + result = args[0] + for arg in args[1:]: + result = mx.multiply(result, arg) + return result + + return mul + + +@mlx_funcify_Elemwise_scalar_op.register(TrueDiv) +def _(scalar_op): + def true_div(x, y): + return mx.divide(x, y) + + return true_div + + +@mlx_funcify_Elemwise_scalar_op.register(IntDiv) +def _(scalar_op): + def int_div(x, y): + return mx.floor_divide(x, y) + + return int_div + + +@mlx_funcify_Elemwise_scalar_op.register(Pow) +def _(scalar_op): + def pow(x, y): + return mx.power(x, y) + + return pow + + +@mlx_funcify_Elemwise_scalar_op.register(Exp) +def _(scalar_op): + def exp(x): + return mx.exp(x) + + return exp + + +@mlx_funcify_Elemwise_scalar_op.register(Log) +def _(scalar_op): + def log(x): + return mx.log(x) + + return log + + +@mlx_funcify_Elemwise_scalar_op.register(Log1p) +def _(scalar_op): + def log1p(x): + return mx.log1p(x) + + return log1p + + +@mlx_funcify_Elemwise_scalar_op.register(Sin) +def _(scalar_op): + def sin(x): + return mx.sin(x) + + return sin + + +@mlx_funcify_Elemwise_scalar_op.register(Cos) +def _(scalar_op): + def cos(x): + return mx.cos(x) + + return cos + + +@mlx_funcify_Elemwise_scalar_op.register(Sqrt) +def _(scalar_op): + def sqrt(x): + return mx.sqrt(x) + + return sqrt + + +@mlx_funcify_Elemwise_scalar_op.register(Sqr) +def _(scalar_op): + def sqr(x): + return mx.square(x) + + return sqr + + +@mlx_funcify_Elemwise_scalar_op.register(Abs) +def _(scalar_op): + def abs(x): + return mx.abs(x) + + return abs + + +@mlx_funcify_Elemwise_scalar_op.register(Neg) +def _(scalar_op): + def neg(x): + return mx.negative(x) + + return neg + + +@mlx_funcify_Elemwise_scalar_op.register(Sign) +def _(scalar_op): + def sign(x): + return mx.sign(x) + + return sign + + +@mlx_funcify_Elemwise_scalar_op.register(LE) +def _(scalar_op): + def le(x, y): + return mx.less_equal(x, y) + + return le + + +@mlx_funcify_Elemwise_scalar_op.register(LT) +def _(scalar_op): + def lt(x, y): + return mx.less(x, y) + + return lt + + +@mlx_funcify_Elemwise_scalar_op.register(GE) +def _(scalar_op): + def ge(x, y): + return mx.greater_equal(x, y) + + return ge + + +@mlx_funcify_Elemwise_scalar_op.register(GT) +def _(scalar_op): + def gt(x, y): + return mx.greater(x, y) + + return gt + + +@mlx_funcify_Elemwise_scalar_op.register(EQ) +def _(scalar_op): + def eq(x, y): + return mx.equal(x, y) + + return eq + + +@mlx_funcify_Elemwise_scalar_op.register(NEQ) +def _(scalar_op): + def neq(x, y): + return mx.not_equal(x, y) + + return neq + + +@mlx_funcify_Elemwise_scalar_op.register(Switch) +def _(scalar_op): + def switch(cond, x, y): + return mx.where(cond, x, y) + + return switch + + +@mlx_funcify_Elemwise_scalar_op.register(AND) +def _(scalar_op): + def bitwise_and(x, y): + return mx.bitwise_and(x, y) + + return bitwise_and + + +@mlx_funcify_Elemwise_scalar_op.register(OR) +def _(scalar_op): + def bitwise_or(x, y): + return mx.bitwise_or(x, y) + + return bitwise_or + + +@mlx_funcify_Elemwise_scalar_op.register(ScalarMaximum) +def _(scalar_op): + def maximum(x, y): + return mx.maximum(x, y) + + return maximum + + +@mlx_funcify_Elemwise_scalar_op.register(ScalarMinimum) +def _(scalar_op): + def minimum(x, y): + return mx.minimum(x, y) + + return minimum + + +@mlx_funcify_Elemwise_scalar_op.register(Cast) +def _(scalar_op): + def cast(x): + dtype = convert_dtype_to_mlx(scalar_op.o_type.dtype) + try: + return x.astype(dtype) + except ValueError as e: + if "is not supported on the GPU" in str(e): + # MLX GPU limitation - try auto-casting with warning + import warnings + + warnings.warn( + f"MLX GPU limitation: {e}. Attempting automatic fallback casting.", + UserWarning, + stacklevel=2, + ) + # Get the auto-cast version + fallback_dtype = convert_dtype_to_mlx( + scalar_op.o_type.dtype, auto_cast_unsupported=True + ) + return x.astype(fallback_dtype) + else: + # Re-raise other ValueError exceptions + raise + + return cast + + +@mlx_funcify_Elemwise_scalar_op.register(Sigmoid) +def _(scalar_op): + def sigmoid(x): + return mx.sigmoid(x) + + return sigmoid + + +@mlx_funcify_Elemwise_scalar_op.register(Invert) +def _(scalar_op): + def invert(x): + return mx.bitwise_invert(x) + + return invert + + +@mlx_funcify_Elemwise_scalar_op.register(IsNan) +def _(scalar_op): + def isnan(x): + return mx.isnan(x) + + return isnan + + +@mlx_funcify_Elemwise_scalar_op.register(Erfc) +def _(scalar_op): + def erfc(x): + return 1.0 - mx.erf(x) + + return erfc + + +@mlx_funcify_Elemwise_scalar_op.register(Erfcx) +def _(scalar_op): + def erfcx(x): + return mx.exp(x * x) * (1.0 - mx.erf(x)) + + return erfcx + + +@mlx_funcify_Elemwise_scalar_op.register(Softplus) +def _(scalar_op): + def softplus(x): + # Numerically stable implementation of log(1 + exp(x)) + # Following the same logic as the original PyTensor implementation + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, mx.log1p(mx.exp(x)), mx.where(x < 33.3, x + mx.exp(-x), x) + ), + ) + + return softplus + + +@mlx_funcify.register(Elemwise) +def mlx_funcify_Elemwise(op, node, **kwargs): + # Dispatch to the appropriate scalar op handler + scalar_func = mlx_funcify_Elemwise_scalar_op(op.scalar_op) + + def elemwise(*inputs): + return scalar_func(*inputs) + + return elemwise diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py new file mode 100644 index 0000000000..8e530d468d --- /dev/null +++ b/pytensor/link/mlx/dispatch/shape.py @@ -0,0 +1,42 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape + + +@mlx_funcify.register(Shape) +def mlx_funcify_Shape(op, **kwargs): + def shape(x): + return mx.array(x.shape, dtype=mx.int64) + + return shape + + +@mlx_funcify.register(SpecifyShape) +def mlx_funcify_SpecifyShape(op, node, **kwargs): + def specifyshape(x, *shape): + assert x.ndim == len(shape) + for actual, expected in zip(x.shape, shape, strict=True): + if expected is None: + continue + if actual != expected: + raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") + return x + + return specifyshape + + +@mlx_funcify.register(Shape_i) +def mlx_funcify_Shape_i(op, node, **kwargs): + def shape_i(x): + return x.shape[op.i] + + return shape_i + + +@mlx_funcify.register(Reshape) +def mlx_funcify_Reshape(op, **kwargs): + def reshape(x, shp): + return mx.reshape(x, shp) + + return reshape diff --git a/pytensor/link/mlx/dispatch/signal/__init__.py b/pytensor/link/mlx/dispatch/signal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py new file mode 100644 index 0000000000..8f84ebb42f --- /dev/null +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -0,0 +1,14 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.signal.conv import Conv1d + + +@mlx_funcify.register(Conv1d) +def mlx_funcify_Conv1d(op, node=None, **kwargs): + mode = op.mode + + def conv1d(data, kernel): + return mx.convolve(data, kernel, mode=mode) + + return conv1d diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py new file mode 100644 index 0000000000..ce14d08246 --- /dev/null +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -0,0 +1,105 @@ +from copy import deepcopy + +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, + indices_from_subtensor, +) +from pytensor.tensor.type_other import MakeSlice + + +@mlx_funcify.register(Subtensor) +def mlx_funcify_Subtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + def subtensor(x, *ilists): + indices = indices_from_subtensor([int(element) for element in ilists], idx_list) + if len(indices) == 1: + indices = indices[0] + + return x.__getitem__(indices) + + return subtensor + + +@mlx_funcify.register(AdvancedSubtensor) +@mlx_funcify.register(AdvancedSubtensor1) +def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + def advanced_subtensor(x, *ilists): + indices = indices_from_subtensor(ilists, idx_list) + if len(indices) == 1: + indices = indices[0] + + return x.__getitem__(indices) + + return advanced_subtensor + + +@mlx_funcify.register(IncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) +def mlx_funcify_IncSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + if getattr(op, "set_instead_of_inc", False): + + def mlx_fn(x, indices, y): + if not op.inplace: + x = deepcopy(x) + x[indices] = y + return x + + else: + + def mlx_fn(x, indices, y): + if not op.inplace: + x = deepcopy(x) + x[indices] += y + return x + + def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + indices = indices_from_subtensor(ilist, idx_list) + if len(indices) == 1: + indices = indices[0] + + return mlx_fn(x, indices, y) + + return incsubtensor + + +@mlx_funcify.register(AdvancedIncSubtensor) +def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): + if getattr(op, "set_instead_of_inc", False): + + def mlx_fn(x, indices, y): + if not op.inplace: + x = deepcopy(x) + x[indices] = y + return x + + else: + + def mlx_fn(x, indices, y): + if not op.inplace: + x = deepcopy(x) + x[indices] += y + return x + + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): + return mlx_fn(x, ilist, y) + + return advancedincsubtensor + + +@mlx_funcify.register(MakeSlice) +def mlx_funcify_MakeSlice(op, **kwargs): + def makeslice(*x): + return slice(*x) + + return makeslice diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py new file mode 100644 index 0000000000..9a4d1ac2c1 --- /dev/null +++ b/pytensor/link/mlx/linker.py @@ -0,0 +1,79 @@ +from pytensor.link.basic import JITLinker + + +class MLXLinker(JITLinker): + """A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX.""" + + def __init__(self, use_compile=True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.gen_functors = [] + self.use_compile = use_compile + + def fgraph_convert(self, fgraph, **kwargs): + """Convert a PyTensor FunctionGraph to an MLX-compatible function. + + Parameters + ---------- + fgraph : FunctionGraph + The function graph to convert + + Returns + ------- + callable + An MLX-compatible function + """ + from pytensor.link.mlx.dispatch import mlx_funcify + + return mlx_funcify( + fgraph, + **kwargs, + ) + + def jit_compile(self, fn): + import mlx.core as mx + + from pytensor.link.mlx.dispatch import mlx_typify + + if not self.use_compile: + # Skip compilation and just return the function with MLX typification + def fn_no_compile(*inputs): + return fn(*(mlx_typify(inp) for inp in inputs)) + + return fn_no_compile + + inner_fn = mx.compile(fn) + + def fn(*inputs, inner_fn=inner_fn): + return inner_fn(*(mlx_typify(inp) for inp in inputs)) + + return fn + + def create_thunk_inputs(self, storage_map): + """Create inputs for the MLX thunk. + + Parameters + ---------- + storage_map : dict + Map from variables to their storage + + Returns + ------- + list + The inputs for the thunk + """ + from numpy.random import Generator, RandomState + + from pytensor.link.mlx.dispatch import mlx_typify + + thunk_inputs = [] + for n in self.fgraph.inputs: + sinput = storage_map[n] + # Handle random number generators specially + if isinstance(sinput[0], RandomState | Generator): + new_value = mlx_typify( + sinput[0], dtype=getattr(sinput[0], "dtype", None) + ) + sinput[0] = new_value + thunk_inputs.append(sinput) + + return thunk_inputs diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index b8475e3157..18824a5b71 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -31,13 +31,15 @@ def conversion_func_register(*args, **kwargs): **kwargs, } return pytorch_funcify( - fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs + fgraph, + input_storage=input_storage, + storage_map=storage_map, + **built_kwargs, ) def jit_compile(self, fn): import torch - # flag that tend to help our graphs torch._dynamo.config.capture_dynamic_output_shape_ops = True from pytensor.link.pytorch.dispatch import pytorch_typify diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py new file mode 100644 index 0000000000..8d6999e55f --- /dev/null +++ b/tests/link/mlx/test_basic.py @@ -0,0 +1,392 @@ +""" +Basic tests for the MLX backend. +""" + +from collections.abc import Callable, Iterable +from functools import partial + +import mlx.core as mx +import numpy as np +import pytest + +import pytensor +from pytensor import tensor as pt +from pytensor.compile.function import function +from pytensor.compile.mode import MLX, Mode +from pytensor.graph import RewriteDatabaseQuery +from pytensor.graph.basic import Variable +from pytensor.link.mlx import MLXLinker +from pytensor.link.mlx.dispatch.core import ( + mlx_funcify_Alloc, +) +from pytensor.tensor.basic import Alloc + + +optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude) +mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer) +mlx_mode_no_compile = Mode(linker=MLXLinker(use_compile=False), optimizer=optimizer) +compile_mode = Mode(linker=MLXLinker(use_compile=True), optimizer=optimizer) +py_mode = Mode(linker="py", optimizer=None) + + +def compare_mlx_and_py( + graph_inputs: Iterable[Variable], + graph_outputs: Variable | Iterable[Variable], + test_inputs: Iterable, + *, + assert_fn: Callable | None = None, + must_be_device_array: bool = True, + mlx_mode=mlx_mode, + py_mode=py_mode, +): + """Function to compare python function output and mlx compiled output for testing equality + + The inputs and outputs are then passed to this function which then compiles the given function in both + mlx and python, runs the calculation in both and checks if the results are the same + + Parameters + ---------- + graph_inputs: + Symbolic inputs to the graph + outputs: + Symbolic outputs of the graph + test_inputs: iter + Numerical inputs for testing the function. + assert_fn: func, opt + Assert function used to check for equality between python and mlx. If not + provided uses np.testing.assert_allclose + must_be_device_array: Bool + Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes + if this device array is found it indicates if the result was computed by jax + + Returns + ------- + mlx_res + + """ + if assert_fn is None: + assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) + + if any(inp.owner is not None for inp in graph_inputs): + raise ValueError("Inputs must be root variables") + + pytensor_mlx_fn = function(graph_inputs, graph_outputs, mode=mlx_mode) + mlx_res = pytensor_mlx_fn(*test_inputs) + + if must_be_device_array: + if isinstance(mlx_res, list): + assert all(isinstance(res, mx.array) for res in mlx_res) + else: + assert isinstance(mlx_res, mx.array) + + pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode) + py_res = pytensor_py_fn(*test_inputs) + + if isinstance(graph_outputs, list | tuple): + for j, p in zip(mlx_res, py_res, strict=True): + assert_fn(j, p) + else: + assert_fn(mlx_res, py_res) + + return pytensor_mlx_fn, mlx_res + + +def test_scalar_from_tensor_matrix_indexing(): + """Test ScalarFromTensor with matrix element extraction.""" + # Matrix element extraction is a common real-world scenario + matrix = pt.matrix("matrix", dtype="float32") + element = matrix[0, 0] # Creates 0-d tensor + + f = pytensor.function([matrix], element, mode="MLX") + + test_matrix = np.array([[42.0, 1.0], [2.0, 3.0]], dtype=np.float32) + result = f(test_matrix) + + assert float(result) == 42.0 + assert isinstance(result, mx.array) + + +def test_scalar_from_tensor_reduction_operations(): + """Test ScalarFromTensor with reduction operations that produce scalars.""" + # Test vector sum reduction + vector = pt.vector("vector", dtype="float32") + sum_result = pt.sum(vector) + + f = pytensor.function([vector], sum_result, mode="MLX") + test_vector = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + result = f(test_vector) + + assert float(result) == 10.0 + + # Test matrix mean reduction + matrix = pt.matrix("matrix", dtype="float32") + mean_result = pt.mean(matrix) + + f2 = pytensor.function([matrix], mean_result, mode="MLX") + test_matrix = np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32) + result = f2(test_matrix) + + assert float(result) == 5.0 + + +def test_scalar_from_tensor_conditional_operations(): + """Test ScalarFromTensor with conditional operations.""" + x = pt.scalar("x", dtype="float32") + y = pt.scalar("y", dtype="float32") + + # Switch operation may create 0-d tensors + max_val = pt.switch(x > y, x, y) + + f = pytensor.function([x, y], max_val, mode="MLX") + + # Test both branches + result1 = f(5.0, 3.0) + assert float(result1) == 5.0 + + result2 = f(2.0, 7.0) + assert float(result2) == 7.0 + + +def test_scalar_from_tensor_multiple_dtypes(): + """Test ScalarFromTensor with different data types.""" + # Test different dtypes that might require scalar extraction + for dtype in ["float32", "int32", "int64"]: + x = pt.vector("x", dtype=dtype) + # Use max reduction to create 0-d tensor + max_val = pt.max(x) + + f = pytensor.function([x], max_val, mode="MLX", allow_input_downcast=True) + + if dtype.startswith("float"): + test_data = np.array([1.5, 3.7, 2.1], dtype=dtype) + expected = 3.7 + else: + test_data = np.array([10, 30, 20], dtype=dtype) + expected = 30 + + result = f(test_data) + assert abs(float(result) - expected) < 1e-5 + + +def test_scalar_from_tensor_pytensor_integration(): + """Test ScalarFromTensor in a complete PyTensor graph context. + + This test uses symbolic variables (not constants) to ensure the MLX backend + actually executes the ScalarFromTensor operation rather than having it + optimized away during compilation. + """ + # Create a symbolic scalar input to actually test MLX execution + x = pt.scalar("x", dtype="int64") + + # Apply ScalarFromTensor - this creates a graph that forces execution + scalar_result = pt.scalar_from_tensor(x) + + # Create function and test with actual MLX backend execution + f = pytensor.function([x], scalar_result, mode="MLX") + result = f(42) + + assert result == 42 + assert isinstance(result, mx.array) + + +def test_alloc_with_different_shape_types(): + """Test Alloc works with different types of shape parameters. + + This addresses the TypeError that occurred when shape parameters + contained MLX arrays instead of Python integers. + """ + + # Create a mock node (we don't need a real node for this test) + class MockNode: + pass + + alloc_func = mlx_funcify_Alloc(Alloc(), MockNode()) + x = mx.array(5.0) + + # Test with Python ints + result = alloc_func(x, 3, 4) + assert result.shape == (3, 4) + assert float(result[0, 0]) == 5.0 + + # Test with MLX arrays (this used to fail) + result = alloc_func(x, mx.array(3), mx.array(4)) + assert result.shape == (3, 4) + assert float(result[0, 0]) == 5.0 + + # Test with mixed types + result = alloc_func(x, 3, mx.array(4)) + assert result.shape == (3, 4) + assert float(result[0, 0]) == 5.0 + + +def test_alloc_pytensor_integration(): + """Test Alloc in a PyTensor graph context.""" + # Test basic constant shape allocation + x = pt.scalar("x", dtype="float32") + result = pt.alloc(x, 3, 4) + + f = pytensor.function([x], result, mode="MLX") + output = f(5.0) + + assert output.shape == (3, 4) + assert float(output[0, 0]) == 5.0 + + +def test_alloc_compilation_limitation(): + """Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts.""" + + # Create variables + x = pt.scalar("x", dtype="float32") + s1 = pt.scalar("s1", dtype="int64") + s2 = pt.scalar("s2", dtype="int64") + + # Create Alloc operation with dynamic shapes + result = pt.alloc(x, s1, s2) + + # Create function with non-compiled MLX mode + f = pytensor.function([x, s1, s2], result, mode=mlx_mode_no_compile) + + # Test that it works with concrete values (non-compiled context) + output = f(5.0, 3, 4) + assert output.shape == (3, 4) + assert np.allclose(output, 5.0) + + # Test that compilation fails with helpful error + compiled_f = pytensor.function([x, s1, s2], result, mode=compile_mode) + + with pytest.raises(ValueError) as exc_info: + compiled_f(5.0, 3, 4) + + error_msg = str(exc_info.value) + assert "MLX compilation limitation" in error_msg + assert "Alloc operations with dynamic shapes" in error_msg + assert "cannot be used inside compiled functions" in error_msg + assert "Workarounds:" in error_msg + assert "Avoid using Alloc with dynamic shapes in compiled contexts" in error_msg + assert "Use static shapes when possible" in error_msg + assert "Move Alloc operations outside compiled functions" in error_msg + + +def test_alloc_static_shapes_compilation(): + """Test that Alloc operations with static shapes work fine in compiled contexts.""" + # Create a scenario with static shapes that should work + x = pt.scalar("x", dtype="float32") + + # Use constant shape - this should work even in compilation + result = pt.alloc(x, 3, 4) # Static shapes + + # Test both compiled and non-compiled modes + f_normal = pytensor.function([x], result, mode=mlx_mode_no_compile) + f_compiled = pytensor.function([x], result, mode=compile_mode) + + # Both should work + output_normal = f_normal(5.0) + output_compiled = f_compiled(5.0) + + assert output_normal.shape == (3, 4) + assert output_compiled.shape == (3, 4) + assert np.allclose(output_normal, 5.0) + assert np.allclose(output_compiled, 5.0) + assert np.allclose(output_normal, output_compiled) + + +def test_mlx_float64_auto_casting(): + """Test MLX automatic casting of float64 to float32 with warnings.""" + import warnings + + # Test 1: Direct Cast operation with warning + x = pt.scalar("x", dtype="float32") + y = pt.cast(x, "float64") + + # Capture warnings + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + + f = pytensor.function([x], y, mode=mlx_mode, allow_input_downcast=True) + result = f(3.14) + + # Check that the operation succeeded + assert result.dtype == mx.float32 # Should be auto-cast to float32 + assert abs(float(result) - 3.14) < 1e-6 + + # Check that a warning was issued + warning_messages = [str(w.message) for w in warning_list] + dtype_warnings = [ + msg for msg in warning_messages if "float64" in msg and "float32" in msg + ] + assert ( + len(dtype_warnings) > 0 + ), f"Expected dtype warning, got warnings: {warning_messages}" + + +def test_mlx_float64_complex_operations(): + """Test float64 casting in more complex operations.""" + import warnings + + # Test with vector operations + x = pt.vector("x", dtype="float32") + y = pt.cast(x, "float64") + z = pt.exp(y) + pt.sin(y) # Multiple operations on float64 + + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + + f = pytensor.function([x], z, mode=mlx_mode, allow_input_downcast=True) + result = f([1.0, 2.0, 3.0]) + + # Should work and return float32 results + assert result.dtype == mx.float32 + assert result.shape == (3,) + + # Should have issued warnings + warning_messages = [str(w.message) for w in warning_list] + dtype_warnings = [ + msg + for msg in warning_messages + if "float64" in msg or "MLX GPU limitation" in msg + ] + assert len(dtype_warnings) > 0 + + +def test_mlx_float64_no_warning_when_disabled(): + """Test that auto-casting can be controlled.""" + import warnings + + from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx + + # Test that we can disable auto-casting + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + + # This should not issue warnings when auto_cast_unsupported=False + dtype = convert_dtype_to_mlx("float64", auto_cast_unsupported=False) + assert dtype == mx.float64 # Should return the original dtype + + # No warnings should be issued for proactive conversion when disabled + dtype_warnings = [ + str(w.message) for w in warning_list if "float64" in str(w.message) + ] + assert len(dtype_warnings) == 0 + + +def test_mlx_complex128_auto_casting(): + """Test automatic casting of complex128 to complex64.""" + import warnings + + from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx + + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + + # This should trigger a warning and return complex64 + dtype = convert_dtype_to_mlx("complex128", auto_cast_unsupported=True) + assert dtype == mx.complex64 + + # Should have issued a warning + warning_messages = [str(w.message) for w in warning_list] + complex_warnings = [ + msg + for msg in warning_messages + if "complex128" in msg and "complex64" in msg + ] + assert len(complex_warnings) > 0 diff --git a/tests/link/mlx/test_blockwise.py b/tests/link/mlx/test_blockwise.py new file mode 100644 index 0000000000..9b271186c9 --- /dev/null +++ b/tests/link/mlx/test_blockwise.py @@ -0,0 +1,64 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor.tensor import tensor +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.math import Dot +from tests.link.mlx.test_basic import compare_mlx_and_py + + +# Equivalent blockwise to matmul but with dumb signature +odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)") + + +# @pytest.mark.parametrize("matmul_op", (matmul, odd_matmul)) +# def test_matmul(matmul_op): +# rng = np.random.default_rng(14) +# a = tensor("a", shape=(2, 3, 5)) +# b = tensor("b", shape=(2, 5, 3)) +# test_values = [ +# rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b) +# ] +# +# out = matmul_op(a, b) +# assert isinstance(out.owner.op, Blockwise) +# fn, _ = compare_mlx_and_py([a, b], [out], test_values) +# +## Check we are not adding any unnecessary stuff +# jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) +# jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul") +# expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values)) +# assert jaxpr == expected_jaxpr + + +# conv1d +# (2, 100) +# (8, 100) +# mode = valid + + +def test_blockwise_conv1d(): + rng = np.random.default_rng(14) + a = tensor("a", shape=(2, 100)) + b = tensor("b", shape=(2, 8)) + + # a_test = np.broadcast_to(np.arange(100), (2, 100)) + a_test = rng.normal(size=(2, 100)) + b_test = rng.normal(size=(2, 8)) + # b_test = np.concatenate( + # [ + # np.ones((1, 8)), + # np.zeros((1, 8)), + # np.zeros((1, 8)), + # np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8), + # np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8), + # ], + # axis=0, + # ) + + test_values = [a_test, b_test] + + out = pt.signal.convolve1d(a, b, mode="valid") + + # assert isinstance(out.owner.op, Blockwise) + compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True) diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py new file mode 100644 index 0000000000..44d0d58dcb --- /dev/null +++ b/tests/link/mlx/test_elemwise.py @@ -0,0 +1,50 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from tests.link.mlx.test_basic import compare_mlx_and_py, mx + + +@pytest.mark.parametrize("op", [pt.any, pt.all, pt.max, pt.min]) +def test_input(op) -> None: + x = pt.vector("x") + out = op(x > 0) + x_test = mx.array([1.0, 2.0, 3.0]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_elemwise_operations() -> None: + """Test elemwise operations (IntDiv, IsNan, Erfc, Erfcx, Softplus) in elemwise context""" + x = pt.vector("x") + y = pt.vector("y") + + # Test int_div in an elemwise expression + out_int_div = pt.int_div(x, y) + 1 + x_test = mx.array([10.0, 15.0, 20.0]) + y_test = mx.array([3.0, 4.0, 6.0]) + compare_mlx_and_py([x, y], out_int_div, [x_test, y_test]) + + # Test isnan in an elemwise expression + z = pt.vector("z") + out_isnan = pt.isnan(z).astype("float32") * 10 + z_test = mx.array([1.0, np.nan, 3.0]) + compare_mlx_and_py([z], out_isnan, [z_test]) + + # Test erfc in an elemwise expression + w = pt.vector("w") + out_erfc = pt.erfc(w) * 2.0 + w_test = mx.array([0.0, 0.5, 1.0]) + compare_mlx_and_py([w], out_erfc, [w_test]) + + # Test erfcx in an elemwise expression + v = pt.vector("v") + out_erfcx = pt.erfcx(v) + 0.1 + v_test = mx.array([0.0, 1.0, 2.0]) + compare_mlx_and_py([v], out_erfcx, [v_test]) + + # Test softplus in an elemwise expression + u = pt.vector("u") + out_softplus = pt.softplus(u) - 0.5 + u_test = mx.array([0.0, 1.0, -1.0]) + compare_mlx_and_py([u], out_softplus, [u_test]) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py new file mode 100644 index 0000000000..d35cb27654 --- /dev/null +++ b/tests/link/mlx/test_math.py @@ -0,0 +1,215 @@ +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor.tensor.math import Argmax, Max +from tests.link.mlx.test_basic import compare_mlx_and_py, mx + + +def test_dot(): + x = pt.matrix("x") + y = pt.matrix("y") + + out = x.dot(y) + fn = pytensor.function([x, y], out, mode="MLX") + + seed = sum(map(ord, "test_mlx_dot")) + rng = np.random.default_rng(seed) + + test_x = rng.normal(size=(3, 2)) + test_y = rng.normal(size=(2, 4)) + + actual = fn(test_x, test_y) + assert isinstance(actual, mx.array) + expected = np.dot(test_x, test_y) + np.testing.assert_allclose(actual, expected, rtol=1e-6) + + +@pytest.mark.parametrize( + "op", + [ + pytest.param(pt.exp, id="exp"), + pytest.param(pt.log, id="log"), + pytest.param(pt.sin, id="sin"), + pytest.param(pt.cos, id="cos"), + pytest.param(pt.sigmoid, id="sigmoid"), + ], +) +def test_elemwise_one_input(op) -> None: + x = pt.vector("x") + out = op(x) + x_test = mx.array([1.0, 2.0, 3.0]) + compare_mlx_and_py([x], out, [x_test]) + + +def test_switch() -> None: + x = pt.vector("x") + y = pt.vector("y") + + out = pt.switch(x > 0, y, x) + + x_test = mx.array([-1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +@pytest.mark.parametrize("op", [pt.sum, pt.prod]) +def test_input(op) -> None: + x = pt.vector("x") + y = pt.vector("y") + out = op([x, y, x + y]) + x_test = mx.array([1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +@pytest.mark.parametrize( + "op", + [ + pytest.param(pt.add, id="add"), + pytest.param(pt.sub, id="sub"), + pytest.param(pt.mul, id="mul"), + pytest.param(pt.power, id="power"), + pytest.param(pt.le, id="le"), + pytest.param(pt.lt, id="lt"), + pytest.param(pt.ge, id="ge"), + pytest.param(pt.gt, id="gt"), + pytest.param(pt.eq, id="eq"), + pytest.param(pt.neq, id="neq"), + pytest.param(pt.true_div, id="true_div"), + pytest.param(pt.int_div, id="int_div"), + ], +) +def test_elemwise_two_inputs(op) -> None: + x = pt.vector("x") + y = pt.vector("y") + out = op(x, y) + x_test = mx.array([1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +def test_int_div_specific() -> None: + """Test integer division with specific test cases""" + x = pt.vector("x") + y = pt.vector("y") + out = pt.int_div(x, y) + + # Test with integers that demonstrate floor division behavior + x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0]) + y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0]) + + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +def test_isnan() -> None: + """Test IsNan operation with various inputs including NaN values""" + x = pt.vector("x") + out = pt.isnan(x) + + # Test with mix of normal values, NaN, and infinity + x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_isnan_edge_cases() -> None: + """Test IsNan with edge cases""" + x = pt.scalar("x") + out = pt.isnan(x) + + # Test individual cases + test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10] + + for test_val in test_cases: + x_test = test_val + compare_mlx_and_py([x], out, [x_test]) + + +def test_erfc() -> None: + """Test complementary error function""" + x = pt.vector("x") + out = pt.erfc(x) + + # Test with various values including negative, positive, and zero + x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_erfc_extreme_values() -> None: + """Test erfc with extreme values""" + x = pt.vector("x") + out = pt.erfc(x) + + # Test with larger values where erfc approaches 0 or 2 + x_test = mx.array([-3.0, -2.5, 2.5, 3.0]) + + # Use relaxed tolerance for extreme values due to numerical precision differences + from functools import partial + + relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6) + + compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert) + + +def test_erfcx() -> None: + """Test scaled complementary error function""" + x = pt.vector("x") + out = pt.erfcx(x) + + # Test with positive values where erfcx is most numerically stable + x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_erfcx_small_values() -> None: + """Test erfcx with small values""" + x = pt.vector("x") + out = pt.erfcx(x) + + # Test with small values + x_test = mx.array([0.001, 0.01, 0.1, 0.2]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_softplus() -> None: + """Test softplus (log(1 + exp(x))) function""" + x = pt.vector("x") + out = pt.softplus(x) + + # Test with normal range values + x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_softplus_extreme_values() -> None: + """Test softplus with extreme values to verify numerical stability""" + x = pt.vector("x") + out = pt.softplus(x) + + # Test with extreme values where different branches of the implementation are used + x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0]) + + # Use relaxed tolerance for extreme values due to numerical precision differences + from functools import partial + + relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8) + + compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert) + + +@pytest.mark.xfail(reason="Argmax not implemented yet") +def test_mlx_max_and_argmax(): + # Test that a single output of a multi-output `Op` can be used as input to + # another `Op` + x = pt.dvector() + mx = Max([0])(x) + amx = Argmax([0])(x) + out = mx * amx + compare_mlx_and_py([x], [out], [np.r_[1, 2]]) diff --git a/tests/link/mlx/test_shape.py b/tests/link/mlx/test_shape.py new file mode 100644 index 0000000000..19c3dd220b --- /dev/null +++ b/tests/link/mlx/test_shape.py @@ -0,0 +1,109 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.configdefaults import config +from pytensor.tensor.shape import Shape, Shape_i, reshape +from pytensor.tensor.type import iscalar, vector +from tests.link.mlx.test_basic import compare_mlx_and_py + + +def test_mlx_shape_ops(): + x_np = np.zeros((20, 3)) + x = Shape()(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], [], must_be_device_array=False) + + x = Shape_i(1)(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], [], must_be_device_array=False) + + +def test_mlx_specify_shape(): + in_pt = pt.matrix("in") + x = pt.specify_shape(in_pt, (4, None)) + compare_mlx_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)]) + + # When used to assert two arrays have similar shapes + in_pt = pt.matrix("in") + shape_pt = pt.matrix("shape") + x = pt.specify_shape(in_pt, shape_pt.shape) + + compare_mlx_and_py( + [in_pt, shape_pt], + [x], + [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], + ) + + +def test_mlx_Reshape_constant(): + a = vector("a") + x = reshape(a, (2, 2)) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +def test_mlx_Reshape_various_shapes(): + """Test reshape with various different shapes to ensure robustness.""" + # 1D to 2D + a = vector("a") + x = reshape(a, (2, 3)) + compare_mlx_and_py([a], [x], [np.arange(6, dtype=config.floatX)]) + + # 2D to 1D + b = pt.matrix("b") + y = reshape(b, (6,)) + compare_mlx_and_py([b], [y], [np.arange(6, dtype=config.floatX).reshape(2, 3)]) + + # 2D to 3D + c = pt.matrix("c") + z = reshape(c, (2, 2, 3)) + compare_mlx_and_py([c], [z], [np.arange(12, dtype=config.floatX).reshape(4, 3)]) + + # 3D to 2D + d = pt.tensor3("d") + w = reshape(d, (3, 4)) + compare_mlx_and_py([d], [w], [np.arange(12, dtype=config.floatX).reshape(2, 2, 3)]) + + +def test_mlx_Reshape_negative_one(): + """Test reshape with -1 dimension (infer dimension).""" + a = vector("a") + # Use -1 to infer the second dimension + x = reshape(a, (2, -1)) + compare_mlx_and_py([a], [x], [np.arange(8, dtype=config.floatX)]) + + # Use -1 to infer the first dimension + y = reshape(a, (-1, 4)) + compare_mlx_and_py([a], [y], [np.arange(8, dtype=config.floatX)]) + + +def test_mlx_Reshape_concrete_shape(): + """MLX should compile when a concrete value is passed for the `shape` parameter.""" + a = vector("a") + x = reshape(a, a.shape) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +@pytest.mark.xfail(reason="`shape_pt` should be specified as a static argument") +def test_mlx_Reshape_shape_graph_input(): + a = vector("a") + shape_pt = iscalar("b") + x = reshape(a, (shape_pt, shape_pt)) + compare_mlx_and_py( + [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2] + ) + + +@pytest.mark.xfail(reason="ViewOp Op is not supported yet") +def test_mlx_compile_ops(): + x = DeepCopyOp()(pt.as_tensor_variable(1.1)) + compare_mlx_and_py([], [x], []) + + x_np = np.zeros((20, 1, 1)) + x = ViewOp()(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], []) diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py new file mode 100644 index 0000000000..cfc5a07baa --- /dev/null +++ b/tests/link/mlx/test_subtensor.py @@ -0,0 +1,247 @@ +import numpy as np +import pytest +from test_basic import compare_mlx_and_py + +import pytensor.tensor as pt +from pytensor.tensor import subtensor as pt_subtensor +from pytensor.tensor import tensor + + +mx = pytest.importorskip("mlx.core") + + +def test_mlx_Subtensor_basic(): + """Test basic subtensor operations with constant indices.""" + shape = (3, 4, 5) + x_pt = tensor("x", shape=shape, dtype="float32") + x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + # Basic indexing with single elements + out_pt = x_pt[1, 2, 0] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Basic indexing with slices + out_pt = x_pt[1:, 1, :] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + out_pt = x_pt[:2, 1, :] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + out_pt = x_pt[1:2, 1, :] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Negative indexing + out_pt = x_pt[-1, -1, -1] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Step slicing + out_pt = x_pt[::2, ::2, ::2] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Reverse indexing + out_pt = x_pt[::-1, ::-1, ::-1] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + +def test_mlx_AdvancedSubtensor(): + """Test advanced subtensor operations.""" + shape = (3, 4, 5) + x_pt = tensor("x", shape=shape, dtype="float32") + x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + # Advanced indexing with array indices + out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2]) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Multi-dimensional advanced indexing + out_pt = x_pt[[1, 2], [2, 3]] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Mixed advanced and basic indexing + out_pt = x_pt[[1, 2], :] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + out_pt = x_pt[[1, 2], :, [3, 4]] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + +@pytest.mark.xfail(reason="MLX does not support boolean indexing yet") +def test_mlx_AdvancedSubtensor_boolean(): + """Test advanced subtensor operations with boolean indexing.""" + shape = (3, 4, 5) + x_pt = tensor("x", shape=shape, dtype="float32") + x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + # Boolean indexing with constant mask + bool_mask = np.array([True, False, True]) + out_pt = x_pt[bool_mask] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + +@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") +def test_mlx_IncSubtensor_set(): + """Test set operations using IncSubtensor (set_instead_of_inc=True).""" + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Set single element + st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32)) + out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) + assert out_pt.owner.op.set_instead_of_inc + compare_mlx_and_py([], [out_pt], []) + + +@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") +def test_mlx_IncSubtensor_increment(): + """Test increment operations using IncSubtensor (set_instead_of_inc=False).""" + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Increment single element + st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32)) + out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) + assert not out_pt.owner.op.set_instead_of_inc + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_AdvancedIncSubtensor_set(): + """Test advanced set operations using AdvancedIncSubtensor.""" + rng = np.random.default_rng(213234) + + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Set with advanced indexing - this actually works in MLX! + st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32)) + out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) + assert out_pt.owner.op.set_instead_of_inc + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_AdvancedIncSubtensor_increment(): + """Test advanced increment operations using AdvancedIncSubtensor.""" + rng = np.random.default_rng(213234) + + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Increment with advanced indexing - this actually works in MLX! + st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32)) + out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) + assert not out_pt.owner.op.set_instead_of_inc + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_AdvancedIncSubtensor1_operations(): + """Test AdvancedIncSubtensor1 operations (handled by IncSubtensor dispatcher).""" + rng = np.random.default_rng(213234) + + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Test set operation - this actually works in MLX! + st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32)) + indices = [1, 2] + + # Create AdvancedIncSubtensor1 manually for set operation + out_pt = pt_subtensor.advanced_set_subtensor1(x_pt, st_pt, indices) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1) + assert out_pt.owner.op.set_instead_of_inc + compare_mlx_and_py([], [out_pt], []) + + +@pytest.mark.xfail(reason="Inplace operations not yet supported in MLX mode") +def test_mlx_inplace_variants(): + """Test inplace variants of all subtensor operations.""" + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + + # Test inplace IncSubtensor (set) + st_pt = pt.as_tensor_variable(np.array([-1.0, -2.0], dtype=np.float32)) + out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], st_pt, inplace=True) + assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) + assert out_pt.owner.op.inplace + assert out_pt.owner.op.set_instead_of_inc + compare_mlx_and_py([], [out_pt], []) + + +@pytest.mark.xfail( + reason="MLX slice indices must be integers or None, dynamic slices not supported" +) +def test_mlx_MakeSlice(): + """Test MakeSlice operation.""" + # Test slice creation + start = pt.iscalar("start") + stop = pt.iscalar("stop") + step = pt.iscalar("step") + + # Create a slice using MakeSlice + slice_op = pt_subtensor.MakeSlice() + slice_pt = slice_op(start, stop, step) + + # Use simple constant array instead of arange + x_pt = pt.constant(np.arange(10, dtype=np.float32)) + out_pt = x_pt[slice_pt] + + compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2]) + + +def test_mlx_subtensor_edge_cases(): + """Test edge cases and boundary conditions.""" + # Empty slices - use constant array + x_pt = pt.constant(np.arange(10, dtype=np.float32)) + out_pt = x_pt[5:5] # Empty slice + compare_mlx_and_py([], [out_pt], []) + + # Single element arrays + x_pt = pt.tensor(shape=(1,), dtype="float32", name="x") + x_np = np.array([42.0], dtype=np.float32) + out_pt = x_pt[0] + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Large step sizes - use constant array + x_pt = pt.constant(np.arange(20, dtype=np.float32)) + out_pt = x_pt[::5] + compare_mlx_and_py([], [out_pt], []) + + # Negative steps - use constant array + x_pt = pt.constant(np.arange(10, dtype=np.float32)) + out_pt = x_pt[::-2] + compare_mlx_and_py([], [out_pt], []) + + +@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") +def test_mlx_subtensor_with_variables(): + """Test subtensor operations with PyTensor variables as inputs.""" + # Test with variable arrays (not constants) + x_pt = pt.matrix("x", dtype="float32") + y_pt = pt.vector("y", dtype="float32") + + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + y_np = np.array([-1.0, -2.0], dtype=np.float32) + + # Set operation with variables + out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) + compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np])