Skip to content

Commit 27dded6

Browse files
xhochypitrou
andcommitted
ARROW-7493: [Python] Expose sum kernel in pyarrow.compute and support ChunkedArray inputs
This only exposes the `Sum` kernel, I will do more once this PR got review and is merged. Closes #6123 from xhochy/ARROW-7493 and squashes the following commits: f65ba06 <Antoine Pitrou> Test summing a ChunkedArray with 0 chunks 7c8d526 <Antoine Pitrou> Nits 8fa07b0 <Uwe L. Korn> Revert back to explicit json variable 1575a74 <Uwe L. Korn> Review d03af5f <Uwe L. Korn> s/int64/int ce2b974 <Uwe L. Korn> Don't use 3.6+ type annotations d0b25ed <Uwe L. Korn> ARROW-7493: Expose sum kernel in pyarrow.compute and support ChunkedArray inputs Lead-authored-by: Uwe L. Korn <[email protected]> Co-authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent 38bf178 commit 27dded6

File tree

12 files changed

+205
-22
lines changed

12 files changed

+205
-22
lines changed

cpp/src/arrow/compute/kernels/aggregate.cc

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,33 @@ class ManagedAggregateState {
5050
};
5151

5252
Status AggregateUnaryKernel::Call(FunctionContext* ctx, const Datum& input, Datum* out) {
53-
if (!input.is_array()) return Status::Invalid("AggregateKernel expects Array datum");
54-
53+
if (!input.is_arraylike()) {
54+
return Status::Invalid("AggregateKernel expects Array or ChunkedArray datum");
55+
}
5556
auto state = ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool());
56-
if (!state) return Status::OutOfMemory("AggregateState allocation failed");
57+
if (!state) {
58+
return Status::OutOfMemory("AggregateState allocation failed");
59+
}
5760

58-
auto array = input.make_array();
59-
RETURN_NOT_OK(aggregate_function_->Consume(*array, state->mutable_data()));
60-
RETURN_NOT_OK(aggregate_function_->Finalize(state->mutable_data(), out));
61+
if (input.is_array()) {
62+
auto array = input.make_array();
63+
RETURN_NOT_OK(aggregate_function_->Consume(*array, state->mutable_data()));
64+
} else {
65+
auto chunked_array = input.chunked_array();
66+
for (int i = 0; i < chunked_array->num_chunks(); i++) {
67+
auto tmp_state =
68+
ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool());
69+
if (!tmp_state) {
70+
return Status::OutOfMemory("AggregateState allocation failed");
71+
}
72+
RETURN_NOT_OK(aggregate_function_->Consume(*chunked_array->chunk(i),
73+
tmp_state->mutable_data()));
74+
RETURN_NOT_OK(
75+
aggregate_function_->Merge(tmp_state->mutable_data(), state->mutable_data()));
76+
}
77+
}
6178

62-
return Status::OK();
79+
return aggregate_function_->Finalize(state->mutable_data(), out);
6380
}
6481

