|
14 | 14 | * limitations under the License.
|
15 | 15 | */
|
16 | 16 |
|
| 17 | +#include <arrow/api.h> |
17 | 18 | #include <cudf/column/column_factories.hpp>
|
18 | 19 | #include <cudf/concatenate.hpp>
|
19 | 20 | #include <cudf/filling.hpp>
|
| 21 | +#include <cudf/interop.hpp> |
20 | 22 | #include <cudf/hashing.hpp>
|
21 | 23 | #include <cudf/reshape.hpp>
|
22 | 24 | #include <cudf/utilities/bit.hpp>
|
| 25 | +#include <cudf/detail/interop.hpp> |
23 | 26 | #include <cudf/lists/detail/concatenate.hpp>
|
24 | 27 | #include <cudf/lists/lists_column_view.hpp>
|
25 | 28 | #include <cudf/scalar/scalar_factories.hpp>
|
@@ -50,6 +53,78 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, j
|
50 | 53 | CATCH_STD(env, 0);
|
51 | 54 | }
|
52 | 55 |
|
| 56 | +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env, jclass, |
| 57 | + jint j_type, |
| 58 | + jlong j_col_length, |
| 59 | + jlong j_null_count, |
| 60 | + jobject j_data_obj, |
| 61 | + jobject j_validity_obj, |
| 62 | + jobject j_offsets_obj) { |
| 63 | + try { |
| 64 | + cudf::jni::auto_set_device(env); |
| 65 | + cudf::type_id n_type = static_cast<cudf::type_id>(j_type); |
| 66 | + // not all the buffers are used for all types |
| 67 | + void const *data_address = 0; |
| 68 | + int data_length = 0; |
| 69 | + if (j_data_obj != 0) { |
| 70 | + data_address = env->GetDirectBufferAddress(j_data_obj); |
| 71 | + data_length = env->GetDirectBufferCapacity(j_data_obj); |
| 72 | + } |
| 73 | + void const *validity_address = 0; |
| 74 | + int validity_length = 0; |
| 75 | + if (j_validity_obj != 0) { |
| 76 | + validity_address = env->GetDirectBufferAddress(j_validity_obj); |
| 77 | + validity_length = env->GetDirectBufferCapacity(j_validity_obj); |
| 78 | + } |
| 79 | + void const *offsets_address = 0; |
| 80 | + int offsets_length = 0; |
| 81 | + if (j_offsets_obj != 0) { |
| 82 | + offsets_address = env->GetDirectBufferAddress(j_offsets_obj); |
| 83 | + offsets_length = env->GetDirectBufferCapacity(j_offsets_obj); |
| 84 | + } |
| 85 | + auto data_buffer = arrow::Buffer::Wrap(static_cast<const char *>(data_address), static_cast<int>(data_length)); |
| 86 | + auto null_buffer = arrow::Buffer::Wrap(static_cast<const char *>(validity_address), static_cast<int>(validity_length)); |
| 87 | + auto offsets_buffer = arrow::Buffer::Wrap(static_cast<const char *>(offsets_address), static_cast<int>(offsets_length)); |
| 88 | + |
| 89 | + cudf::jni::native_jlongArray outcol_handles(env, 1); |
| 90 | + std::shared_ptr<arrow::Array> arrow_array; |
| 91 | + switch (n_type) { |
| 92 | + case cudf::type_id::DECIMAL32: |
| 93 | + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL32 yet", 0); |
| 94 | + break; |
| 95 | + case cudf::type_id::DECIMAL64: |
| 96 | + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL64 yet", 0); |
| 97 | + break; |
| 98 | + case cudf::type_id::STRUCT: |
| 99 | + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting STRUCT yet", 0); |
| 100 | + break; |
| 101 | + case cudf::type_id::LIST: |
| 102 | + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting LIST yet", 0); |
| 103 | + break; |
| 104 | + case cudf::type_id::DICTIONARY32: |
| 105 | + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DICTIONARY32 yet", 0); |
| 106 | + break; |
| 107 | + case cudf::type_id::STRING: |
| 108 | + arrow_array = std::make_shared<arrow::StringArray>(j_col_length, offsets_buffer, data_buffer, null_buffer, j_null_count); |
| 109 | + break; |
| 110 | + default: |
| 111 | + // this handles the primitive types |
| 112 | + arrow_array = cudf::detail::to_arrow_array(n_type, j_col_length, data_buffer, null_buffer, j_null_count); |
| 113 | + } |
| 114 | + auto name_and_type = arrow::field("col", arrow_array->type()); |
| 115 | + std::vector<std::shared_ptr<arrow::Field>> fields = {name_and_type}; |
| 116 | + std::shared_ptr<arrow::Schema> schema = std::make_shared<arrow::Schema>(fields); |
| 117 | + auto arrow_table = arrow::Table::Make(schema, std::vector<std::shared_ptr<arrow::Array>>{arrow_array}); |
| 118 | + std::unique_ptr<cudf::table> table_result = cudf::from_arrow(*(arrow_table)); |
| 119 | + std::vector<std::unique_ptr<cudf::column>> retCols = table_result->release(); |
| 120 | + if (retCols.size() != 1) { |
| 121 | + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Must result in one column", 0); |
| 122 | + } |
| 123 | + return reinterpret_cast<jlong>(retCols[0].release()); |
| 124 | + } |
| 125 | + CATCH_STD(env, 0); |
| 126 | +} |
| 127 | + |
53 | 128 | JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object,
|
54 | 129 | jlongArray handles,
|
55 | 130 | jlong j_type,
|
|
0 commit comments