From d0d02c66734e6a8b9578a88bfe3c82dddb24b855 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 31 Oct 2024 01:49:16 -0700 Subject: [PATCH 1/2] feat: add complex dtype support `mean` --- src/array_api_stubs/_draft/statistical_functions.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/array_api_stubs/_draft/statistical_functions.py b/src/array_api_stubs/_draft/statistical_functions.py index 9d3563e26..bf038466a 100644 --- a/src/array_api_stubs/_draft/statistical_functions.py +++ b/src/array_api_stubs/_draft/statistical_functions.py @@ -113,7 +113,7 @@ def mean( Parameters ---------- x: array - input array. Should have a real-valued floating-point data type. + input array. Should have a floating-point data type. axis: Optional[Union[int, Tuple[int, ...]]] axis or axes along which arithmetic means must be computed. By default, the mean must be computed over the entire array. If a tuple of integers, arithmetic means must be computed over multiple axes. Default: ``None``. keepdims: bool @@ -125,17 +125,23 @@ def mean( if the arithmetic mean was computed over the entire array, a zero-dimensional array containing the arithmetic mean; otherwise, a non-zero-dimensional array containing the arithmetic means. The returned array must have the same data type as ``x``. .. note:: - While this specification recommends that this function only accept input arrays having a real-valued floating-point data type, specification-compliant array libraries may choose to accept input arrays having an integer data type. While mixed data type promotion is implementation-defined, if the input array ``x`` has an integer data type, the returned array must have the default real-valued floating-point data type. + While this specification recommends that this function only accept input arrays having a floating-point data type, specification-compliant array libraries may choose to accept input arrays having an integer data type. While mixed data type promotion is implementation-defined, if the input array ``x`` has an integer data type, the returned array must have the default real-valued floating-point data type. Notes ----- **Special Cases** - Let ``N`` equal the number of elements over which to compute the arithmetic mean. + Let ``N`` equal the number of elements over which to compute the arithmetic mean. For real-valued operands, - If ``N`` is ``0``, the arithmetic mean is ``NaN``. - If ``x_i`` is ``NaN``, the arithmetic mean is ``NaN`` (i.e., ``NaN`` values propagate). + + For complex floating-point operands, real-valued floating-point special cases must independently apply to the real and imaginary component operations involving real numbers. For example, let ``a = real(x_i)`` and ``b = imag(x_i)``, and + + - If ``N`` is ``0``, the arithmetic mean is ``NaN + NaN j``. + - If ``a`` is ``NaN``, the real component of the result is ``NaN``. + - Similarly, if ``b`` is ``NaN``, the imaginary component of the result is ``NaN``. """ From 095e8d0ad891d5972ff7704948be8549e87a38ae Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 11 Dec 2024 23:43:39 -0800 Subject: [PATCH 2/2] docs: add note concerning the use of `isnan` --- src/array_api_stubs/_draft/statistical_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/statistical_functions.py b/src/array_api_stubs/_draft/statistical_functions.py index bf038466a..be438173a 100644 --- a/src/array_api_stubs/_draft/statistical_functions.py +++ b/src/array_api_stubs/_draft/statistical_functions.py @@ -137,11 +137,14 @@ def mean( - If ``N`` is ``0``, the arithmetic mean is ``NaN``. - If ``x_i`` is ``NaN``, the arithmetic mean is ``NaN`` (i.e., ``NaN`` values propagate). - For complex floating-point operands, real-valued floating-point special cases must independently apply to the real and imaginary component operations involving real numbers. For example, let ``a = real(x_i)`` and ``b = imag(x_i)``, and + For complex floating-point operands, real-valued floating-point special cases should independently apply to the real and imaginary component operations involving real numbers. For example, let ``a = real(x_i)`` and ``b = imag(x_i)``, and - If ``N`` is ``0``, the arithmetic mean is ``NaN + NaN j``. - If ``a`` is ``NaN``, the real component of the result is ``NaN``. - Similarly, if ``b`` is ``NaN``, the imaginary component of the result is ``NaN``. + + .. note:: + Array libraries, such as NumPy, PyTorch, and JAX, currently deviate from this specification in their handling of components which are ``NaN`` when computing the arithmetic mean. In general, consumers of array libraries implementing this specification should use :func:`~array_api.isnan` to test whether the result of computing the arithmetic mean over an array have a complex floating-point data type is ``NaN``, rather than relying on ``NaN`` propagation of individual components. """