Skip to content

Commit d0b25ed

Browse files
xhochypitrou
authored andcommitted
ARROW-7493: [Python] Expose sum kernel in pyarrow.compute and support ChunkedArray inputs
1 parent 91114cf commit d0b25ed

File tree

9 files changed

+171
-11
lines changed

9 files changed

+171
-11
lines changed

cpp/src/arrow/compute/kernels/aggregate.cc

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,29 @@ class ManagedAggregateState {
5050
};
5151

5252
Status AggregateUnaryKernel::Call(FunctionContext* ctx, const Datum& input, Datum* out) {
53-
if (!input.is_array()) return Status::Invalid("AggregateKernel expects Array datum");
53+
if (input.is_arraylike()) {
54+
auto state = ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool());
55+
if (!state) return Status::OutOfMemory("AggregateState allocation failed");
5456

55-
auto state = ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool());
56-
if (!state) return Status::OutOfMemory("AggregateState allocation failed");
57-
58-
auto array = input.make_array();
59-
RETURN_NOT_OK(aggregate_function_->Consume(*array, state->mutable_data()));
60-
RETURN_NOT_OK(aggregate_function_->Finalize(state->mutable_data(), out));
57+
if (input.is_array()) {
58+
auto array = input.make_array();
59+
RETURN_NOT_OK(aggregate_function_->Consume(*array, state->mutable_data()));
60+
} else {
61+
auto chunked_array = input.chunked_array();
62+
for (int64_t i = 0; i < chunked_array->num_chunks(); i++) {
63+
auto tmp_state =
64+
ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool());
65+
if (!tmp_state) return Status::OutOfMemory("AggregateState allocation failed");
66+
RETURN_NOT_OK(aggregate_function_->Consume(*chunked_array->chunk(i),
67+
tmp_state->mutable_data()));
68+
RETURN_NOT_OK(
69+
aggregate_function_->Merge(tmp_state->mutable_data(), state->mutable_data()));
70+
}
71+
}
72+
RETURN_NOT_OK(aggregate_function_->Finalize(state->mutable_data(), out));
73+
} else {
74+
return Status::Invalid("AggregateKernel expects Array or ChunkedArray datum");
75+
}
6176

6277
return Status::OK();
6378
}

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,29 @@ void ValidateSum(FunctionContext* ctx, const Array& input, Datum expected) {
9696
DatumEqual<OutputType>::EnsureEqual(result, expected);
9797
}
9898

99+
template <typename ArrowType>
100+
void ValidateSum(FunctionContext* ctx, const std::shared_ptr<ChunkedArray>& input,
101+
Datum expected) {
102+
using OutputType = typename FindAccumulatorType<ArrowType>::Type;
103+
104+
Datum result;
105+
ASSERT_OK(Sum(ctx, input, &result));
106+
DatumEqual<OutputType>::EnsureEqual(result, expected);
107+
}
108+
99109
template <typename ArrowType>
100110
void ValidateSum(FunctionContext* ctx, const char* json, Datum expected) {
101111
auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
102112
ValidateSum<ArrowType>(ctx, *array, expected);
103113
}
104114

