|
25 | 25 | xps = make_strategies_namespace(np, api_version="2022.12")
|
26 | 26 |
|
27 | 27 |
|
| 28 | +default_dtypes = [np.bool, np.int64, np.float64, np.complex128] |
| 29 | +kind_to_strat = { |
| 30 | + "b": xps.boolean_dtypes(), |
| 31 | + "i": xps.integer_dtypes(), |
| 32 | + "u": xps.unsigned_integer_dtypes(sizes=8), |
| 33 | + "f": xps.floating_dtypes(), |
| 34 | + "c": xps.complex_dtypes(), |
| 35 | +} |
| 36 | +scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(np.dtype) |
| 37 | + |
| 38 | + |
28 | 39 | @given(shape=xps.array_shapes(), data=st.data())
|
29 | 40 | def test_full(shape, data):
|
30 | 41 | if data.draw(st.booleans(), label="pass kwargs?"):
|
31 |
| - kw = {} |
32 |
| - else: |
33 |
| - dtype = data.draw(st.none() | xps.scalar_dtypes(), label="dtype") |
| 42 | + dtype = data.draw(st.none() | scalar_dtype_strat, label="dtype") |
34 | 43 | kw = {"dtype": dtype}
|
35 |
| - _dtype = kw.get("dtype", None) or data.draw( |
36 |
| - st.sampled_from([np.bool, np.int64, np.float64, np.complex128]), label="_dtype" |
37 |
| - ) |
| 44 | + else: |
| 45 | + kw = {} |
| 46 | + _dtype = kw.get("dtype", None) or data.draw(scalar_dtype_strat, label="_dtype") |
38 | 47 | values_strat = xps.from_dtype(_dtype)
|
39 |
| - fill_value = data.draw( |
40 |
| - values_strat | values_strat.map(lambda v: np.asarray(v, dtype=_dtype)), |
41 |
| - label="fill_value", |
42 |
| - ) |
43 |
| - out = np.full(shape, fill_value, **kw) |
44 |
| - if kw.get("dtype", None) is None and not isinstance(fill_value, np.ndarray): |
45 |
| - if isinstance(fill_value, bool): |
46 |
| - assert out.dtype == np.bool |
47 |
| - elif isinstance(fill_value, int): |
48 |
| - assert out.dtype == np.int64 |
49 |
| - elif isinstance(fill_value, float): |
50 |
| - assert out.dtype == np.float64 |
| 48 | + if _dtype not in default_dtypes or data.draw( |
| 49 | + st.booleans(), label="fill_value is array?" |
| 50 | + ): |
| 51 | + if specified_dtype := kw.get("dtype", None): |
| 52 | + kind = specified_dtype.name[0] |
| 53 | + values_dtypes_strat = kind_to_strat[kind] |
51 | 54 | else:
|
52 |
| - assert isinstance(fill_value, complex) # sanity check |
53 |
| - assert out.dtype == np.complex128 |
54 |
| - else: |
55 |
| - assert out.dtype == _dtype |
| 55 | + values_dtypes_strat = st.just(_dtype) |
| 56 | + values_strat = values_dtypes_strat.flatmap( |
| 57 | + lambda d: values_strat.map(lambda v: np.asarray(v, dtype=d)) |
| 58 | + ) |
| 59 | + fill_value = data.draw(values_strat, label="fill_value") |
| 60 | + out = np.full(shape, fill_value, **kw) |
| 61 | + assert out.dtype == _dtype |
56 | 62 | assert out.shape == shape
|
57 | 63 | if cmath.isnan(fill_value):
|
58 | 64 | assert np.isnan(out).all()
|
|
0 commit comments