6582
std::shared_ptr<DataType> AggregateUnaryKernel::out_type() const {

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,29 @@ void ValidateSum(FunctionContext* ctx, const Array& input, Datum expected) {
9696
DatumEqual<OutputType>::EnsureEqual(result, expected);
9797
}
9898

99+
template <typename ArrowType>
100+
void ValidateSum(FunctionContext* ctx, const std::shared_ptr<ChunkedArray>& input,
101+
Datum expected) {
102+
using OutputType = typename FindAccumulatorType<ArrowType>::Type;
103+
104+
Datum result;
105+
ASSERT_OK(Sum(ctx, input, &result));
106+
DatumEqual<OutputType>::EnsureEqual(result, expected);
107+
}
108+
99109
template <typename ArrowType>
100110
void ValidateSum(FunctionContext* ctx, const char* json, Datum expected) {
101111
auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
102112
ValidateSum<ArrowType>(ctx, *array, expected);
103113
}
104114

115+
template <typename ArrowType>
116+
void ValidateSum(FunctionContext* ctx, const std::vector<std::string>& json,
117+
Datum expected) {
118+
auto array = ChunkedArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
119+
ValidateSum<ArrowType>(ctx, array, expected);
120+
}
121+
105122
template <typename ArrowType>
106123
void ValidateSum(FunctionContext* ctx, const Array& array) {
107124
ValidateSum<ArrowType>(ctx, array, NaiveSum<ArrowType>(array));
@@ -123,6 +140,22 @@ TYPED_TEST(TestNumericSumKernel, SimpleSum) {
123140
ValidateSum<TypeParam>(&this->ctx_, "[0, 1, 2, 3, 4, 5]",
124141
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
125142

143+
std::vector<std::string> chunks = {"[0, 1, 2, 3, 4, 5]"};
144+
ValidateSum<TypeParam>(&this->ctx_, chunks,
145+
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
146+
147+
chunks = {"[0, 1, 2]", "[3, 4, 5]"};
148+
ValidateSum<TypeParam>(&this->ctx_, chunks,
149+
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
150+
151+
chunks = {"[0, 1, 2]", "[]", "[3, 4, 5]"};
152+
ValidateSum<TypeParam>(&this->ctx_, chunks,
153+
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
154+
155+
chunks = {};
156+
ValidateSum<TypeParam>(&this->ctx_, chunks,
157+
Datum(std::make_shared<ScalarType>())); // null
158+
126159
const T expected_result = static_cast<T>(14);
127160
ValidateSum<TypeParam>(&this->ctx_, "[1, null, 3, null, 3, null, 7]",
128161
Datum(std::make_shared<ScalarType>(expected_result)));

cpp/src/arrow/compute/kernels/sum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
namespace arrow {
2525

2626
class Array;
27+
class ChunkedArray;
2728
class DataType;
2829
class Status;
2930

cpp/src/arrow/testing/gtest_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ std::shared_ptr<ChunkedArray> ChunkedArrayFromJSON(const std::shared_ptr<DataTyp
231231
for (const std::string& chunk_json : json) {
232232
out_chunks.push_back(ArrayFromJSON(type, chunk_json));
233233
}
234-
return std::make_shared<ChunkedArray>(std::move(out_chunks));
234+
return std::make_shared<ChunkedArray>(std::move(out_chunks), type);
235235
}
236236

237237
std::shared_ptr<RecordBatch> RecordBatchFromJSON(const std::shared_ptr<Schema>& schema,

python/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,12 @@ if(UNIX)
370370
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
371371
endif()
372372

373-
set(CYTHON_EXTENSIONS lib _fs _csv _json)
373+
set(CYTHON_EXTENSIONS
374+
lib
375+
_fs
376+
_csv
377+
_json
378+
_compute)
374379

375380
set(LINK_LIBS arrow_shared arrow_python_shared)
376381

python/pyarrow/_compute.pyx

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# cython: language_level = 3
19+
20+
from pyarrow.lib cimport (
21+
Array,
22+
wrap_datum,
23+
_context,
24+
check_status,
25+
ChunkedArray
26+
)
27+
from pyarrow.includes.libarrow cimport CDatum, Sum
28+
29+
30+
cdef _sum_array(array: Array):
31+
cdef CDatum out
32+
33+
with nogil:
34+
check_status(Sum(_context(), CDatum(array.sp_array), &out))
35+
36+
return wrap_datum(out)
37+
38+
39+
cdef _sum_chunked_array(array: ChunkedArray):
40+
cdef CDatum out
41+
42+
with nogil:
43+
check_status(Sum(_context(), CDatum(array.sp_chunked_array), &out))
44+
45+
return wrap_datum(out)
46+
47+
48+
def sum(array):
49+
"""
50+
Sum the values in a numerical (chunked) array.
51+
52+
Parameters
53+
----------
54+
array : pyarrow.Array or pyarrow.ChunkedArray
55+
56+
Returns
57+
-------
58+
sum : pyarrow.Scalar
59+
"""
60+
if isinstance(array, Array):
61+
return _sum_array(array)
62+
elif isinstance(array, ChunkedArray):
63+
return _sum_chunked_array(array)
64+
else:
65+
raise ValueError(
66+
"Only pyarrow.Array and pyarrow.ChunkedArray supported as"
67+
" an input, passed {}".format(type(array))
68+
)

python/pyarrow/compute.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import absolute_import
19+
20+
from pyarrow._compute import ( # noqa
21+
sum
22+
)

python/pyarrow/includes/libarrow.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
762762

763763
cdef cppclass CScalar" arrow::Scalar":
764764
shared_ptr[CDataType] type
765+
c_bool is_valid
765766

766767
cdef cppclass CInt8Scalar" arrow::Int8Scalar"(CScalar):
767768
int8_t value

python/pyarrow/lib.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cdef extern from "Python.h":
3030
int PySlice_Check(object)
3131

3232

33+
cdef CFunctionContext* _context() nogil
3334
cdef int check_status(const CStatus& status) nogil except -1
3435

3536
cdef class Message:
@@ -429,6 +430,7 @@ cdef class ExtensionArray(Array):
429430

430431

431432
cdef wrap_array_output(PyObject* output)
433+
cdef wrap_datum(const CDatum& datum)
432434
cdef object box_scalar(DataType type,
433435
const shared_ptr[CArray]& sp_array,
434436
int64_t index)

python/pyarrow/scalar.pxi

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,7 @@ cdef class ScalarValue(Scalar):
10311031
def __hash__(self):
10321032
return hash(self.as_py())
10331033

1034+
10341035
cdef class UInt8Scalar(ScalarValue):
10351036
"""
10361037
Concrete class for uint8 scalars.
@@ -1041,7 +1042,7 @@ cdef class UInt8Scalar(ScalarValue):
10411042
Return this value as a Python int.
10421043
"""
10431044
cdef CUInt8Scalar* sp = <CUInt8Scalar*> self.sp_scalar.get()
1044-
return sp.value
1045+
return sp.value if sp.is_valid else None
10451046

10461047

10471048
cdef class Int8Scalar(ScalarValue):
@@ -1054,7 +1055,7 @@ cdef class Int8Scalar(ScalarValue):
10541055
Return this value as a Python int.
10551056
"""
10561057
cdef CInt8Scalar* sp = <CInt8Scalar*> self.sp_scalar.get()
1057-
return sp.value
1058+
return sp.value if sp.is_valid else None
10581059

10591060

10601061
cdef class UInt16Scalar(ScalarValue):
@@ -1067,7 +1068,7 @@ cdef class UInt16Scalar(ScalarValue):
10671068
Return this value as a Python int.
10681069
"""
10691070
cdef CUInt16Scalar* sp = <CUInt16Scalar*> self.sp_scalar.get()
1070-
return sp.value
1071+
return sp.value if sp.is_valid else None
10711072

10721073

10731074
cdef class Int16Scalar(ScalarValue):
@@ -1080,7 +1081,7 @@ cdef class Int16Scalar(ScalarValue):
10801081
Return this value as a Python int.
10811082
"""
10821083
cdef CInt16Scalar* sp = <CInt16Scalar*> self.sp_scalar.get()
1083-
return sp.value
1084+
return sp.value if sp.is_valid else None
10841085

10851086

10861087
cdef class UInt32Scalar(ScalarValue):
@@ -1093,7 +1094,7 @@ cdef class UInt32Scalar(ScalarValue):
10931094
Return this value as a Python int.
10941095
"""
10951096
cdef CUInt32Scalar* sp = <CUInt32Scalar*> self.sp_scalar.get()
1096-
return sp.value
1097+
return sp.value if sp.is_valid else None
10971098

10981099

10991100
cdef class Int32Scalar(ScalarValue):
@@ -1106,7 +1107,7 @@ cdef class Int32Scalar(ScalarValue):
11061107
Return this value as a Python int.
11071108
"""
11081109
cdef CInt32Scalar* sp = <CInt32Scalar*> self.sp_scalar.get()
1109-
return sp.value
1110+
return sp.value if sp.is_valid else None
11101111

11111112

11121113
cdef class UInt64Scalar(ScalarValue):
@@ -1119,7 +1120,7 @@ cdef class UInt64Scalar(ScalarValue):
11191120
Return this value as a Python int.
11201121
"""
11211122
cdef CUInt64Scalar* sp = <CUInt64Scalar*> self.sp_scalar.get()
1122-
return sp.value
1123+
return sp.value if sp.is_valid else None
11231124

11241125

11251126
cdef class Int64Scalar(ScalarValue):
@@ -1132,7 +1133,7 @@ cdef class Int64Scalar(ScalarValue):
11321133
Return this value as a Python int.
11331134
"""
11341135
cdef CInt64Scalar* sp = <CInt64Scalar*> self.sp_scalar.get()
1135-
return sp.value
1136+
return sp.value if sp.is_valid else None
11361137

11371138

11381139
cdef class FloatScalar(ScalarValue):
@@ -1145,7 +1146,7 @@ cdef class FloatScalar(ScalarValue):
11451146
Return this value as a Python float.
11461147
"""
11471148
cdef CFloatScalar* sp = <CFloatScalar*> self.sp_scalar.get()
1148-
return sp.value
1149+
return sp.value if sp.is_valid else None
11491150

11501151

11511152
cdef class DoubleScalar(ScalarValue):
@@ -1158,7 +1159,7 @@ cdef class DoubleScalar(ScalarValue):
11581159
Return this value as a Python float.
11591160
"""
11601161
cdef CDoubleScalar* sp = <CDoubleScalar*> self.sp_scalar.get()
1161-
return sp.value
1162+
return sp.value if sp.is_valid else None
11621163

11631164

11641165
cdef dict _scalar_classes = {

python/pyarrow/tests/test_compute.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121

2222
import pyarrow as pa
23+
import pyarrow.compute
2324

2425

2526
all_array_types = [
@@ -44,7 +45,7 @@
4445
]
4546

4647

47-
@pytest.mark.parametrize('arrow_type', [
48+
numerical_arrow_types = [
4849
pa.int8(),
4950
pa.int16(),
5051
pa.int64(),
@@ -53,10 +54,41 @@
5354
pa.uint64(),
5455
pa.float32(),
5556
pa.float64()
56-
])
57-
def test_sum(arrow_type):
57+
]
58+
59+
60+
@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
61+
def test_sum_array(arrow_type):
5862
arr = pa.array([1, 2, 3, 4], type=arrow_type)
5963
assert arr.sum() == 10
64+
assert pa.compute.sum(arr) == 10
65+
66+
arr = pa.array([], type=arrow_type)
67+
assert arr.sum() == None # noqa: E711
68+
assert pa.compute.sum(arr) == None # noqa: E711
69+
70+
71+
@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
72+
def test_sum_chunked_array(arrow_type):
73+
arr = pa.chunked_array([pa.array([1, 2, 3, 4], type=arrow_type)])
74+
assert pa.compute.sum(arr) == 10
75+
76+
arr = pa.chunked_array([
77+
pa.array([1, 2], type=arrow_type), pa.array([3, 4], type=arrow_type)
78+
])
79+
assert pa.compute.sum(arr) == 10
80+
81+
arr = pa.chunked_array([
82+
pa.array([1, 2], type=arrow_type),
83+
pa.array([], type=arrow_type),
84+
pa.array([3, 4], type=arrow_type)
85+
])
86+
assert pa.compute.sum(arr) == 10
87+
88+
arr = pa.chunked_array((), type=arrow_type)
89+
print(arr, type(arr))
90+
assert arr.num_chunks == 0
91+
assert pa.compute.sum(arr) == None # noqa: E711
6092

6193

6294
@pytest.mark.parametrize(('ty', 'values'), all_array_types)

0 commit comments

Comments
 (0)