Skip to content

Commit ce8216c

Browse files
committed
remove MPI
1 parent 33b64d3 commit ce8216c

File tree

4 files changed

+3
-423
lines changed

4 files changed

+3
-423
lines changed

pytensor/tensor/io.py

Lines changed: 1 addition & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import numpy as np
22

3-
from pytensor.graph.basic import Apply, Constant, Variable
3+
from pytensor.graph.basic import Apply, Constant
44
from pytensor.graph.op import Op
55
from pytensor.link.c.type import Generic
66
from pytensor.tensor.type import tensor
7-
from pytensor.utils import key_to_cmp
87

98

109
class LoadFromDisk(Op):
@@ -92,229 +91,4 @@ def load(path, dtype, shape, mmap_mode=None):
9291
return LoadFromDisk(dtype, shape, mmap_mode)(path)
9392

9493

95-
##########################
96-
# MPI
97-
##########################
98-
99-
try:
100-
from mpi4py import MPI
101-
except ImportError:
102-
mpi_enabled = False
103-
else:
104-
comm = MPI.COMM_WORLD
105-
mpi_enabled = True
106-
107-
108-
class MPIRecv(Op):
109-
"""
110-
An operation to asynchronously receive an array to a remote host using MPI.
111-
112-
See Also
113-
--------
114-
MPIRecv
115-
MPIWait
116-
117-
Notes
118-
-----
119-
Non-differentiable.
120-
121-
"""
122-
123-
__props__ = ("source", "tag", "shape", "dtype")
124-
125-
def __init__(self, source, tag, shape, dtype):
126-
self.source = source
127-
self.tag = tag
128-
self.shape = shape
129-
self.dtype = np.dtype(dtype) # turn "float64" into numpy.float64
130-
self.static_shape = (None,) * len(shape)
131-
132-
def make_node(self):
133-
return Apply(
134-
self,
135-
[],
136-
[
137-
Variable(Generic(), None),
138-
tensor(self.dtype, shape=self.static_shape),
139-
],
140-
)
141-
142-
def perform(self, node, inp, out):
143-
144-
data = np.zeros(self.shape, dtype=self.dtype)
145-
request = comm.Irecv(data, self.source, self.tag)
146-
147-
out[0][0] = request
148-
out[1][0] = data
149-
150-
def __str__(self):
151-
return f"MPIRecv{{source: {int(self.source)}, tag: {int(self.tag)}, shape: {self.shape}, dtype: {self.dtype}}}"
152-
153-
def infer_shape(self, fgraph, node, shapes):
154-
return [None, self.shape]
155-
156-
def do_constant_folding(self, fgraph, node):
157-
return False
158-
159-
160-
class MPIRecvWait(Op):
161-
"""
162-
An operation to wait on a previously received array using MPI.
163-
164-
See Also
165-
--------
166-
MPIRecv
167-
168-
Notes
169-
-----
170-
Non-differentiable.
171-
172-
"""
173-
174-
__props__ = ("tag",)
175-
176-
def __init__(self, tag):
177-
self.tag = tag
178-
179-
def make_node(self, request, data):
180-
return Apply(
181-
self,
182-
[request, data],
183-
[tensor(data.dtype, shape=data.type.shape)],
184-
)
185-
186-
def perform(self, node, inp, out):
187-
188-
request = inp[0]
189-
data = inp[1]
190-
191-
request.wait()
192-
193-
out[0][0] = data
194-
195-
def infer_shape(self, fgraph, node, shapes):
196-
return [shapes[1]]
197-
198-
view_map = {0: [1]}
199-
200-
201-
class MPISend(Op):
202-
"""
203-
An operation to asynchronously Send an array to a remote host using MPI.
204-
205-
See Also
206-
--------
207-
MPIRecv
208-
MPISendWait
209-
210-
Notes
211-
-----
212-
Non-differentiable.
213-
214-
"""
215-
216-
__props__ = ("dest", "tag")
217-
218-
def __init__(self, dest, tag):
219-
self.dest = dest
220-
self.tag = tag
221-
222-
def make_node(self, data):
223-
return Apply(self, [data], [Variable(Generic(), None), data.type()])
224-
225-
view_map = {1: [0]}
226-
227-
def perform(self, node, inp, out):
228-
229-
data = inp[0]
230-
231-
request = comm.Isend(data, self.dest, self.tag)
232-
233-
out[0][0] = request
234-
out[1][0] = data
235-
236-
def __str__(self):
237-
return f"MPISend{{dest: {int(self.dest)}, tag: {int(self.tag)}}}"
238-
239-
240-
class MPISendWait(Op):
241-
"""
242-
An operation to wait on a previously sent array using MPI.
243-
244-
See Also
245-
--------
246-
MPISend
247-
248-
Notes
249-
-----
250-
Non-differentiable.
251-
252-
"""
253-
254-
__props__ = ("tag",)
255-
256-
def __init__(self, tag):
257-
self.tag = tag
258-
259-
def make_node(self, request, data):
260-
return Apply(self, [request, data], [Variable(Generic(), None)])
261-
262-
def perform(self, node, inp, out):
263-
request = inp[0]
264-
request.wait()
265-
out[0][0] = True
266-
267-
268-
def isend(var, dest, tag):
269-
"""
270-
Non blocking send.
271-
"""
272-
return MPISend(dest, tag)(var)
273-
274-
275-
def send(var, dest, tag):
276-
"""
277-
Blocking send.
278-
"""
279-
return MPISendWait(tag)(*isend(var, dest, tag))
280-
281-
282-
def irecv(shape, dtype, source, tag):
283-
"""
284-
Non-blocking receive.
285-
"""
286-
return MPIRecv(source, tag, shape, dtype)()
287-
288-
289-
def recv(shape, dtype, source, tag):
290-
"""
291-
Blocking receive.
292-
"""
293-
return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag))
294-
295-
296-
# Ordering keys for scheduling
297-
def mpi_send_wait_key(a):
298-
"""Wait as long as possible on Waits, Start Send/Recvs early."""
299-
if isinstance(a.op, (MPIRecvWait, MPISendWait)):
300-
return 1
301-
if isinstance(a.op, (MPIRecv, MPISend)):
302-
return -1
303-
return 0
304-
305-
306-
def mpi_tag_key(a):
307-
"""Break MPI ties by using the variable tag - prefer lower tags first."""
308-
if isinstance(a.op, (MPISend, MPIRecv, MPIRecvWait, MPISendWait)):
309-
return a.op.tag
310-
else:
311-
return 0
312-
313-
314-
mpi_send_wait_cmp = key_to_cmp(mpi_send_wait_key)
315-
mpi_tag_cmp = key_to_cmp(mpi_tag_key)
316-
317-
mpi_keys = (mpi_send_wait_key, mpi_tag_key)
318-
mpi_cmps = (mpi_send_wait_cmp, mpi_tag_cmp)
319-
32094
__all__ = ["load"]