115+
template <typename ArrowType>
116+
void ValidateSum(FunctionContext* ctx, const std::vector<std::string>& json,
117+
Datum expected) {
118+
auto array = ChunkedArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
119+
ValidateSum<ArrowType>(ctx, array, expected);
120+
}
121+
105122
template <typename ArrowType>
106123
void ValidateSum(FunctionContext* ctx, const Array& array) {
107124
ValidateSum<ArrowType>(ctx, array, NaiveSum<ArrowType>(array));
@@ -123,6 +140,18 @@ TYPED_TEST(TestNumericSumKernel, SimpleSum) {
123140
ValidateSum<TypeParam>(&this->ctx_, "[0, 1, 2, 3, 4, 5]",
124141
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
125142

143+
std::vector<std::string> json{"[0, 1, 2, 3, 4, 5]"};
144+
ValidateSum<TypeParam>(&this->ctx_, json,
145+
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
146+
147+
json = {"[0, 1, 2]", "[3, 4, 5]"};
148+
ValidateSum<TypeParam>(&this->ctx_, json,
149+
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
150+
151+
json = {"[0, 1, 2]", "[]", "[3, 4, 5]"};
152+
ValidateSum<TypeParam>(&this->ctx_, json,
153+
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
154+
126155
const T expected_result = static_cast<T>(14);
127156
ValidateSum<TypeParam>(&this->ctx_, "[1, null, 3, null, 3, null, 7]",
128157
Datum(std::make_shared<ScalarType>(expected_result)));

cpp/src/arrow/compute/kernels/sum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
namespace arrow {
2525

2626
class Array;
27+
class ChunkedArray;
2728
class DataType;
2829
class Status;
2930

python/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,12 @@ if(UNIX)
370370
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
371371
endif()
372372

373-
set(CYTHON_EXTENSIONS lib _fs _csv _json)
373+
set(CYTHON_EXTENSIONS
374+
lib
375+
_fs
376+
_csv
377+
_json
378+
_compute)
374379

375380
set(LINK_LIBS arrow_shared arrow_python_shared)
376381

python/pyarrow/_compute.pyx

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# cython: language_level = 3
19+
20+
from pyarrow.lib cimport (
21+
Array,
22+
wrap_datum,
23+
_context,
24+
check_status,
25+
ChunkedArray
26+
)
27+
from pyarrow.includes.libarrow cimport CDatum, Sum
28+
29+
from typing import Union
30+
31+
32+
cdef _sum_array(array: Array):
33+
cdef CDatum out
34+
35+
with nogil:
36+
check_status(Sum(_context(), CDatum(array.sp_array), &out))
37+
38+
return wrap_datum(out)
39+
40+
41+
cdef _sum_chunked_array(array: ChunkedArray):
42+
cdef CDatum out
43+
44+
with nogil:
45+
check_status(Sum(_context(), CDatum(array.sp_chunked_array), &out))
46+
47+
return wrap_datum(out)
48+
49+
50+
def sum(array: Union[Array, ChunkedArray]):
51+
"""
52+
Sum the values in a numerical (chunked) array.
53+
"""
54+
if isinstance(array, Array):
55+
return _sum_array(array)
56+
elif isinstance(array, ChunkedArray):
57+
return _sum_chunked_array(array)
58+
else:
59+
raise ValueError(
60+
"Only pyarrow.Array and pyarrow.ChunkedArray supported as"
61+
" an input, passed {}".format(type(array))
62+
)

python/pyarrow/compute.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import absolute_import
19+
20+
from pyarrow._compute import ( # noqa
21+
sum
22+
)

python/pyarrow/lib.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cdef extern from "Python.h":
3030
int PySlice_Check(object)
3131

3232

33+
cdef CFunctionContext* _context() nogil
3334
cdef int check_status(const CStatus& status) nogil except -1
3435

3536
cdef class Message:
@@ -429,6 +430,7 @@ cdef class ExtensionArray(Array):
429430

430431

431432
cdef wrap_array_output(PyObject* output)
433+
cdef wrap_datum(const CDatum& datum)
432434
cdef object box_scalar(DataType type,
433435
const shared_ptr[CArray]& sp_array,
434436
int64_t index)

python/pyarrow/tests/test_compute.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121

2222
import pyarrow as pa
23+
import pyarrow.compute
2324

2425

2526
all_array_types = [
@@ -44,7 +45,7 @@
4445
]
4546

4647

47-
@pytest.mark.parametrize('arrow_type', [
48+
numerical_arrow_types = [
4849
pa.int8(),
4950
pa.int16(),
5051
pa.int64(),
@@ -53,10 +54,32 @@
5354
pa.uint64(),
5455
pa.float32(),
5556
pa.float64()
56-
])
57-
def test_sum(arrow_type):
57+
]
58+
59+
60+
@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
61+
def test_sum_array(arrow_type):
5862
arr = pa.array([1, 2, 3, 4], type=arrow_type)
5963
assert arr.sum() == 10
64+
assert pa.compute.sum(arr) == 10
65+
66+
67+
@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
68+
def test_sum_chunked_array(arrow_type):
69+
arr = pa.chunked_array([pa.array([1, 2, 3, 4], type=arrow_type)])
70+
assert pa.compute.sum(arr) == 10
71+
72+
arr = pa.chunked_array([
73+
pa.array([1, 2], type=arrow_type), pa.array([3, 4], type=arrow_type)
74+
])
75+
assert pa.compute.sum(arr) == 10
76+
77+
arr = pa.chunked_array([
78+
pa.array([1, 2], type=arrow_type),
79+
pa.array([], type=arrow_type),
80+
pa.array([3, 4], type=arrow_type)
81+
])
82+
assert pa.compute.sum(arr) == 10
6083

6184

6285
@pytest.mark.parametrize(('ty', 'values'), all_array_types)

python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def initialize_options(self):
179179
'_csv',
180180
'_json',
181181
'_cuda',
182+
'_compute',
182183
'_flight',
183184
'_dataset',
184185
'_parquet',

0 commit comments

Comments
 (0)