-
Notifications
You must be signed in to change notification settings - Fork 4
dtypes and array scalars #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from ._dtypes import * | ||
from ._scalar_types import * | ||
from ._wrapper import * | ||
from . import testing | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,12 @@ | |
import builtins | ||
import torch | ||
|
||
from . import _scalar_types | ||
|
||
|
||
__all__ = ['dtype_from_torch', 'dtype', 'typecodes', 'issubdtype'] | ||
|
||
|
||
# Define analogs of numpy dtypes supported by pytorch. | ||
|
||
class dtype: | ||
|
@@ -22,6 +28,9 @@ def __init__(self, name): | |
_name = typecode_chars_dict[name] | ||
elif name in dt_aliases_dict: | ||
_name = dt_aliases_dict[name] | ||
# the check must come last, so that 'name' is not a string | ||
elif issubclass(name, _scalar_types.generic): | ||
_name = name.name | ||
else: | ||
raise TypeError(f"data type '{name}' not understood") | ||
self._name = _name | ||
|
@@ -30,6 +39,10 @@ def __init__(self, name): | |
def name(self): | ||
return self._name | ||
|
||
@property | ||
def type(self): | ||
return _scalar_types._typemap[self._name] | ||
|
||
@property | ||
def typecode(self): | ||
return _typecodes_from_dtype_dict[self._name] | ||
|
@@ -104,21 +117,6 @@ def __repr__(self): | |
} | ||
|
||
|
||
float16 = dtype("float16") | ||
float32 = dtype("float32") | ||
float64 = dtype("float64") | ||
complex64 = dtype("complex64") | ||
complex128 = dtype("complex128") | ||
uint8 = dtype("uint8") | ||
int8 = dtype("int8") | ||
int16 = dtype("int16") | ||
int32 = dtype("int32") | ||
int64 = dtype("int64") | ||
bool = dtype("bool") | ||
|
||
intp = int64 # XXX | ||
int_ = int64 | ||
|
||
# Map the torch-suppored subset dtypes to local analogs | ||
# "quantized" types not available in numpy, skip | ||
_dtype_from_torch_dict = { | ||
|
@@ -183,6 +181,23 @@ def is_integer(dtyp): | |
return dtyp.typecode in typecodes['AllInteger'] | ||
|
||
|
||
|
||
def issubclass_(arg, klass): | ||
try: | ||
return issubclass(arg, klass) | ||
except TypeError: | ||
return False | ||
|
||
|
||
def issubdtype(arg1, arg2): | ||
# cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420 | ||
if not issubclass_(arg1, _scalar_types.generic): | ||
arg1 = dtype(arg1).type | ||
if not issubclass_(arg2, _scalar_types.generic): | ||
arg2 = dtype(arg2).type | ||
return issubclass(arg1, arg2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to not yet be complete I think? At least from a quick test: In [25]: tnp.issubdtype(tnp.float32, tnp.inexact)
Out[25]: False
In [26]: np.issubdtype(np.float32, np.inexact)
Out[26]: True |
||
|
||
|
||
# The casting below is defined *with dtypes only*, so no value-based casting! | ||
|
||
# These two dicts are autogenerated with autogen/gen_dtypes.py, | ||
|
@@ -216,6 +231,3 @@ def is_integer(dtyp): | |
|
||
########################## end autogenerated part | ||
|
||
|
||
__all__ = ['dtype_from_torch', 'dtype', 'typecodes'] + dt_names + ['intp', 'int_'] | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
"""Replicate the NumPy scalar type hierarchy | ||
""" | ||
|
||
import abc | ||
import torch | ||
|
||
class generic(abc.ABC): | ||
@property | ||
@abc.abstractmethod | ||
def name(self): | ||
return self.__class__.__name__ | ||
|
||
def __new__(self, value): | ||
# | ||
# Yes, a call to np.float32(4) produces a zero-dim array. | ||
# | ||
from . import _dtypes | ||
from . import _ndarray | ||
|
||
torch_dtype = _dtypes.torch_dtype_from(self.name) | ||
tensor = torch.as_tensor(value, dtype=torch_dtype) | ||
return _ndarray.ndarray._from_tensor_and_base(tensor, None) | ||
|
||
|
||
##### these are abstract types | ||
|
||
class number(generic): | ||
pass | ||
|
||
|
||
class integer(generic): | ||
pass | ||
|
||
|
||
class inexact(generic): | ||
pass | ||
|
||
|
||
class signedinteger(generic): | ||
pass | ||
|
||
|
||
class unsignedinteger(generic): | ||
pass | ||
|
||
|
||
class inexact(generic): | ||
pass | ||
|
||
|
||
class floating(generic): | ||
pass | ||
|
||
|
||
class complexfloating(generic): | ||
pass | ||
|
||
|
||
# ##### concrete types | ||
|
||
# signed integers | ||
|
||
class int8(signedinteger): | ||
name = 'int8' | ||
|
||
|
||
class int16(signedinteger): | ||
name = 'int16' | ||
|
||
|
||
class int32(signedinteger): | ||
name = 'int32' | ||
|
||
|
||
class int64(signedinteger): | ||
name = 'int64' | ||
|
||
|
||
# unsigned integers | ||
|
||
class uint8(unsignedinteger): | ||
name = 'uint8' | ||
|
||
|
||
# floating point | ||
|
||
class float16(floating): | ||
name = 'float16' | ||
|
||
|
||
class float32(floating): | ||
name = 'float32' | ||
|
||
|
||
class float64(floating): | ||
name = 'float64' | ||
|
||
|
||
class complex64(complexfloating): | ||
name = 'complex64' | ||
|
||
|
||
class complex128(complexfloating): | ||
name = 'complex128' | ||
|
||
|
||
class bool_(generic): | ||
name = 'bool' | ||
|
||
|
||
# name aliases : FIXME (OS, bitness) | ||
intp = int64 | ||
int_ = int64 | ||
|
||
|
||
_typemap ={ | ||
'int8' : int8, | ||
'int16' : int16, | ||
'int32' : int32, | ||
'int64' : int64, | ||
'uint8' : uint8, | ||
'float16': float16, | ||
'float32': float32, | ||
'float64': float64, | ||
'complex64': complex64, | ||
'complex128': complex128, | ||
'bool': bool_ | ||
} | ||
|
||
|
||
__all__ = list(_typemap.keys()) | ||
__all__.remove('bool') | ||
|
||
__all__ += ['bool_', 'intp', 'int_'] | ||
__all__ += ['generic', 'number', | ||
'signedinteger', 'unsignedinteger', | ||
'inexact', 'floating', 'complexfloating'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This goes on my list of functions to deprecate and remove in NumPy 2.0 ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is easy enough so may as well leave it here now. I'd say let's not bother with
issubsctype
,obj2sctype
, etc., those definitely all need to go from NumPy.