3
3
from typing import Any , Dict , Optional , Sequence , Tuple , Union
4
4
5
5
from . import _array_module as xp
6
- from . import array_helpers as ah
7
6
from . import dtype_helpers as dh
8
7
from . import shape_helpers as sh
9
8
from . import stubs
@@ -88,6 +87,40 @@ def assert_dtype(
88
87
* ,
89
88
repr_name : str = "out.dtype" ,
90
89
):
90
+ """
91
+ Assert the output dtype is as expected.
92
+
93
+ If expected=None, we infer the expected dtype as in_dtype, to test
94
+ out_dtype, e.g.
95
+
96
+ >>> x = xp.arange(5, dtype=xp.uint8)
97
+ >>> out = xp.abs(x)
98
+ >>> assert_dtype('abs', x.dtype, out.dtype)
99
+
100
+ is equivalent to
101
+
102
+ >>> assert out.dtype == xp.uint8
103
+
104
+ Or for multiple input dtypes, the expected dtype is inferred from their
105
+ resulting type promotion, e.g.
106
+
107
+ >>> x1 = xp.arange(5, dtype=xp.uint8)
108
+ >>> x2 = xp.arange(5, dtype=xp.uint16)
109
+ >>> out = xp.add(x1, x2)
110
+ >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
111
+
112
+ is equivalent to
113
+
114
+ >>> assert out.dtype == xp.uint16
115
+
116
+ We can also specify the expected dtype ourselves, e.g.
117
+
118
+ >>> x = xp.arange(5, dtype=xp.int8)
119
+ >>> out = xp.sum(x)
120
+ >>> default_int = xp.asarray(0).dtype
121
+ >>> assert_dtype('sum', x, out.dtype, default_int)
122
+
123
+ """
91
124
in_dtypes = in_dtype if isinstance (in_dtype , Sequence ) else [in_dtype ]
92
125
f_in_dtypes = dh .fmt_types (tuple (in_dtypes ))
93
126
f_out_dtype = dh .dtype_to_name [out_dtype ]
@@ -102,6 +135,14 @@ def assert_dtype(
102
135
103
136
104
137
def assert_kw_dtype (func_name : str , kw_dtype : DataType , out_dtype : DataType ):
138
+ """
139
+ Assert the output dtype is the passed keyword dtype, e.g.
140
+
141
+ >>> kw = {'dtype': xp.uint8}
142
+ >>> out = xp.ones(5, **kw)
143
+ >>> assert_kw_dtype('ones', kw['dtype'], out.dtype)
144
+
145
+ """
105
146
f_kw_dtype = dh .dtype_to_name [kw_dtype ]
106
147
f_out_dtype = dh .dtype_to_name [out_dtype ]
107
148
msg = (
@@ -111,33 +152,54 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
111
152
assert out_dtype == kw_dtype , msg
112
153
113
154
114
- def assert_default_float (func_name : str , dtype : DataType ):
115
- f_dtype = dh .dtype_to_name [dtype ]
155
+ def assert_default_float (func_name : str , out_dtype : DataType ):
156
+ """
157
+ Assert the output dtype is the default float, e.g.
158
+
159
+ >>> out = xp.ones(5)
160
+ >>> assert_default_float('ones', out.dtype)
161
+
162
+ """
163
+ f_dtype = dh .dtype_to_name [out_dtype ]
116
164
f_default = dh .dtype_to_name [dh .default_float ]
117
165
msg = (
118
166
f"out.dtype={ f_dtype } , should be default "
119
167
f"floating-point dtype { f_default } [{ func_name } ()]"
120
168
)
121
- assert dtype == dh .default_float , msg
169
+ assert out_dtype == dh .default_float , msg
122
170
123
171
124
- def assert_default_int (func_name : str , dtype : DataType ):
125
- f_dtype = dh .dtype_to_name [dtype ]
172
+ def assert_default_int (func_name : str , out_dtype : DataType ):
173
+ """
174
+ Assert the output dtype is the default int, e.g.
175
+
176
+ >>> out = xp.full(5, 42)
177
+ >>> assert_default_int('full', out.dtype)
178
+
179
+ """
180
+ f_dtype = dh .dtype_to_name [out_dtype ]
126
181
f_default = dh .dtype_to_name [dh .default_int ]
127
182
msg = (
128
183
f"out.dtype={ f_dtype } , should be default "
129
184
f"integer dtype { f_default } [{ func_name } ()]"
130
185
)
131
- assert dtype == dh .default_int , msg
186
+ assert out_dtype == dh .default_int , msg
187
+
188
+
189
+ def assert_default_index (func_name : str , out_dtype : DataType , repr_name = "out.dtype" ):
190
+ """
191
+ Assert the output dtype is the default index dtype, e.g.
132
192
193
+ >>> out = xp.argmax(xp.arange(5))
194
+ >>> assert_default_int('argmax', out.dtype)
133
195
134
- def assert_default_index ( func_name : str , dtype : DataType , repr_name = "out.dtype" ):
135
- f_dtype = dh .dtype_to_name [dtype ]
196
+ """
197
+ f_dtype = dh .dtype_to_name [out_dtype ]
136
198
msg = (
137
199
f"{ repr_name } ={ f_dtype } , should be the default index dtype, "
138
200
f"which is either int32 or int64 [{ func_name } ()]"
139
201
)
140
- assert dtype in (xp .int32 , xp .int64 ), msg
202
+ assert out_dtype in (xp .int32 , xp .int64 ), msg
141
203
142
204
143
205
def assert_shape (
@@ -148,6 +210,13 @@ def assert_shape(
148
210
repr_name = "out.shape" ,
149
211
** kw ,
150
212
):
213
+ """
214
+ Assert the output shape is as expected, e.g.
215
+
216
+ >>> out = xp.ones((3, 3, 3))
217
+ >>> assert_shape('ones', out.shape, (3, 3, 3))
218
+
219
+ """
151
220
if isinstance (out_shape , int ):
152
221
out_shape = (out_shape ,)
153
222
if isinstance (expected , int ):
@@ -168,6 +237,20 @@ def assert_result_shape(
168
237
repr_name = "out.shape" ,
169
238
** kw ,
170
239
):
240
+ """
241
+ Assert the output shape is as expected.
242
+
243
+ If expected=None, we infer the expected shape as the result of broadcasting
244
+ in_shapes, to test against out_shape, e.g.
245
+
246
+ >>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
247
+ >>> assert_shape('add', [(3, 1), (1, 3)], out.shape)
248
+
249
+ is equivalent to
250
+
251
+ >>> assert out.shape == (3, 3)
252
+
253
+ """
171
254
if expected is None :
172
255
expected = sh .broadcast_shapes (* in_shapes )
173
256
f_in_shapes = " . " .join (str (s ) for s in in_shapes )
@@ -180,13 +263,28 @@ def assert_result_shape(
180
263
181
264
def assert_keepdimable_shape (
182
265
func_name : str ,
183
- out_shape : Shape ,
184
266
in_shape : Shape ,
267
+ out_shape : Shape ,
185
268
axes : Tuple [int , ...],
186
269
keepdims : bool ,
187
270
/ ,
188
271
** kw ,
189
272
):
273
+ """
274
+ Assert the output shape from a keepdimable function is as expected, e.g.
275
+
276
+ >>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
277
+ >>> out1 = xp.max(x, keepdims=False)
278
+ >>> out2 = xp.max(x, keepdims=True)
279
+ >>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
280
+ >>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)
281
+
282
+ is equivalent to
283
+
284
+ >>> assert out1.shape == ()
285
+ >>> assert out2.shape == (1, 1)
286
+
287
+ """
190
288
if keepdims :
191
289
shape = tuple (1 if axis in axes else side for axis , side in enumerate (in_shape ))
192
290
else :
@@ -197,6 +295,19 @@ def assert_keepdimable_shape(
197
295
def assert_0d_equals (
198
296
func_name : str , x_repr : str , x_val : Array , out_repr : str , out_val : Array , ** kw
199
297
):
298
+ """
299
+ Assert a 0d array is as expected, e.g.
300
+
301
+ >>> x = xp.asarray([0, 1, 2])
302
+ >>> res = xp.asarray(x, copy=True)
303
+ >>> res[0] = 42
304
+ >>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])
305
+
306
+ is equivalent to
307
+
308
+ >>> assert res[0] == x[0]
309
+
310
+ """
200
311
msg = (
201
312
f"{ out_repr } ={ out_val } , but should be { x_repr } ={ x_val } "
202
313
f"[{ func_name } ({ fmt_kw (kw )} )]"
@@ -217,9 +328,21 @@ def assert_scalar_equals(
217
328
repr_name : str = "out" ,
218
329
** kw ,
219
330
):
331
+ """
332
+ Assert a 0d array, convered to a scalar, is as expected, e.g.
333
+
334
+ >>> x = xp.ones(5, dtype=xp.uint8)
335
+ >>> out = xp.sum(x)
336
+ >>> assert_scalar_equals('sum', int, (), int(out), 5)
337
+
338
+ is equivalent to
339
+
340
+ >>> assert int(out) == 5
341
+
342
+ """
220
343
repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
221
344
f_func = f"{ func_name } ({ fmt_kw (kw )} )"
222
- if type_ is bool or type_ is int :
345
+ if type_ in [ bool , int ] :
223
346
msg = f"{ repr_name } ={ out } , but should be { expected } [{ f_func } ]"
224
347
assert out == expected , msg
225
348
elif math .isnan (expected ):
@@ -233,14 +356,37 @@ def assert_scalar_equals(
233
356
def assert_fill (
234
357
func_name : str , fill_value : Scalar , dtype : DataType , out : Array , / , ** kw
235
358
):
359
+ """
360
+ Assert all elements of an array is as expected, e.g.
361
+
362
+ >>> out = xp.full(5, 42, dtype=xp.uint8)
363
+ >>> assert_fill('full', 42, xp.uint8, out, 5)
364
+
365
+ is equivalent to
366
+
367
+ >>> assert xp.all(out == 42)
368
+
369
+ """
236
370
msg = f"out not filled with { fill_value } [{ func_name } ({ fmt_kw (kw )} )]\n { out = } "
237
371
if math .isnan (fill_value ):
238
- assert ah .all (ah .isnan (out )), msg
372
+ assert xp .all (xp .isnan (out )), msg
239
373
else :
240
- assert ah .all (ah .equal (out , ah .asarray (fill_value , dtype = dtype ))), msg
374
+ assert xp .all (xp .equal (out , xp .asarray (fill_value , dtype = dtype ))), msg
241
375
242
376
243
377
def assert_array (func_name : str , out : Array , expected : Array , / , ** kw ):
378
+ """
379
+ Assert array is (strictly) as expected, e.g.
380
+
381
+ >>> x = xp.arange(5)
382
+ >>> out = xp.asarray(x)
383
+ >>> assert_array('asarray', out, x)
384
+
385
+ is equivalent to
386
+
387
+ >>> assert xp.all(out == x)
388
+
389
+ """
244
390
assert_dtype (func_name , out .dtype , expected .dtype )
245
391
assert_shape (func_name , out .shape , expected .shape , ** kw )
246
392
f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
0 commit comments