Skip to content

Commit bfb2a05

Browse files
rokjorisvandenbosschebenibus
authored andcommitted
apacheGH-15483: [C++] Add a Fixed Shape Tensor canonical ExtensionType (apache#8510)
> [ARROW-1614](https://issues.apache.org/jira/browse/ARROW-1614): In an Arrow table, we would like to add support for a column that has values cells each containing a tensor value, with all tensors having the same dimensions. These would be stored as a binary value, plus some metadata to store type and shape/strides. * Closes: apache#15483 Lead-authored-by: Rok Mihevc <[email protected]> Co-authored-by: Rok <[email protected]> Co-authored-by: Joris Van den Bossche <[email protected]> Co-authored-by: Ben Harkins <[email protected]> Signed-off-by: Joris Van den Bossche <[email protected]>
1 parent f3d84f5 commit bfb2a05

File tree

6 files changed

+515
-0
lines changed

6 files changed

+515
-0
lines changed

cpp/src/arrow/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@ endif()
520520
if(ARROW_JSON)
521521
list(APPEND
522522
ARROW_SRCS
523+
extension/fixed_shape_tensor.cc
523524
json/options.cc
524525
json/chunked_builder.cc
525526
json/chunker.cc
@@ -856,6 +857,7 @@ endif()
856857

857858
if(ARROW_JSON)
858859
add_subdirectory(json)
860+
add_subdirectory(extension)
859861
endif()
860862

861863
if(ARROW_ORC)
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
add_arrow_test(test
19+
SOURCES
20+
fixed_shape_tensor_test.cc
21+
PREFIX
22+
"arrow-fixed-shape-tensor")
23+
24+
arrow_install_all_headers("arrow/extension")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include <numeric>
19+
#include <sstream>
20+
21+
#include "arrow/extension/fixed_shape_tensor.h"
22+
23+
#include "arrow/array/array_nested.h"
24+
#include "arrow/array/array_primitive.h"
25+
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
26+
#include "arrow/util/int_util_overflow.h"
27+
#include "arrow/util/logging.h"
28+
#include "arrow/util/sort.h"
29+
30+
#include <rapidjson/document.h>
31+
#include <rapidjson/writer.h>
32+
33+
namespace rj = arrow::rapidjson;
34+
35+
namespace arrow {
36+
namespace extension {
37+
38+
bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
39+
if (extension_name() != other.extension_name()) {
40+
return false;
41+
}
42+
const auto& other_ext = static_cast<const FixedShapeTensorType&>(other);
43+
44+
auto is_permutation_trivial = [](const std::vector<int64_t>& permutation) {
45+
for (size_t i = 1; i < permutation.size(); ++i) {
46+
if (permutation[i - 1] + 1 != permutation[i]) {
47+
return false;
48+
}
49+
}
50+
return true;
51+
};
52+
const bool permutation_equivalent =
53+
((permutation_ == other_ext.permutation()) ||
54+
(permutation_.empty() && is_permutation_trivial(other_ext.permutation())) ||
55+
(is_permutation_trivial(permutation_) && other_ext.permutation().empty()));
56+
57+
return (storage_type()->Equals(other_ext.storage_type())) &&
58+
(this->shape() == other_ext.shape()) && (dim_names_ == other_ext.dim_names()) &&
59+
permutation_equivalent;
60+
}
61+
62+
std::string FixedShapeTensorType::Serialize() const {
63+
rj::Document document;
64+
document.SetObject();
65+
rj::Document::AllocatorType& allocator = document.GetAllocator();
66+
67+
rj::Value shape(rj::kArrayType);
68+
for (auto v : shape_) {
69+
shape.PushBack(v, allocator);
70+
}
71+
document.AddMember(rj::Value("shape", allocator), shape, allocator);
72+
73+
if (!permutation_.empty()) {
74+
rj::Value permutation(rj::kArrayType);
75+
for (auto v : permutation_) {
76+
permutation.PushBack(v, allocator);
77+
}
78+
document.AddMember(rj::Value("permutation", allocator), permutation, allocator);
79+
}
80+
81+
if (!dim_names_.empty()) {
82+
rj::Value dim_names(rj::kArrayType);
83+
for (std::string v : dim_names_) {
84+
dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator);
85+
}
86+
document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator);
87+
}
88+
89+
rj::StringBuffer buffer;
90+
rj::Writer<rj::StringBuffer> writer(buffer);
91+
document.Accept(writer);
92+
return buffer.GetString();
93+
}
94+
95+
Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
96+
std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const {
97+
if (storage_type->id() != Type::FIXED_SIZE_LIST) {
98+
return Status::Invalid("Expected FixedSizeList storage type, got ",
99+
storage_type->ToString());
100+
}
101+
auto value_type =
102+
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
103+
rj::Document document;
104+
if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() ||
105+
!document.HasMember("shape") || !document["shape"].IsArray()) {
106+
return Status::Invalid("Invalid serialized JSON data: ", serialized_data);
107+
}
108+
109+
std::vector<int64_t> shape;
110+
for (auto& x : document["shape"].GetArray()) {
111+
shape.emplace_back(x.GetInt64());
112+
}
113+
std::vector<int64_t> permutation;
114+
if (document.HasMember("permutation")) {
115+
for (auto& x : document["permutation"].GetArray()) {
116+
permutation.emplace_back(x.GetInt64());
117+
}
118+
if (shape.size() != permutation.size()) {
119+
return Status::Invalid("Invalid permutation");
120+
}
121+
}
122+
std::vector<std::string> dim_names;
123+
if (document.HasMember("dim_names")) {
124+
for (auto& x : document["dim_names"].GetArray()) {
125+
dim_names.emplace_back(x.GetString());
126+
}
127+
if (shape.size() != dim_names.size()) {
128+
return Status::Invalid("Invalid dim_names");
129+
}
130+
}
131+
132+
return fixed_shape_tensor(value_type, shape, permutation, dim_names);
133+
}
134+
135+
std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
136+
std::shared_ptr<ArrayData> data) const {
137+
DCHECK_EQ(data->type->id(), Type::EXTENSION);
138+
DCHECK_EQ("arrow.fixed_shape_tensor",
139+
static_cast<const ExtensionType&>(*data->type).extension_name());
140+
return std::make_shared<ExtensionArray>(data);
141+
}
142+
143+
Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
144+
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
145+
const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
146+
if (!permutation.empty() && shape.size() != permutation.size()) {
147+
return Status::Invalid("permutation size must match shape size. Expected: ",
148+
shape.size(), " Got: ", permutation.size());
149+
}
150+
if (!dim_names.empty() && shape.size() != dim_names.size()) {
151+
return Status::Invalid("dim_names size must match shape size. Expected: ",
152+
shape.size(), " Got: ", dim_names.size());
153+
}
154+
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
155+
std::multiplies<>());
156+
return std::make_shared<FixedShapeTensorType>(value_type, static_cast<int32_t>(size),
157+
shape, permutation, dim_names);
158+
}
159+
160+
std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& value_type,
161+
const std::vector<int64_t>& shape,
162+
const std::vector<int64_t>& permutation,
163+
const std::vector<std::string>& dim_names) {
164+
auto maybe_type = FixedShapeTensorType::Make(value_type, shape, permutation, dim_names);
165+
ARROW_DCHECK_OK(maybe_type.status());
166+
return maybe_type.MoveValueUnsafe();
167+
}
168+
169+
} // namespace extension
170+
} // namespace arrow
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "arrow/extension_type.h"
19+
20+
namespace arrow {
21+
namespace extension {
22+
23+
class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
24+
public:
25+
using ExtensionArray::ExtensionArray;
26+
};
27+
28+
/// \brief Concrete type class for constant-size Tensor data.
29+
/// This is a canonical arrow extension type.
30+
/// See: https://arrow.apache.org/docs/format/CanonicalExtensions.html
31+
class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
32+
public:
33+
FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, const int32_t& size,
34+
const std::vector<int64_t>& shape,
35+
const std::vector<int64_t>& permutation = {},
36+
const std::vector<std::string>& dim_names = {})
37+
: ExtensionType(fixed_size_list(value_type, size)),
38+
value_type_(value_type),
39+
shape_(shape),
40+
permutation_(permutation),
41+
dim_names_(dim_names) {}
42+
43+
std::string extension_name() const override { return "arrow.fixed_shape_tensor"; }
44+
45+
/// Number of dimensions of tensor elements
46+
size_t ndim() { return shape_.size(); }
47+
48+
/// Shape of tensor elements
49+
const std::vector<int64_t> shape() const { return shape_; }
50+
51+
/// Value type of tensor elements
52+
const std::shared_ptr<DataType> value_type() const { return value_type_; }
53+
54+
/// Permutation mapping from logical to physical memory layout of tensor elements
55+
const std::vector<int64_t>& permutation() const { return permutation_; }
56+
57+
/// Dimension names of tensor elements. Dimensions are ordered physically.
58+
const std::vector<std::string>& dim_names() const { return dim_names_; }
59+
60+
bool ExtensionEquals(const ExtensionType& other) const override;
61+
62+
std::string Serialize() const override;
63+
64+
Result<std::shared_ptr<DataType>> Deserialize(
65+
std::shared_ptr<DataType> storage_type,
66+
const std::string& serialized_data) const override;
67+
68+
/// Create a FixedShapeTensorArray from ArrayData
69+
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
70+
71+
/// \brief Create a FixedShapeTensorType instance
72+
static Result<std::shared_ptr<DataType>> Make(
73+
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
74+
const std::vector<int64_t>& permutation = {},
75+
const std::vector<std::string>& dim_names = {});
76+
77+
private:
78+
std::shared_ptr<DataType> storage_type_;
79+
std::shared_ptr<DataType> value_type_;
80+
std::vector<int64_t> shape_;
81+
std::vector<int64_t> permutation_;
82+
std::vector<std::string> dim_names_;
83+
};
84+
85+
/// \brief Return a FixedShapeTensorType instance.
86+
ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
87+
const std::shared_ptr<DataType>& storage_type, const std::vector<int64_t>& shape,
88+
const std::vector<int64_t>& permutation = {},
89+
const std::vector<std::string>& dim_names = {});
90+
91+
} // namespace extension
92+
} // namespace arrow

0 commit comments

Comments
 (0)