Skip to content

Commit cc693a4

Browse files
Merge pull request #1488 from IntelPython/device-are-can_cast_and_result_type
Device-aware `can_cast` and `result_type`
2 parents bb5ff39 + e6bc9f2 commit cc693a4

File tree

6 files changed

+356
-270
lines changed

6 files changed

+356
-270
lines changed

dpctl/tensor/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,12 @@
6464
from dpctl.tensor._manipulation_functions import (
6565
broadcast_arrays,
6666
broadcast_to,
67-
can_cast,
6867
concat,
6968
expand_dims,
70-
finfo,
7169
flip,
72-
iinfo,
7370
moveaxis,
7471
permute_dims,
7572
repeat,
76-
result_type,
7773
roll,
7874
squeeze,
7975
stack,
@@ -180,6 +176,7 @@
180176
sum,
181177
)
182178
from ._testing import allclose
179+
from ._type_utils import can_cast, finfo, iinfo, result_type
183180

184181
__all__ = [
185182
"Device",

dpctl/tensor/_data_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _get_dtype(inp_dt, sycl_obj, ref_type=None):
120120

121121
__all__ = [
122122
"dtype",
123+
"_get_dtype",
123124
"isdtype",
124125
"bool",
125126
"int8",

dpctl/tensor/_manipulation_functions.py

Lines changed: 1 addition & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -27,101 +27,14 @@
2727
import dpctl.utils as dputils
2828

2929
from ._copy_utils import _broadcast_strides
30-
from ._type_utils import _to_device_supported_dtype
30+
from ._type_utils import _supported_dtype, _to_device_supported_dtype
3131

3232
__doc__ = (
3333
"Implementation module for array manipulation "
3434
"functions in :module:`dpctl.tensor`"
3535
)
3636

3737

38-
class finfo_object:
39-
"""
40-
`numpy.finfo` subclass which returns Python floating-point scalars for
41-
`eps`, `max`, `min`, and `smallest_normal` attributes.
42-
"""
43-
44-
def __init__(self, dtype):
45-
_supported_dtype([dpt.dtype(dtype)])
46-
self._finfo = np.finfo(dtype)
47-
48-
@property
49-
def bits(self):
50-
"""
51-
number of bits occupied by the real-valued floating-point data type.
52-
"""
53-
return int(self._finfo.bits)
54-
55-
@property
56-
def smallest_normal(self):
57-
"""
58-
smallest positive real-valued floating-point number with full
59-
precision.
60-
"""
61-
return float(self._finfo.smallest_normal)
62-
63-
@property
64-
def tiny(self):
65-
"""an alias for `smallest_normal`"""
66-
return float(self._finfo.tiny)
67-
68-
@property
69-
def eps(self):
70-
"""
71-
difference between 1.0 and the next smallest representable real-valued
72-
floating-point number larger than 1.0 according to the IEEE-754
73-
standard.
74-
"""
75-
return float(self._finfo.eps)
76-
77-
@property
78-
def epsneg(self):
79-
"""
80-
difference between 1.0 and the next smallest representable real-valued
81-
floating-point number smaller than 1.0 according to the IEEE-754
82-
standard.
83-
"""
84-
return float(self._finfo.epsneg)
85-
86-
@property
87-
def min(self):
88-
"""smallest representable real-valued number."""
89-
return float(self._finfo.min)
90-
91-
@property
92-
def max(self):
93-
"largest representable real-valued number."
94-
return float(self._finfo.max)
95-
96-
@property
97-
def resolution(self):
98-
"the approximate decimal resolution of this type."
99-
return float(self._finfo.resolution)
100-
101-
@property
102-
def precision(self):
103-
"""
104-
the approximate number of decimal digits to which this kind of
105-
floating point type is precise.
106-
"""
107-
return float(self._finfo.precision)
108-
109-
@property
110-
def dtype(self):
111-
"""
112-
the dtype for which finfo returns information. For complex input, the
113-
returned dtype is the associated floating point dtype for its real and
114-
complex components.
115-
"""
116-
return self._finfo.dtype
117-
118-
def __str__(self):
119-
return self._finfo.__str__()
120-
121-
def __repr__(self):
122-
return self._finfo.__repr__()
123-
124-
12538
def _broadcast_shape_impl(shapes):
12639
if len(set(shapes)) == 1:
12740
return shapes[0]
@@ -681,127 +594,6 @@ def stack(arrays, axis=0):
681594
return res
682595

683596

684-
def can_cast(from_, to, casting="safe"):
685-
""" can_cast(from, to, casting="safe")
686-
687-
Determines if one data type can be cast to another data type according \
688-
to Type Promotion Rules.
689-
690-
Args:
691-
from (usm_ndarray, dtype): source data type
692-
to (dtype): target data type
693-
casting ({'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional):
694-
controls what kind of data casting may occur.
695-
696-
Returns:
697-
bool:
698-
Gives `True` if cast can occur according to the casting rule.
699-
"""
700-
if isinstance(to, dpt.usm_ndarray):
701-
raise TypeError("Expected dtype type.")
702-
703-
dtype_to = dpt.dtype(to)
704-
705-
dtype_from = (
706-
from_.dtype if isinstance(from_, dpt.usm_ndarray) else dpt.dtype(from_)
707-
)
708-
709-
_supported_dtype([dtype_from, dtype_to])
710-
711-
return np.can_cast(dtype_from, dtype_to, casting)
712-
713-
714-
def result_type(*arrays_and_dtypes):
715-
"""
716-
result_type(arrays_and_dtypes)
717-
718-
Returns the dtype that results from applying the Type Promotion Rules to \
719-
the arguments.
720-
721-
Args:
722-
arrays_and_dtypes (object):
723-
An arbitrary length sequence of arrays or dtypes.
724-
725-
Returns:
726-
dtype:
727-
The dtype resulting from an operation involving the
728-
input arrays and dtypes.
729-
"""
730-
dtypes = [
731-
X.dtype if isinstance(X, dpt.usm_ndarray) else dpt.dtype(X)
732-
for X in arrays_and_dtypes
733-
]
734-
735-
_supported_dtype(dtypes)
736-
737-
return np.result_type(*dtypes)
738-
739-
740-
def iinfo(dtype):
741-
"""iinfo(dtype)
742-
743-
Returns machine limits for integer data types.
744-
745-
Args:
746-
dtype (dtype, usm_ndarray):
747-
integer dtype or
748-
an array with integer dtype.
749-
750-
Returns:
751-
iinfo_object:
752-
An object with the following attributes
753-
* bits: int
754-
number of bits occupied by the data type
755-
* max: int
756-
largest representable number.
757-
* min: int
758-
smallest representable number.
759-
* dtype: dtype
760-
integer data type.
761-
"""
762-
if isinstance(dtype, dpt.usm_ndarray):
763-
dtype = dtype.dtype
764-
_supported_dtype([dpt.dtype(dtype)])
765-
return np.iinfo(dtype)
766-
767-
768-
def finfo(dtype):
769-
"""finfo(type)
770-
771-
Returns machine limits for floating-point data types.
772-
773-
Args:
774-
dtype (dtype, usm_ndarray): floating-point dtype or
775-
an array with floating point data type.
776-
If complex, the information is about its component
777-
data type.
778-
779-
Returns:
780-
finfo_object:
781-
an object have the following attributes
782-
* bits: int
783-
number of bits occupied by dtype.
784-
* eps: float
785-
difference between 1.0 and the next smallest representable
786-
real-valued floating-point number larger than 1.0 according
787-
to the IEEE-754 standard.
788-
* max: float
789-
largest representable real-valued number.
790-
* min: float
791-
smallest representable real-valued number.
792-
* smallest_normal: float
793-
smallest positive real-valued floating-point number with
794-
full precision.
795-
* dtype: dtype
796-
real-valued floating-point data type.
797-
798-
"""
799-
if isinstance(dtype, dpt.usm_ndarray):
800-
dtype = dtype.dtype
801-
_supported_dtype([dpt.dtype(dtype)])
802-
return finfo_object(dtype)
803-
804-
805597
def unstack(X, axis=0):
806598
"""unstack(x, axis=0)
807599
@@ -1229,10 +1021,3 @@ def tile(x, repetitions):
12291021
)
12301022
hev.wait()
12311023
return dpt.reshape(res, res_shape)
1232-
1233-
1234-
def _supported_dtype(dtypes):
1235-
for dtype in dtypes:
1236-
if dtype.char not in "?bBhHiIlLqQefdFD":
1237-
raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
1238-
return True

0 commit comments

Comments
 (0)