@@ -82,10 +82,10 @@ def is_neg_zero(n: float) -> bool:
82
82
83
83
def assert_dtype (
84
84
func_name : str ,
85
+ * ,
85
86
in_dtype : Union [DataType , Sequence [DataType ]],
86
87
out_dtype : DataType ,
87
88
expected : Optional [DataType ] = None ,
88
- * ,
89
89
repr_name : str = "out.dtype" ,
90
90
):
91
91
"""
@@ -96,7 +96,7 @@ def assert_dtype(
96
96
97
97
>>> x = xp.arange(5, dtype=xp.uint8)
98
98
>>> out = xp.abs(x)
99
- >>> assert_dtype('abs', x.dtype, out.dtype)
99
+ >>> assert_dtype('abs', in_dtype= x.dtype, out_dtype= out.dtype)
100
100
101
101
is equivalent to
102
102
@@ -108,7 +108,7 @@ def assert_dtype(
108
108
>>> x1 = xp.arange(5, dtype=xp.uint8)
109
109
>>> x2 = xp.arange(5, dtype=xp.uint16)
110
110
>>> out = xp.add(x1, x2)
111
- >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
111
+ >>> assert_dtype('add', in_dtype= [x1.dtype, x2.dtype], out_dtype= out.dtype)
112
112
113
113
is equivalent to
114
114
@@ -119,7 +119,7 @@ def assert_dtype(
119
119
>>> x = xp.arange(5, dtype=xp.int8)
120
120
>>> out = xp.sum(x)
121
121
>>> default_int = xp.asarray(0).dtype
122
- >>> assert_dtype('sum', x, out.dtype, default_int)
122
+ >>> assert_dtype('sum', in_dtype= x, out_dtype= out.dtype, expected= default_int)
123
123
124
124
"""
125
125
in_dtypes = in_dtype if isinstance (in_dtype , Sequence ) and not isinstance (in_dtype , str ) else [in_dtype ]
@@ -135,13 +135,18 @@ def assert_dtype(
135
135
assert out_dtype == expected , msg
136
136
137
137
138
- def assert_kw_dtype (func_name : str , kw_dtype : DataType , out_dtype : DataType ):
138
+ def assert_kw_dtype (
139
+ func_name : str ,
140
+ * ,
141
+ kw_dtype : DataType ,
142
+ out_dtype : DataType ,
143
+ ):
139
144
"""
140
145
Assert the output dtype is the passed keyword dtype, e.g.
141
146
142
147
>>> kw = {'dtype': xp.uint8}
143
- >>> out = xp.ones(5, ** kw)
144
- >>> assert_kw_dtype('ones', kw['dtype'], out.dtype)
148
+ >>> out = xp.ones(5, kw= kw)
149
+ >>> assert_kw_dtype('ones', kw_dtype= kw['dtype'], out_dtype= out.dtype)
145
150
146
151
"""
147
152
f_kw_dtype = dh .dtype_to_name [kw_dtype ]
@@ -222,17 +227,17 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
222
227
223
228
def assert_shape (
224
229
func_name : str ,
230
+ * ,
225
231
out_shape : Union [int , Shape ],
226
232
expected : Union [int , Shape ],
227
- / ,
228
233
repr_name = "out.shape" ,
229
- ** kw ,
234
+ kw : dict = {} ,
230
235
):
231
236
"""
232
237
Assert the output shape is as expected, e.g.
233
238
234
239
>>> out = xp.ones((3, 3, 3))
235
- >>> assert_shape('ones', out.shape, (3, 3, 3))
240
+ >>> assert_shape('ones', out_shape= out.shape, expected= (3, 3, 3))
236
241
237
242
"""
238
243
if isinstance (out_shape , int ):
@@ -249,11 +254,10 @@ def assert_result_shape(
249
254
func_name : str ,
250
255
in_shapes : Sequence [Shape ],
251
256
out_shape : Shape ,
252
- / ,
253
257
expected : Optional [Shape ] = None ,
254
258
* ,
255
259
repr_name = "out.shape" ,
256
- ** kw ,
260
+ kw : dict = {} ,
257
261
):
258
262
"""
259
263
Assert the output shape is as expected.
@@ -262,7 +266,7 @@ def assert_result_shape(
262
266
in_shapes, to test against out_shape, e.g.
263
267
264
268
>>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
265
- >>> assert_shape ('add', [(3, 1), (1, 3)], out.shape)
269
+ >>> assert_result_shape ('add', in_shape= [(3, 1), (1, 3)], out_shape= out.shape)
266
270
267
271
is equivalent to
268
272
@@ -281,21 +285,21 @@ def assert_result_shape(
281
285
282
286
def assert_keepdimable_shape (
283
287
func_name : str ,
288
+ * ,
284
289
in_shape : Shape ,
285
290
out_shape : Shape ,
286
291
axes : Tuple [int , ...],
287
292
keepdims : bool ,
288
- / ,
289
- ** kw ,
293
+ kw : dict = {},
290
294
):
291
295
"""
292
296
Assert the output shape from a keepdimable function is as expected, e.g.
293
297
294
298
>>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
295
299
>>> out1 = xp.max(x, keepdims=False)
296
300
>>> out2 = xp.max(x, keepdims=True)
297
- >>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
298
- >>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)
301
+ >>> assert_keepdimable_shape('max', in_shape= x.shape, out_shape= out1.shape, axes= (0, 1), keepdims= False)
302
+ >>> assert_keepdimable_shape('max', in_shape= x.shape, out_shape= out2.shape, axes= (0, 1), keepdims= True)
299
303
300
304
is equivalent to
301
305
@@ -307,19 +311,26 @@ def assert_keepdimable_shape(
307
311
shape = tuple (1 if axis in axes else side for axis , side in enumerate (in_shape ))
308
312
else :
309
313
shape = tuple (side for axis , side in enumerate (in_shape ) if axis not in axes )
310
- assert_shape (func_name , out_shape , shape , ** kw )
314
+ assert_shape (func_name , out_shape = out_shape , expected = shape , kw = kw )
311
315
312
316
313
317
def assert_0d_equals (
314
- func_name : str , x_repr : str , x_val : Array , out_repr : str , out_val : Array , ** kw
318
+ func_name : str ,
319
+ * ,
320
+ x_repr : str ,
321
+ x_val : Array ,
322
+ out_repr : str ,
323
+ out_val : Array ,
324
+ kw : dict = {},
315
325
):
316
326
"""
317
327
Assert a 0d array is as expected, e.g.
318
328
319
329
>>> x = xp.asarray([0, 1, 2])
320
- >>> res = xp.asarray(x, copy=True)
330
+ >>> kw = {'copy': True}
331
+ >>> res = xp.asarray(x, **kw)
321
332
>>> res[0] = 42
322
- >>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0])
333
+ >>> assert_0d_equals('asarray', x_repr= 'x[0]', x_val= x[0], out_repr= 'x[0]', out_val= res[0], kw=kw )
323
334
324
335
is equivalent to
325
336
@@ -338,20 +349,20 @@ def assert_0d_equals(
338
349
339
350
def assert_scalar_equals (
340
351
func_name : str ,
352
+ * ,
341
353
type_ : ScalarType ,
342
354
idx : Shape ,
343
355
out : Scalar ,
344
356
expected : Scalar ,
345
- / ,
346
357
repr_name : str = "out" ,
347
- ** kw ,
358
+ kw : dict = {} ,
348
359
):
349
360
"""
350
361
Assert a 0d array, convered to a scalar, is as expected, e.g.
351
362
352
363
>>> x = xp.ones(5, dtype=xp.uint8)
353
364
>>> out = xp.sum(x)
354
- >>> assert_scalar_equals('sum', int, (), int(out), 5)
365
+ >>> assert_scalar_equals('sum', type_int, out= (), out= int(out), expected= 5)
355
366
356
367
is equivalent to
357
368
@@ -372,13 +383,18 @@ def assert_scalar_equals(
372
383
373
384
374
385
def assert_fill (
375
- func_name : str , fill_value : Scalar , dtype : DataType , out : Array , / , ** kw
386
+ func_name : str ,
387
+ * ,
388
+ fill_value : Scalar ,
389
+ dtype : DataType ,
390
+ out : Array ,
391
+ kw : dict = {},
376
392
):
377
393
"""
378
394
Assert all elements of an array is as expected, e.g.
379
395
380
396
>>> out = xp.full(5, 42, dtype=xp.uint8)
381
- >>> assert_fill('full', 42, xp.uint8, out, 5 )
397
+ >>> assert_fill('full', fill_value= 42, dtype= xp.uint8, out=out, kw=dict(shape=5) )
382
398
383
399
is equivalent to
384
400
@@ -408,22 +424,27 @@ def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
408
424
409
425
410
426
def assert_array_elements (
411
- func_name : str , out : Array , expected : Array , / , * , out_repr : str = "out" , ** kw
427
+ func_name : str ,
428
+ * ,
429
+ out : Array ,
430
+ expected : Array ,
431
+ out_repr : str = "out" ,
432
+ kw : dict = {},
412
433
):
413
434
"""
414
435
Assert array elements are (strictly) as expected, e.g.
415
436
416
437
>>> x = xp.arange(5)
417
438
>>> out = xp.asarray(x)
418
- >>> assert_array_elements('asarray', out, x)
439
+ >>> assert_array_elements('asarray', out=out, expected= x)
419
440
420
441
is equivalent to
421
442
422
443
>>> assert xp.all(out == x)
423
444
424
445
"""
425
446
dh .result_type (out .dtype , expected .dtype ) # sanity check
426
- assert_shape (func_name , out .shape , expected .shape , ** kw ) # sanity check
447
+ assert_shape (func_name , out_shape = out .shape , expected = expected .shape , kw = kw ) # sanity check
427
448
f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
428
449
if out .dtype in dh .float_dtypes :
429
450
for idx in sh .ndindex (out .shape ):
0 commit comments