|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +#include <numeric> |
| 19 | +#include <sstream> |
| 20 | + |
| 21 | +#include "arrow/extension/fixed_shape_tensor.h" |
| 22 | + |
| 23 | +#include "arrow/array/array_nested.h" |
| 24 | +#include "arrow/array/array_primitive.h" |
| 25 | +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep |
| 26 | +#include "arrow/util/int_util_overflow.h" |
| 27 | +#include "arrow/util/logging.h" |
| 28 | +#include "arrow/util/sort.h" |
| 29 | + |
| 30 | +#include <rapidjson/document.h> |
| 31 | +#include <rapidjson/writer.h> |
| 32 | + |
| 33 | +namespace rj = arrow::rapidjson; |
| 34 | + |
| 35 | +namespace arrow { |
| 36 | +namespace extension { |
| 37 | + |
| 38 | +bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const { |
| 39 | + if (extension_name() != other.extension_name()) { |
| 40 | + return false; |
| 41 | + } |
| 42 | + const auto& other_ext = static_cast<const FixedShapeTensorType&>(other); |
| 43 | + |
| 44 | + auto is_permutation_trivial = [](const std::vector<int64_t>& permutation) { |
| 45 | + for (size_t i = 1; i < permutation.size(); ++i) { |
| 46 | + if (permutation[i - 1] + 1 != permutation[i]) { |
| 47 | + return false; |
| 48 | + } |
| 49 | + } |
| 50 | + return true; |
| 51 | + }; |
| 52 | + const bool permutation_equivalent = |
| 53 | + ((permutation_ == other_ext.permutation()) || |
| 54 | + (permutation_.empty() && is_permutation_trivial(other_ext.permutation())) || |
| 55 | + (is_permutation_trivial(permutation_) && other_ext.permutation().empty())); |
| 56 | + |
| 57 | + return (storage_type()->Equals(other_ext.storage_type())) && |
| 58 | + (this->shape() == other_ext.shape()) && (dim_names_ == other_ext.dim_names()) && |
| 59 | + permutation_equivalent; |
| 60 | +} |
| 61 | + |
| 62 | +std::string FixedShapeTensorType::Serialize() const { |
| 63 | + rj::Document document; |
| 64 | + document.SetObject(); |
| 65 | + rj::Document::AllocatorType& allocator = document.GetAllocator(); |
| 66 | + |
| 67 | + rj::Value shape(rj::kArrayType); |
| 68 | + for (auto v : shape_) { |
| 69 | + shape.PushBack(v, allocator); |
| 70 | + } |
| 71 | + document.AddMember(rj::Value("shape", allocator), shape, allocator); |
| 72 | + |
| 73 | + if (!permutation_.empty()) { |
| 74 | + rj::Value permutation(rj::kArrayType); |
| 75 | + for (auto v : permutation_) { |
| 76 | + permutation.PushBack(v, allocator); |
| 77 | + } |
| 78 | + document.AddMember(rj::Value("permutation", allocator), permutation, allocator); |
| 79 | + } |
| 80 | + |
| 81 | + if (!dim_names_.empty()) { |
| 82 | + rj::Value dim_names(rj::kArrayType); |
| 83 | + for (std::string v : dim_names_) { |
| 84 | + dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator); |
| 85 | + } |
| 86 | + document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator); |
| 87 | + } |
| 88 | + |
| 89 | + rj::StringBuffer buffer; |
| 90 | + rj::Writer<rj::StringBuffer> writer(buffer); |
| 91 | + document.Accept(writer); |
| 92 | + return buffer.GetString(); |
| 93 | +} |
| 94 | + |
| 95 | +Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize( |
| 96 | + std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const { |
| 97 | + if (storage_type->id() != Type::FIXED_SIZE_LIST) { |
| 98 | + return Status::Invalid("Expected FixedSizeList storage type, got ", |
| 99 | + storage_type->ToString()); |
| 100 | + } |
| 101 | + auto value_type = |
| 102 | + internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type(); |
| 103 | + rj::Document document; |
| 104 | + if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() || |
| 105 | + !document.HasMember("shape") || !document["shape"].IsArray()) { |
| 106 | + return Status::Invalid("Invalid serialized JSON data: ", serialized_data); |
| 107 | + } |
| 108 | + |
| 109 | + std::vector<int64_t> shape; |
| 110 | + for (auto& x : document["shape"].GetArray()) { |
| 111 | + shape.emplace_back(x.GetInt64()); |
| 112 | + } |
| 113 | + std::vector<int64_t> permutation; |
| 114 | + if (document.HasMember("permutation")) { |
| 115 | + for (auto& x : document["permutation"].GetArray()) { |
| 116 | + permutation.emplace_back(x.GetInt64()); |
| 117 | + } |
| 118 | + if (shape.size() != permutation.size()) { |
| 119 | + return Status::Invalid("Invalid permutation"); |
| 120 | + } |
| 121 | + } |
| 122 | + std::vector<std::string> dim_names; |
| 123 | + if (document.HasMember("dim_names")) { |
| 124 | + for (auto& x : document["dim_names"].GetArray()) { |
| 125 | + dim_names.emplace_back(x.GetString()); |
| 126 | + } |
| 127 | + if (shape.size() != dim_names.size()) { |
| 128 | + return Status::Invalid("Invalid dim_names"); |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + return fixed_shape_tensor(value_type, shape, permutation, dim_names); |
| 133 | +} |
| 134 | + |
| 135 | +std::shared_ptr<Array> FixedShapeTensorType::MakeArray( |
| 136 | + std::shared_ptr<ArrayData> data) const { |
| 137 | + DCHECK_EQ(data->type->id(), Type::EXTENSION); |
| 138 | + DCHECK_EQ("arrow.fixed_shape_tensor", |
| 139 | + static_cast<const ExtensionType&>(*data->type).extension_name()); |
| 140 | + return std::make_shared<ExtensionArray>(data); |
| 141 | +} |
| 142 | + |
| 143 | +Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make( |
| 144 | + const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape, |
| 145 | + const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) { |
| 146 | + if (!permutation.empty() && shape.size() != permutation.size()) { |
| 147 | + return Status::Invalid("permutation size must match shape size. Expected: ", |
| 148 | + shape.size(), " Got: ", permutation.size()); |
| 149 | + } |
| 150 | + if (!dim_names.empty() && shape.size() != dim_names.size()) { |
| 151 | + return Status::Invalid("dim_names size must match shape size. Expected: ", |
| 152 | + shape.size(), " Got: ", dim_names.size()); |
| 153 | + } |
| 154 | + const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), |
| 155 | + std::multiplies<>()); |
| 156 | + return std::make_shared<FixedShapeTensorType>(value_type, static_cast<int32_t>(size), |
| 157 | + shape, permutation, dim_names); |
| 158 | +} |
| 159 | + |
| 160 | +std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& value_type, |
| 161 | + const std::vector<int64_t>& shape, |
| 162 | + const std::vector<int64_t>& permutation, |
| 163 | + const std::vector<std::string>& dim_names) { |
| 164 | + auto maybe_type = FixedShapeTensorType::Make(value_type, shape, permutation, dim_names); |
| 165 | + ARROW_DCHECK_OK(maybe_type.status()); |
| 166 | + return maybe_type.MoveValueUnsafe(); |
| 167 | +} |
| 168 | + |
| 169 | +} // namespace extension |
| 170 | +} // namespace arrow |
0 commit comments