Skip to content

Commit e7e2cfe

Browse files
pitroujcrist
authored andcommitted
Make random_state_data() faster (dask#2379)
* Make random_state_data() faster Also speed up test_random_state_data * Fix doctest
1 parent 85e2b4f commit e7e2cfe

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

dask/array/random.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class RandomState(object):
3737
>>> state = da.random.RandomState(1234) # a seed
3838
>>> x = state.normal(10, 0.1, size=3, chunks=(2,))
3939
>>> x.compute()
40-
array([ 10.06307943, 9.91493648, 10.0822082 ])
40+
array([ 10.01867852, 10.04812289, 9.89649746])
4141
4242
See Also:
4343
np.random.RandomState

dask/tests/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,10 @@ class Bar(object):
5959
assert foo((1, 2.0, b)) == (2, 1.0, b)
6060

6161

62-
@pytest.mark.slow
6362
def test_random_state_data():
6463
seed = 37
6564
state = np.random.RandomState(seed)
66-
n = 100000
65+
n = 10000
6766

6867
# Use an integer
6968
states = random_state_data(n, seed)
@@ -72,6 +71,7 @@ def test_random_state_data():
7271
# Use RandomState object
7372
states2 = random_state_data(n, state)
7473
for s1, s2 in zip(states, states2):
74+
assert s1.shape == (624,)
7575
assert (s1 == s2).all()
7676

7777
# Consistent ordering

dask/utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def random_state_data(n, random_state=None):
263263
Parameters
264264
----------
265265
n : int
266-
Number of tuples to return.
266+
Number of arrays to return.
267267
random_state : int or np.random.RandomState, optional
268268
If an int, is used to seed a new ``RandomState``.
269269
"""
@@ -272,9 +272,10 @@ def random_state_data(n, random_state=None):
272272
if not isinstance(random_state, np.random.RandomState):
273273
random_state = np.random.RandomState(random_state)
274274

275-
maxuint32 = np.iinfo(np.uint32).max
276-
return [(random_state.rand(624) * maxuint32).astype('uint32')
277-
for i in range(n)]
275+
random_data = random_state.bytes(624 * n * 4) # `n * 624` 32-bit integers
276+
l = list(np.frombuffer(random_data, dtype=np.uint32).reshape((n, -1)))
277+
assert len(l) == n
278+
return l
278279

279280

280281
def is_integer(i):

0 commit comments

Comments
 (0)