|
| 1 | +#include <functional> |
| 2 | +#include <optional> |
| 3 | +#include <sstream> |
| 4 | +#include <string_view> |
| 5 | +#include <tuple> |
| 6 | + |
| 7 | +#include <nanoarrow/nanoarrow.hpp> |
| 8 | +#include <nanobind/nanobind.h> |
| 9 | +#include <nanobind/stl/pair.h> |
| 10 | +#include <nanobind/stl/string.h> |
| 11 | + |
| 12 | +using namespace nanoarrow::literals; |
| 13 | +namespace nb = nanobind; |
| 14 | + |
| 15 | +static auto ReleaseArrowArray(void *ptr) noexcept -> void { |
| 16 | + auto array = static_cast<struct ArrowArray *>(ptr); |
| 17 | + if (array->release != nullptr) { |
| 18 | + ArrowArrayRelease(array); |
| 19 | + } |
| 20 | + |
| 21 | + delete array; |
| 22 | +} |
| 23 | + |
| 24 | +static auto ReleaseArrowSchema(void *ptr) noexcept -> void { |
| 25 | + auto schema = static_cast<struct ArrowSchema *>(ptr); |
| 26 | + if (schema->release != nullptr) { |
| 27 | + ArrowSchemaRelease(schema); |
| 28 | + } |
| 29 | + |
| 30 | + delete schema; |
| 31 | +} |
| 32 | + |
| 33 | +static auto CumSum(const struct ArrowArrayView *array_view, |
| 34 | + struct ArrowArray *out, bool skipna) { |
| 35 | + bool seen_na = false; |
| 36 | + std::stringstream ss{}; |
| 37 | + |
| 38 | + for (int64_t i = 0; i < array_view->length; i++) { |
| 39 | + const bool isna = ArrowArrayViewIsNull(array_view, i); |
| 40 | + if (!skipna && (seen_na || isna)) { |
| 41 | + seen_na = true; |
| 42 | + ArrowArrayAppendNull(out, 1); |
| 43 | + } else { |
| 44 | + if (!isna) { |
| 45 | + const auto std_sv = ArrowArrayViewGetStringUnsafe(array_view, i); |
| 46 | + ss << std::string_view{std_sv.data, |
| 47 | + static_cast<size_t>(std_sv.size_bytes)}; |
| 48 | + } |
| 49 | + const auto str = ss.str(); |
| 50 | + const ArrowStringView asv{str.c_str(), static_cast<int64_t>(str.size())}; |
| 51 | + NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, asv)); |
| 52 | + } |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +template <typename T> |
| 57 | +concept MinOrMaxOp = |
| 58 | + std::same_as<T, std::less<>> || std::same_as<T, std::greater<>>; |
| 59 | + |
| 60 | +template <auto Op> |
| 61 | + requires MinOrMaxOp<decltype(Op)> |
| 62 | +static auto CumMinOrMax(const struct ArrowArrayView *array_view, |
| 63 | + struct ArrowArray *out, bool skipna) { |
| 64 | + bool seen_na = false; |
| 65 | + std::optional<std::string> current_str{}; |
| 66 | + |
| 67 | + for (int64_t i = 0; i < array_view->length; i++) { |
| 68 | + const bool isna = ArrowArrayViewIsNull(array_view, i); |
| 69 | + if (!skipna && (seen_na || isna)) { |
| 70 | + seen_na = true; |
| 71 | + ArrowArrayAppendNull(out, 1); |
| 72 | + } else { |
| 73 | + if (!isna || current_str) { |
| 74 | + if (!isna) { |
| 75 | + const auto asv = ArrowArrayViewGetStringUnsafe(array_view, i); |
| 76 | + const nb::str pyval{asv.data, static_cast<size_t>(asv.size_bytes)}; |
| 77 | + |
| 78 | + if (current_str) { |
| 79 | + const nb::str pycurrent{current_str->data(), current_str->size()}; |
| 80 | + if (Op(pyval, pycurrent)) { |
| 81 | + current_str = |
| 82 | + std::string{asv.data, static_cast<size_t>(asv.size_bytes)}; |
| 83 | + } |
| 84 | + } else { |
| 85 | + current_str = |
| 86 | + std::string{asv.data, static_cast<size_t>(asv.size_bytes)}; |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + struct ArrowStringView out_sv{ |
| 91 | + current_str->data(), static_cast<int64_t>(current_str->size())}; |
| 92 | + NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, out_sv)); |
| 93 | + } else { |
| 94 | + ArrowArrayAppendEmpty(out, 1); |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +class ArrowStringAccumulation { |
| 101 | +public: |
| 102 | + ArrowStringAccumulation(nb::object array_object, std::string accumulation, |
| 103 | + bool skipna) |
| 104 | + : skipna_(skipna) { |
| 105 | + if ((accumulation == "cumsum") || (accumulation == "cummin") || |
| 106 | + (accumulation == "cummax")) { |
| 107 | + accumulation_ = std::move(accumulation); |
| 108 | + } else { |
| 109 | + const auto error_message = |
| 110 | + std::string("Unsupported accumulation: ") + accumulation; |
| 111 | + throw nb::value_error(error_message.c_str()); |
| 112 | + } |
| 113 | + |
| 114 | + const auto obj = nb::getattr(array_object, "__arrow_c_stream__")(); |
| 115 | + const auto pycapsule_obj = nb::cast<nb::capsule>(obj); |
| 116 | + |
| 117 | + const auto stream = static_cast<struct ArrowArrayStream *>( |
| 118 | + PyCapsule_GetPointer(pycapsule_obj.ptr(), "arrow_array_stream")); |
| 119 | + if (stream == nullptr) { |
| 120 | + throw std::invalid_argument("Invalid Arrow Stream capsule provided!"); |
| 121 | + } |
| 122 | + |
| 123 | + if (stream->get_schema(stream, schema_.get()) != 0) { |
| 124 | + std::string error_msg{stream->get_last_error(stream)}; |
| 125 | + throw std::runtime_error("Could not read from arrow schema:" + error_msg); |
| 126 | + } |
| 127 | + struct ArrowSchemaView schema_view{}; |
| 128 | + NANOARROW_THROW_NOT_OK( |
| 129 | + ArrowSchemaViewInit(&schema_view, schema_.get(), nullptr)); |
| 130 | + |
| 131 | + switch (schema_view.type) { |
| 132 | + case NANOARROW_TYPE_STRING: |
| 133 | + case NANOARROW_TYPE_LARGE_STRING: |
| 134 | + case NANOARROW_TYPE_STRING_VIEW: |
| 135 | + break; |
| 136 | + default: |
| 137 | + const auto error_message = |
| 138 | + std::string("Expected a string-like array type, got: ") + |
| 139 | + ArrowTypeString(schema_view.type); |
| 140 | + throw std::invalid_argument(error_message); |
| 141 | + } |
| 142 | + |
| 143 | + ArrowArrayStreamMove(stream, stream_.get()); |
| 144 | + } |
| 145 | + |
| 146 | + std::pair<nb::capsule, nb::capsule> Accumulate(nb::object requested_schema) { |
| 147 | + struct ArrowSchemaView schema_view{}; |
| 148 | + NANOARROW_THROW_NOT_OK( |
| 149 | + ArrowSchemaViewInit(&schema_view, schema_.get(), nullptr)); |
| 150 | + auto uschema = nanoarrow::UniqueSchema{}; |
| 151 | + ArrowSchemaInit(uschema.get()); |
| 152 | + NANOARROW_THROW_NOT_OK(ArrowSchemaSetType(uschema.get(), schema_view.type)); |
| 153 | + |
| 154 | + // TODO: even though we are reading a stream we are returning an array |
| 155 | + // We should return a like sized stream of data in the future |
| 156 | + auto uarray_out = nanoarrow::UniqueArray{}; |
| 157 | + NANOARROW_THROW_NOT_OK( |
| 158 | + ArrowArrayInitFromSchema(uarray_out.get(), uschema.get(), nullptr)); |
| 159 | + |
| 160 | + NANOARROW_THROW_NOT_OK(ArrowArrayStartAppending(uarray_out.get())); |
| 161 | + |
| 162 | + nanoarrow::UniqueArray chunk{}; |
| 163 | + int errcode{}; |
| 164 | + |
| 165 | + while ((errcode = ArrowArrayStreamGetNext(stream_.get(), chunk.get(), |
| 166 | + nullptr) == 0) && |
| 167 | + chunk->release != nullptr) { |
| 168 | + struct ArrowArrayView array_view{}; |
| 169 | + NANOARROW_THROW_NOT_OK( |
| 170 | + ArrowArrayViewInitFromSchema(&array_view, schema_.get(), nullptr)); |
| 171 | + |
| 172 | + NANOARROW_THROW_NOT_OK( |
| 173 | + ArrowArrayViewSetArray(&array_view, chunk.get(), nullptr)); |
| 174 | + |
| 175 | + if (accumulation_ == "cumsum") { |
| 176 | + CumSum(&array_view, uarray_out.get(), skipna_); |
| 177 | + } else if (accumulation_ == "cummin") { |
| 178 | + CumMinOrMax<std::less{}>(&array_view, uarray_out.get(), skipna_); |
| 179 | + } else if (accumulation_ == "cummax") { |
| 180 | + CumMinOrMax<std::greater{}>(&array_view, uarray_out.get(), skipna_); |
| 181 | + } else { |
| 182 | + throw std::runtime_error("Unexpected branch"); |
| 183 | + } |
| 184 | + |
| 185 | + chunk.reset(); |
| 186 | + } |
| 187 | + |
| 188 | + NANOARROW_THROW_NOT_OK( |
| 189 | + ArrowArrayFinishBuildingDefault(uarray_out.get(), nullptr)); |
| 190 | + |
| 191 | + auto out_schema = new struct ArrowSchema; |
| 192 | + ArrowSchemaMove(uschema.get(), out_schema); |
| 193 | + nb::capsule schema_capsule{out_schema, "arrow_schema", &ReleaseArrowSchema}; |
| 194 | + |
| 195 | + auto out_array = new struct ArrowArray; |
| 196 | + ArrowArrayMove(uarray_out.get(), out_array); |
| 197 | + nb::capsule array_capsule{out_array, "arrow_array", &ReleaseArrowArray}; |
| 198 | + |
| 199 | + return std::pair<nb::capsule, nb::capsule>{schema_capsule, array_capsule}; |
| 200 | + } |
| 201 | + |
| 202 | +private: |
| 203 | + nanoarrow::UniqueArrayStream stream_; |
| 204 | + nanoarrow::UniqueSchema schema_; |
| 205 | + std::string accumulation_; |
| 206 | + bool skipna_; |
| 207 | +}; |
| 208 | + |
| 209 | +NB_MODULE(arrow_string_accumulations, m) { |
| 210 | + nb::class_<ArrowStringAccumulation>(m, "ArrowStringAccumulation") |
| 211 | + .def(nb::init<nb::object, std::string, bool>()) |
| 212 | + .def("__arrow_c_array__", &ArrowStringAccumulation::Accumulate, |
| 213 | + nb::arg("requested_schema") = nb::none()); |
| 214 | +} |
0 commit comments