Skip to content

dtypes and array scalars #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch_np/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._dtypes import *
from ._scalar_types import *
from ._wrapper import *
from . import testing

Expand Down
48 changes: 30 additions & 18 deletions torch_np/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
import builtins
import torch

from . import _scalar_types


__all__ = ['dtype_from_torch', 'dtype', 'typecodes', 'issubdtype']


# Define analogs of numpy dtypes supported by pytorch.

class dtype:
Expand All @@ -22,6 +28,9 @@ def __init__(self, name):
_name = typecode_chars_dict[name]
elif name in dt_aliases_dict:
_name = dt_aliases_dict[name]
# the check must come last, so that 'name' is not a string
elif issubclass(name, _scalar_types.generic):
_name = name.name
else:
raise TypeError(f"data type '{name}' not understood")
self._name = _name
Expand All @@ -30,6 +39,10 @@ def __init__(self, name):
def name(self):
return self._name

@property
def type(self):
return _scalar_types._typemap[self._name]

@property
def typecode(self):
return _typecodes_from_dtype_dict[self._name]
Expand Down Expand Up @@ -104,21 +117,6 @@ def __repr__(self):
}


float16 = dtype("float16")
float32 = dtype("float32")
float64 = dtype("float64")
complex64 = dtype("complex64")
complex128 = dtype("complex128")
uint8 = dtype("uint8")
int8 = dtype("int8")
int16 = dtype("int16")
int32 = dtype("int32")
int64 = dtype("int64")
bool = dtype("bool")

intp = int64 # XXX
int_ = int64

# Map the torch-suppored subset dtypes to local analogs
# "quantized" types not available in numpy, skip
_dtype_from_torch_dict = {
Expand Down Expand Up @@ -183,6 +181,23 @@ def is_integer(dtyp):
return dtyp.typecode in typecodes['AllInteger']



def issubclass_(arg, klass):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This goes on my list of functions to deprecate and remove in NumPy 2.0 ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is easy enough so may as well leave it here now. I'd say let's not bother with issubsctype, obj2sctype, etc., those definitely all need to go from NumPy.

try:
return issubclass(arg, klass)
except TypeError:
return False


def issubdtype(arg1, arg2):
# cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420
if not issubclass_(arg1, _scalar_types.generic):
arg1 = dtype(arg1).type
if not issubclass_(arg2, _scalar_types.generic):
arg2 = dtype(arg2).type
return issubclass(arg1, arg2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to not yet be complete I think? At least from a quick test:

In [25]: tnp.issubdtype(tnp.float32, tnp.inexact)
Out[25]: False

In [26]: np.issubdtype(np.float32, np.inexact)
Out[26]: True



# The casting below is defined *with dtypes only*, so no value-based casting!

# These two dicts are autogenerated with autogen/gen_dtypes.py,
Expand Down Expand Up @@ -216,6 +231,3 @@ def is_integer(dtyp):

########################## end autogenerated part


__all__ = ['dtype_from_torch', 'dtype', 'typecodes'] + dt_names + ['intp', 'int_']

6 changes: 6 additions & 0 deletions torch_np/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,9 @@ def allow_only_single_axis(axis):
if len(axis) != 1:
raise NotImplementedError("does not handle tuple axis")
return axis[0]


def to_tensors(*inputs):
"""Convert all ndarrays from `inputs` to tensors."""
return tuple([value.get() if isinstance(value, ndarray) else value
for value in inputs])
137 changes: 137 additions & 0 deletions torch_np/_scalar_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Replicate the NumPy scalar type hierarchy
"""

import abc
import torch

class generic(abc.ABC):
@property
@abc.abstractmethod
def name(self):
return self.__class__.__name__

def __new__(self, value):
#
# Yes, a call to np.float32(4) produces a zero-dim array.
#
from . import _dtypes
from . import _ndarray

torch_dtype = _dtypes.torch_dtype_from(self.name)
tensor = torch.as_tensor(value, dtype=torch_dtype)
return _ndarray.ndarray._from_tensor_and_base(tensor, None)


##### these are abstract types

class number(generic):
pass


class integer(generic):
pass


class inexact(generic):
pass


class signedinteger(generic):
pass


class unsignedinteger(generic):
pass


class inexact(generic):
pass


class floating(generic):
pass


class complexfloating(generic):
pass


# ##### concrete types

# signed integers

class int8(signedinteger):
name = 'int8'


class int16(signedinteger):
name = 'int16'


class int32(signedinteger):
name = 'int32'


class int64(signedinteger):
name = 'int64'


# unsigned integers

class uint8(unsignedinteger):
name = 'uint8'


# floating point

class float16(floating):
name = 'float16'


class float32(floating):
name = 'float32'


class float64(floating):
name = 'float64'


class complex64(complexfloating):
name = 'complex64'


class complex128(complexfloating):
name = 'complex128'


class bool_(generic):
name = 'bool'


# name aliases : FIXME (OS, bitness)
intp = int64
int_ = int64


_typemap ={
'int8' : int8,
'int16' : int16,
'int32' : int32,
'int64' : int64,
'uint8' : uint8,
'float16': float16,
'float32': float32,
'float64': float64,
'complex64': complex64,
'complex128': complex128,
'bool': bool_
}


__all__ = list(_typemap.keys())
__all__.remove('bool')

__all__ += ['bool_', 'intp', 'int_']
__all__ += ['generic', 'number',
'signedinteger', 'unsignedinteger',
'inexact', 'floating', 'complexfloating']
3 changes: 2 additions & 1 deletion torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ def arange(start=None, stop=None, step=1, dtype=None, *, like=None):
if dtype is None:
dtype = _dtypes.default_int_type()
dtype = result_type(start, stop, step, dtype)

torch_dtype = _dtypes.torch_dtype_from(dtype)
start, stop, step = _helpers.to_tensors(start, stop, step)

try:
return asarray(torch.arange(start, stop, step, dtype=torch_dtype))
except RuntimeError:
Expand Down
5 changes: 2 additions & 3 deletions torch_np/tests/test_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,15 @@ def test_arange_booleans(self):
with pytest.raises(TypeError):
np.arange(3, dtype="bool")

@pytest.mark.skip(reason='XXX: python scalars from array scalars')
@pytest.mark.parametrize("which", [0, 1, 2])
def test_error_paths_and_promotion(self, which):
args = [0, 1, 2] # start, stop, and step
args = [0, 10, 2] # start, stop, and step
args[which] = np.float64(2.) # should ensure float64 output
assert np.arange(*args).dtype == np.float64

# Cover stranger error path, test only to achieve code coverage!
args[which] = [None, []]
with pytest.raises(ValueError):
with pytest.raises((ValueError, RuntimeError)):
# Fails discovering start dtype
np.arange(*args)

Expand Down
2 changes: 1 addition & 1 deletion torch_np/tests/test_shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def test_stack():

# 0d input
for input_ in [(1, 2, 3),
### [np.int32(1), np.int32(2), np.int32(3)], # XXX: numpy scalars?
[np.int32(1), np.int32(2), np.int32(3)],
[np.array(1), np.array(2), np.array(3)]]:
assert_array_equal(stack(input_), [1, 2, 3])
# 1d input examples
Expand Down