@@ -296,6 +296,19 @@ def assert_keepdimable_shape(
296
296
def assert_0d_equals (
297
297
func_name : str , x_repr : str , x_val : Array , out_repr : str , out_val : Array , ** kw
298
298
):
299
+ """
300
+ Assert a 0d array is as expected, e.g.
301
+
302
+ >>> x = xp.asarray([0, 1, 2])
303
+ >>> res = xp.asarray(x, copy=True)
304
+ >>> res[0] = 42
305
+ >>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])
306
+
307
+ is equivalent to
308
+
309
+ >>> assert res[0] == x[0]
310
+
311
+ """
299
312
msg = (
300
313
f"{ out_repr } ={ out_val } , but should be { x_repr } ={ x_val } "
301
314
f"[{ func_name } ({ fmt_kw (kw )} )]"
@@ -316,9 +329,21 @@ def assert_scalar_equals(
316
329
repr_name : str = "out" ,
317
330
** kw ,
318
331
):
332
+ """
333
+ Assert a 0d array, convered to a scalar, is as expected, e.g.
334
+
335
+ >>> x = xp.ones(5, dtype=xp.uint8)
336
+ >>> out = xp.sum(x)
337
+ >>> assert_scalar_equals('sum', int, (), int(out), 5)
338
+
339
+ is equivalent to
340
+
341
+ >>> assert int(out) == 5
342
+
343
+ """
319
344
repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
320
345
f_func = f"{ func_name } ({ fmt_kw (kw )} )"
321
- if type_ is bool or type_ is int :
346
+ if type_ in [ bool , int ] :
322
347
msg = f"{ repr_name } ={ out } , but should be { expected } [{ f_func } ]"
323
348
assert out == expected , msg
324
349
elif math .isnan (expected ):
@@ -332,6 +357,17 @@ def assert_scalar_equals(
332
357
def assert_fill (
333
358
func_name : str , fill_value : Scalar , dtype : DataType , out : Array , / , ** kw
334
359
):
360
+ """
361
+ Assert all elements of an array is as expected, e.g.
362
+
363
+ >>> out = xp.full(5, 42, dtype=xp.uint8)
364
+ >>> assert_fill('full', 42, xp.uint8, out, 5)
365
+
366
+ is equivalent to
367
+
368
+ >>> assert xp.all(out == 42)
369
+
370
+ """
335
371
msg = f"out not filled with { fill_value } [{ func_name } ({ fmt_kw (kw )} )]\n { out = } "
336
372
if math .isnan (fill_value ):
337
373
assert ah .all (ah .isnan (out )), msg
@@ -340,6 +376,18 @@ def assert_fill(
340
376
341
377
342
378
def assert_array (func_name : str , out : Array , expected : Array , / , ** kw ):
379
+ """
380
+ Assert array is (strictly) as expected, e.g.
381
+
382
+ >>> x = xp.arange(5)
383
+ >>> out = xp.asarray(x)
384
+ >>> assert_array('asarray', out, x)
385
+
386
+ is equivalent to
387
+
388
+ >>> assert xp.all(out == x)
389
+
390
+ """
343
391
assert_dtype (func_name , out .dtype , expected .dtype )
344
392
assert_shape (func_name , out .shape , expected .shape , ** kw )
345
393
f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
0 commit comments