Skip to content

Commit 1575a74

Browse files
xhochypitrou
authored andcommitted
Review
1 parent d03af5f commit 1575a74

File tree

3 files changed

+30
-27
lines changed

3 files changed

+30
-27
lines changed

cpp/src/arrow/compute/kernels/aggregate.cc

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,31 +50,29 @@ class ManagedAggregateState {
5050
};
5151

5252
Status AggregateUnaryKernel::Call(FunctionContext* ctx, const Datum& input, Datum* out) {
53-
if (input.is_arraylike()) {
54-
auto state = ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool());
55-
if (!state) return Status::OutOfMemory("AggregateState allocation failed");
53+
if (!input.is_arraylike()) {
54+
return Status::Invalid("AggregateKernel expects Array or ChunkedArray datum");
55+
}
56+
auto state = ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool());
57+
if (!state) return Status::OutOfMemory("AggregateState allocation failed");
5658

57-
if (input.is_array()) {
58-
auto array = input.make_array();
59-
RETURN_NOT_OK(aggregate_function_->Consume(*array, state->mutable_data()));
60-
} else {
61-
auto chunked_array = input.chunked_array();
62-
for (int i = 0; i < chunked_array->num_chunks(); i++) {
63-
auto tmp_state =
64-
ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool());
65-
if (!tmp_state) return Status::OutOfMemory("AggregateState allocation failed");
66-
RETURN_NOT_OK(aggregate_function_->Consume(*chunked_array->chunk(i),
67-
tmp_state->mutable_data()));
68-
RETURN_NOT_OK(
69-
aggregate_function_->Merge(tmp_state->mutable_data(), state->mutable_data()));
70-
}
71-
}
72-
RETURN_NOT_OK(aggregate_function_->Finalize(state->mutable_data(), out));
59+
if (input.is_array()) {
60+
auto array = input.make_array();
61+
RETURN_NOT_OK(aggregate_function_->Consume(*array, state->mutable_data()));
7362
} else {
74-
return Status::Invalid("AggregateKernel expects Array or ChunkedArray datum");
63+
auto chunked_array = input.chunked_array();
64+
for (int i = 0; i < chunked_array->num_chunks(); i++) {
65+
auto tmp_state =
66+
ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool());
67+
if (!tmp_state) return Status::OutOfMemory("AggregateState allocation failed");
68+
RETURN_NOT_OK(aggregate_function_->Consume(*chunked_array->chunk(i),
69+
tmp_state->mutable_data()));
70+
RETURN_NOT_OK(
71+
aggregate_function_->Merge(tmp_state->mutable_data(), state->mutable_data()));
72+
}
7573
}
7674

77-
return Status::OK();
75+
return aggregate_function_->Finalize(state->mutable_data(), out);
7876
}
7977

8078
std::shared_ptr<DataType> AggregateUnaryKernel::out_type() const {

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,13 @@ TYPED_TEST(TestNumericSumKernel, SimpleSum) {
140140
ValidateSum<TypeParam>(&this->ctx_, "[0, 1, 2, 3, 4, 5]",
141141
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
142142

143-
std::vector<std::string> json{"[0, 1, 2, 3, 4, 5]"};
144-
ValidateSum<TypeParam>(&this->ctx_, json,
143+
ValidateSum<TypeParam>(&this->ctx_, {"[0, 1, 2, 3, 4, 5]"},
145144
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
146145

147-
json = {"[0, 1, 2]", "[3, 4, 5]"};
148-
ValidateSum<TypeParam>(&this->ctx_, json,
146+
ValidateSum<TypeParam>(&this->ctx_, {"[0, 1, 2]", "[3, 4, 5]"},
149147
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
150148

151-
json = {"[0, 1, 2]", "[]", "[3, 4, 5]"};
152-
ValidateSum<TypeParam>(&this->ctx_, json,
149+
ValidateSum<TypeParam>(&this->ctx_, {"[0, 1, 2]", "[]", "[3, 4, 5]"},
153150
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
154151

155152
const T expected_result = static_cast<T>(14);

python/pyarrow/_compute.pyx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ cdef _sum_chunked_array(array: ChunkedArray):
4848
def sum(array):
4949
"""
5050
Sum the values in a numerical (chunked) array.
51+
52+
Parameters
53+
----------
54+
array : pyarrow.Array or pyarrow.ChunkedArray
55+
56+
Returns
57+
-------
58+
sum : pyarrow.Scalar
5159
"""
5260
if isinstance(array, Array):
5361
return _sum_array(array)

0 commit comments

Comments
 (0)