tests/link/test_link.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44
import numpy as np
55

66
import pytensor
7-
from pytensor.compile.mode import Mode
87
from pytensor.graph import fg
98
from pytensor.graph.basic import Apply, Constant, Variable, clone
109
from pytensor.graph.op import Op
1110
from pytensor.graph.type import Type
1211
from pytensor.link.basic import Container, Linker, PerformLinker, WrapLinker
13-
from pytensor.link.c.basic import OpWiseCLinker
14-
from pytensor.tensor.type import matrix, scalar
15-
from pytensor.utils import cmp, to_return_values
12+
from pytensor.tensor.type import scalar
13+
from pytensor.utils import to_return_values
1614

1715

1816
def make_function(linker: Linker, unpack_single: bool = True, **kwargs) -> Callable:
@@ -219,26 +217,6 @@ def wrap(fgraph, i, node, th):
219217
assert o[0].data == 1.5
220218

221219

222-
def test_sort_schedule_fn():
223-
from pytensor.graph.sched import make_depends, sort_schedule_fn
224-
225-
x = matrix("x")
226-
y = pytensor.tensor.dot(x[:5] * 2, x.T + 1).T
227-
228-
def str_cmp(a, b):
229-
return cmp(str(a), str(b)) # lexicographical sort
230-
231-
linker = OpWiseCLinker(schedule=sort_schedule_fn(str_cmp))
232-
mode = Mode(linker=linker)
233-
f = pytensor.function((x,), (y,), mode=mode)
234-
235-
nodes = f.maker.linker.make_all()[-1]
236-
depends = make_depends()
237-
for a, b in zip(nodes[:-1], nodes[1:]):
238-
if not depends((b, a)):
239-
assert str(a) < str(b)
240-
241-
242220
def test_container_deepcopy():
243221
# This is a test to a work around a NumPy bug.
244222

tests/tensor/_test_mpi_roundtrip.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

0 commit comments

Comments
 (0)