Skip to content

Commit f65ba06

Browse files
committed
Test summing a ChunkedArray with 0 chunks
1 parent 7c8d526 commit f65ba06

File tree

5 files changed

+26
-11
lines changed

5 files changed

+26
-11
lines changed

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ TYPED_TEST(TestNumericSumKernel, SimpleSum) {
152152
ValidateSum<TypeParam>(&this->ctx_, chunks,
153153
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
154154

155+
chunks = {};
156+
ValidateSum<TypeParam>(&this->ctx_, chunks,
157+
Datum(std::make_shared<ScalarType>())); // null
158+
155159
const T expected_result = static_cast<T>(14);
156160
ValidateSum<TypeParam>(&this->ctx_, "[1, null, 3, null, 3, null, 7]",
157161
Datum(std::make_shared<ScalarType>(expected_result)));

cpp/src/arrow/testing/gtest_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ std::shared_ptr<ChunkedArray> ChunkedArrayFromJSON(const std::shared_ptr<DataTyp
231231
for (const std::string& chunk_json : json) {
232232
out_chunks.push_back(ArrayFromJSON(type, chunk_json));
233233
}
234-
return std::make_shared<ChunkedArray>(std::move(out_chunks));
234+
return std::make_shared<ChunkedArray>(std::move(out_chunks), type);
235235
}
236236

237237
std::shared_ptr<RecordBatch> RecordBatchFromJSON(const std::shared_ptr<Schema>& schema,

python/pyarrow/includes/libarrow.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
762762

763763
cdef cppclass CScalar" arrow::Scalar":
764764
shared_ptr[CDataType] type
765+
c_bool is_valid
765766

766767
cdef cppclass CInt8Scalar" arrow::Int8Scalar"(CScalar):
767768
int8_t value

python/pyarrow/scalar.pxi

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,7 @@ cdef class ScalarValue(Scalar):
10311031
def __hash__(self):
10321032
return hash(self.as_py())
10331033

1034+
10341035
cdef class UInt8Scalar(ScalarValue):
10351036
"""
10361037
Concrete class for uint8 scalars.
@@ -1041,7 +1042,7 @@ cdef class UInt8Scalar(ScalarValue):
10411042
Return this value as a Python int.
10421043
"""
10431044
cdef CUInt8Scalar* sp = <CUInt8Scalar*> self.sp_scalar.get()
1044-
return sp.value
1045+
return sp.value if sp.is_valid else None
10451046

10461047

10471048
cdef class Int8Scalar(ScalarValue):
@@ -1054,7 +1055,7 @@ cdef class Int8Scalar(ScalarValue):
10541055
Return this value as a Python int.
10551056
"""
10561057
cdef CInt8Scalar* sp = <CInt8Scalar*> self.sp_scalar.get()
1057-
return sp.value
1058+
return sp.value if sp.is_valid else None
10581059

10591060

10601061
cdef class UInt16Scalar(ScalarValue):
@@ -1067,7 +1068,7 @@ cdef class UInt16Scalar(ScalarValue):
10671068
Return this value as a Python int.
10681069
"""
10691070
cdef CUInt16Scalar* sp = <CUInt16Scalar*> self.sp_scalar.get()
1070-
return sp.value
1071+
return sp.value if sp.is_valid else None
10711072

10721073

10731074
cdef class Int16Scalar(ScalarValue):
@@ -1080,7 +1081,7 @@ cdef class Int16Scalar(ScalarValue):
10801081
Return this value as a Python int.
10811082
"""
10821083
cdef CInt16Scalar* sp = <CInt16Scalar*> self.sp_scalar.get()
1083-
return sp.value
1084+
return sp.value if sp.is_valid else None
10841085

10851086

10861087
cdef class UInt32Scalar(ScalarValue):
@@ -1093,7 +1094,7 @@ cdef class UInt32Scalar(ScalarValue):
10931094
Return this value as a Python int.
10941095
"""
10951096
cdef CUInt32Scalar* sp = <CUInt32Scalar*> self.sp_scalar.get()
1096-
return sp.value
1097+
return sp.value if sp.is_valid else None
10971098

10981099

10991100
cdef class Int32Scalar(ScalarValue):
@@ -1106,7 +1107,7 @@ cdef class Int32Scalar(ScalarValue):
11061107
Return this value as a Python int.
11071108
"""
11081109
cdef CInt32Scalar* sp = <CInt32Scalar*> self.sp_scalar.get()
1109-
return sp.value
1110+
return sp.value if sp.is_valid else None
11101111

11111112

11121113
cdef class UInt64Scalar(ScalarValue):
@@ -1119,7 +1120,7 @@ cdef class UInt64Scalar(ScalarValue):
11191120
Return this value as a Python int.
11201121
"""
11211122
cdef CUInt64Scalar* sp = <CUInt64Scalar*> self.sp_scalar.get()
1122-
return sp.value
1123+
return sp.value if sp.is_valid else None
11231124

11241125

11251126
cdef class Int64Scalar(ScalarValue):
@@ -1132,7 +1133,7 @@ cdef class Int64Scalar(ScalarValue):
11321133
Return this value as a Python int.
11331134
"""
11341135
cdef CInt64Scalar* sp = <CInt64Scalar*> self.sp_scalar.get()
1135-
return sp.value
1136+
return sp.value if sp.is_valid else None
11361137

11371138

11381139
cdef class FloatScalar(ScalarValue):
@@ -1145,7 +1146,7 @@ cdef class FloatScalar(ScalarValue):
11451146
Return this value as a Python float.
11461147
"""
11471148
cdef CFloatScalar* sp = <CFloatScalar*> self.sp_scalar.get()
1148-
return sp.value
1149+
return sp.value if sp.is_valid else None
11491150

11501151

11511152
cdef class DoubleScalar(ScalarValue):
@@ -1158,7 +1159,7 @@ cdef class DoubleScalar(ScalarValue):
11581159
Return this value as a Python float.
11591160
"""
11601161
cdef CDoubleScalar* sp = <CDoubleScalar*> self.sp_scalar.get()
1161-
return sp.value
1162+
return sp.value if sp.is_valid else None
11621163

11631164

11641165
cdef dict _scalar_classes = {

python/pyarrow/tests/test_compute.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def test_sum_array(arrow_type):
6363
assert arr.sum() == 10
6464
assert pa.compute.sum(arr) == 10
6565

66+
arr = pa.array([], type=arrow_type)
67+
assert arr.sum() == None # noqa: E711
68+
assert pa.compute.sum(arr) == None # noqa: E711
69+
6670

6771
@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
6872
def test_sum_chunked_array(arrow_type):
@@ -81,6 +85,11 @@ def test_sum_chunked_array(arrow_type):
8185
])
8286
assert pa.compute.sum(arr) == 10
8387

88+
arr = pa.chunked_array((), type=arrow_type)
89+
print(arr, type(arr))
90+
assert arr.num_chunks == 0
91+
assert pa.compute.sum(arr) == None # noqa: E711
92+
8493

8594
@pytest.mark.parametrize(('ty', 'values'), all_array_types)
8695
def test_take(ty, values):

0 commit comments

Comments
 (0)