Skip to content

ndarray dunders / binary ufuncs #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
__pycache__/*
autogen/__pycache__
torch_np/__pycache__/*
torch_np/tests/__pycache__/*
torch_np/tests/numpy_tests/core/__pycache__/*
torch_np/testing/__pycache__/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
.coverage

213 changes: 213 additions & 0 deletions autogen/gen_ufuncs_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from dump_namespace import grab_namespace, get_signature

import numpy as np

namespace = np

dct = grab_namespace(namespace)


# SKIP these (need special handling)
skip = {np.frexp, np.modf, # non-standard unary ufunc signatures
np.isnat,
np.invert, # bitwise NOT operator
np.spacing, # niche, does not have a direct equivalent
}

# np functions where torch names differ
torch_names = {np.radians : "deg2rad",
np.degrees : "rad2deg",
np.conjugate : "conj_physical",
np.fabs : "absolute", # FIXME: np.fabs raises form complex
np.rint : "round",
np.left_shift: "bitwise_left_shift",
np.right_shift: "bitwise_right_shift",
}


# np functions which do not have a torch equivalent
default_stanza = "torch.{torch_name}(x, out=out)"

stanzas = {np.cbrt : "torch.pow(x, 1/3, out=out)",

# XXX what on earth is np.positive
np.positive: "+x",

# these three do not have an out arg
np.isinf: "torch.isinf(x)",
np.isnan: "torch.isnan(x)",
np.isfinite: "torch.isfinite(x)",
}


# for these np functions, pytorch analog does not have the out= arg
needs_out = {np.isinf, np.isnan, np.isfinite, np.positive}
add_out_stanza = """
if out is not None:
out[...] = result
"""


header = """\
# this file is autogenerated via gen_ufuncs.py
# do not edit manually!

import torch

import _util
from _ndarray import asarray_replacer

"""

test_header = header + """\
import numpy as np
import torch

from _unary_ufuncs import *
from testing import assert_allclose
"""


template = """ """

test_template = """

def test_{np_name}():
assert_allclose(np.{np_name}(0.5),
{np_name}(0.5), atol=1e-14, check_dtype=False)

"""


###### UNARY UFUNCS ###################################

_all_list = []
main_text = header
test_text = test_header

_impl_list = []
_ufunc_list = []

for ufunc in dct['ufunc']:
if ufunc in skip:
continue

if ufunc.nin == 1:
# print(get_signature(ufunc))

torch_name = torch_names.get(ufunc)
if torch_name is None:
torch_name = ufunc.__name__

# print(ufunc.__name__, ' -- ', torch_name)

_impl_stanza = "{np_name} = deco_unary_ufunc(torch.{torch_name})"
_impl_stanza = _impl_stanza.format(np_name=ufunc.__name__,
torch_name=torch_name,)
_impl_list.append(_impl_stanza)

continue

torch_stanza = stanzas.get(ufunc)
if torch_stanza is None:
torch_stanza = default_stanza.format(torch_name=torch_name)

out_stanza= add_out_stanza if ufunc in needs_out else ""

main_text += template.format(np_name=ufunc.__name__,
torch_stanza=torch_stanza,
out_stanza=out_stanza)
test_text += test_template.format(np_name=ufunc.__name__)

_all_list.append(ufunc.__name__)


print("\n".join(_impl_list))
print("\n\n-----\n\n")

'''
main_text += "\n\n__all__ = %s" % _all_list


with open("_unary_ufuncs.py", "w") as f:
f.write(main_text)

with open("test_unary_ufuncs.py", "w") as f:
f.write(test_text)
'''

###### BINARY UFUNCS ###################################



test_header = header + """\
import numpy as np
import torch

from _binary_ufuncs import *
from testing import assert_allclose
"""


template = """


"""

test_template = """

def test_{np_name}():
assert_allclose(np.{np_name}(0.5, 0.6),
{np_name}(0.5, 0.6), atol=1e-7, check_dtype=False)

"""



skip = {np.divmod, # two outputs
}


torch_names = {np.power: "pow",
np.equal: "eq",
}


_all_list = []
main_text = header
test_text = test_header

_impl_list = []
_ufunc_list = []


for ufunc in dct['ufunc']:

if ufunc in skip:
continue

if ufunc.nin == 2:
## print(get_signature(ufunc))

torch_name = torch_names.get(ufunc)
if torch_name is None:
torch_name = ufunc.__name__

_impl_stanza = "{np_name} = deco_binary_ufunc(torch.{torch_name})"
_impl_stanza = _impl_stanza.format(np_name=ufunc.__name__,
torch_name=torch_name,)
_impl_list.append(_impl_stanza)

_ufunc_stanza = "{np_name} = deco_ufunc_from_impl(_ufunc_impl.{np_name})"
_ufunc_stanza = _ufunc_stanza.format(np_name=ufunc.__name__)
_ufunc_list.append(_ufunc_stanza)


print("\n".join(_impl_list))

print("\n\n")
print("\n".join(_ufunc_list))




2 changes: 1 addition & 1 deletion torch_np/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ._unary_ufuncs import *
from ._binary_ufuncs import *
from ._ndarray import can_cast, result_type, newaxis
from ._util import AxisError
from ._util import AxisError, UFuncTypeError
from ._getlimits import iinfo, finfo
from ._getlimits import errstate

Expand Down
Loading