Skip to content

Commit 977fceb

Browse files
committed
Add reshape() to torch
The copy keyword raises NotImplementedError for now
1 parent 0b4fcd9 commit 977fceb

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,16 @@ def where(condition: array, x1: array, x2: array, /) -> array:
459459
x1, x2 = _fix_promotion(x1, x2)
460460
return torch.where(condition, x1, x2)
461461

462+
# torch.reshape doesn't have the copy keyword
463+
def reshape(x: array,
464+
/,
465+
shape: Tuple[int, ...],
466+
copy: Optional[bool] = None,
467+
**kwargs) -> array:
468+
if copy is not None:
469+
raise NotImplementedError("torch.reshape doesn't yet support the copy keyword")
470+
return torch.reshape(x, shape, **kwargs)
471+
462472
# torch.arange doesn't support returning empty arrays
463473
# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
464474
# keyword argument combinations
@@ -659,8 +669,8 @@ def isdtype(
659669
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
660670
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
661671
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
662-
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
663-
'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
672+
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
673+
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
664674
'broadcast_arrays', 'unique_all', 'unique_counts',
665675
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
666676
'vecdot', 'tensordot', 'isdtype']

0 commit comments

Comments
 (0)