Skip to content

Commit cbc0394

Browse files
authored
Add JNI support for converting Arrow buffers to CUDF ColumnVectors (#7222)
This adds in the JNI layer to be able to take build up Arrow column vectors which are just references to off heap arrow buffers and then convert those into CUDF ColumnVectors by directly copying the arrow data to the GPU. The way this works is you create a ArrowColumnBuilder for each column you need. You call addBatch for each separate arrow buffer you want to add into that column and then you call buildAndPutOnDevice() on the Builder. That will cause the arrow pointer to be passed into CUDF, an Arrow Table with 1 column is created, that Arrow table gets passed into the cudf::from_arrow which returns a CUDF Table and we grab the 1 column from that and return it. Note this only supports primitive types and Strings for now. List, Struct, Dictionary, and Decimal are not supported yet. Signed-off-by: Thomas Graves <[email protected]> Authors: - Thomas Graves (@tgravescs) Approvers: - Robert (Bobby) Evans (@revans2) - Jason Lowe (@jlowe) URL: #7222
1 parent 9631660 commit cbc0394

File tree

5 files changed

+574
-0
lines changed

5 files changed

+574
-0
lines changed

java/pom.xml

+7
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@
132132
<version>2.25.0</version>
133133
<scope>test</scope>
134134
</dependency>
135+
<dependency>
136+
<groupId>org.apache.arrow</groupId>
137+
<artifactId>arrow-vector</artifactId>
138+
<version>${arrow.version}</version>
139+
<scope>test</scope>
140+
</dependency>
135141
</dependencies>
136142

137143
<properties>
@@ -151,6 +157,7 @@
151157
<GPU_ARCHS>ALL</GPU_ARCHS>
152158
<native.build.path>${project.build.directory}/cmake-build</native.build.path>
153159
<slf4j.version>1.7.30</slf4j.version>
160+
<arrow.version>0.15.1</arrow.version>
154161
</properties>
155162

156163
<profiles>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
*
3+
* Copyright (c) 2021, NVIDIA CORPORATION.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
package ai.rapids.cudf;
20+
21+
import java.nio.ByteBuffer;
22+
import java.util.ArrayList;
23+
24+
/**
25+
* Column builder from Arrow data. This builder takes in byte buffers referencing
26+
* Arrow data and allows efficient building of CUDF ColumnVectors from that Arrow data.
27+
* The caller can add multiple batches where each batch corresponds to Arrow data
28+
* and those batches get concatenated together after being converted to CUDF
29+
* ColumnVectors.
30+
* This currently only supports primitive types and Strings, Decimals and nested types
31+
* such as list and struct are not supported.
32+
*/
33+
public final class ArrowColumnBuilder implements AutoCloseable {
34+
private DType type;
35+
private final ArrayList<ByteBuffer> data = new ArrayList<>();
36+
private final ArrayList<ByteBuffer> validity = new ArrayList<>();
37+
private final ArrayList<ByteBuffer> offsets = new ArrayList<>();
38+
private final ArrayList<Long> nullCount = new ArrayList<>();
39+
private final ArrayList<Long> rows = new ArrayList<>();
40+
41+
public ArrowColumnBuilder(HostColumnVector.DataType type) {
42+
this.type = type.getType();
43+
}
44+
45+
/**
46+
* Add an Arrow buffer. This API allows you to add multiple if you want them
47+
* combined into a single ColumnVector.
48+
* Note, this takes all data, validity, and offsets buffers, but they may not all
49+
* be needed based on the data type. The buffer should be null if its not used
50+
* for that type.
51+
* This API only supports primitive types and Strings, Decimals and nested types
52+
* such as list and struct are not supported.
53+
* @param rows - number of rows in this Arrow buffer
54+
* @param nullCount - number of null values in this Arrow buffer
55+
* @param data - ByteBuffer of the Arrow data buffer
56+
* @param validity - ByteBuffer of the Arrow validity buffer
57+
* @param offsets - ByteBuffer of the Arrow offsets buffer
58+
*/
59+
public void addBatch(long rows, long nullCount, ByteBuffer data, ByteBuffer validity,
60+
ByteBuffer offsets) {
61+
this.rows.add(rows);
62+
this.nullCount.add(nullCount);
63+
this.data.add(data);
64+
this.validity.add(validity);
65+
this.offsets.add(offsets);
66+
}
67+
68+
/**
69+
* Create the immutable ColumnVector, copied to the device based on the Arrow data.
70+
* @return - new ColumnVector
71+
*/
72+
public final ColumnVector buildAndPutOnDevice() {
73+
int numBatches = rows.size();
74+
ArrayList<ColumnVector> allVecs = new ArrayList<>(numBatches);
75+
ColumnVector vecRet;
76+
try {
77+
for (int i = 0; i < numBatches; i++) {
78+
allVecs.add(ColumnVector.fromArrow(type, rows.get(i), nullCount.get(i),
79+
data.get(i), validity.get(i), offsets.get(i)));
80+
}
81+
if (numBatches == 1) {
82+
vecRet = allVecs.get(0);
83+
} else if (numBatches > 1) {
84+
vecRet = ColumnVector.concatenate(allVecs.toArray(new ColumnVector[0]));
85+
} else {
86+
throw new IllegalStateException("Can't build a ColumnVector when no Arrow batches specified");
87+
}
88+
} finally {
89+
// close the vectors that were concatenated
90+
if (numBatches > 1) {
91+
allVecs.forEach(cv -> cv.close());
92+
}
93+
}
94+
return vecRet;
95+
}
96+
97+
@Override
98+
public void close() {
99+
// memory buffers owned outside of this
100+
}
101+
102+
@Override
103+
public String toString() {
104+
return "ArrowColumnBuilder{" +
105+
"type=" + type +
106+
", data=" + data +
107+
", validity=" + validity +
108+
", offsets=" + offsets +
109+
", nullCount=" + nullCount +
110+
", rows=" + rows +
111+
'}';
112+
}
113+
}

java/src/main/java/ai/rapids/cudf/ColumnVector.java

+49
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import java.math.BigDecimal;
2727
import java.math.RoundingMode;
28+
import java.nio.ByteBuffer;
2829
import java.util.ArrayList;
2930
import java.util.List;
3031
import java.util.Optional;
@@ -310,6 +311,50 @@ public BaseDeviceMemoryBuffer getDeviceBufferFor(BufferType type) {
310311
return srcBuffer;
311312
}
312313

314+
/**
315+
* Ensures the ByteBuffer passed in is a direct byte buffer.
316+
* If it is not then it creates one and copies the data in
317+
* the byte buffer passed in to the direct byte buffer
318+
* it created and returns it.
319+
*/
320+
private static ByteBuffer bufferAsDirect(ByteBuffer buf) {
321+
ByteBuffer bufferOut = buf;
322+
if (bufferOut != null && !bufferOut.isDirect()) {
323+
bufferOut = ByteBuffer.allocateDirect(buf.remaining());
324+
bufferOut.put(buf);
325+
bufferOut.flip();
326+
}
327+
return bufferOut;
328+
}
329+
330+
/**
331+
* Create a ColumnVector from the Apache Arrow byte buffers passed in.
332+
* Any of the buffers not used for that datatype should be set to null.
333+
* The buffers are expected to be off heap buffers, but if they are not,
334+
* it will handle copying them to direct byte buffers.
335+
* This only supports primitive types. Strings, Decimals and nested types
336+
* such as list and struct are not supported.
337+
* @param type - type of the column
338+
* @param numRows - Number of rows in the arrow column
339+
* @param nullCount - Null count
340+
* @param data - ByteBuffer of the Arrow data buffer
341+
* @param validity - ByteBuffer of the Arrow validity buffer
342+
* @param offsets - ByteBuffer of the Arrow offsets buffer
343+
* @return - new ColumnVector
344+
*/
345+
public static ColumnVector fromArrow(
346+
DType type,
347+
long numRows,
348+
long nullCount,
349+
ByteBuffer data,
350+
ByteBuffer validity,
351+
ByteBuffer offsets) {
352+
long columnHandle = fromArrow(type.typeId.getNativeId(), numRows, nullCount,
353+
bufferAsDirect(data), bufferAsDirect(validity), bufferAsDirect(offsets));
354+
ColumnVector vec = new ColumnVector(columnHandle);
355+
return vec;
356+
}
357+
313358
/**
314359
* Create a new vector of length rows, where each row is filled with the Scalar's
315360
* value
@@ -615,6 +660,10 @@ public ColumnVector castTo(DType type) {
615660

616661
private static native long sequence(long initialValue, long step, int rows);
617662

663+
private static native long fromArrow(int type, long col_length,
664+
long null_count, ByteBuffer data, ByteBuffer validity,
665+
ByteBuffer offsets) throws CudfException;
666+
618667
private static native long fromScalar(long scalarHandle, int rowCount) throws CudfException;
619668

620669
private static native long makeList(long[] handles, long typeHandle, int scale, long rows)

java/src/main/native/src/ColumnVectorJni.cpp

+75
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
* limitations under the License.
1515
*/
1616

17+
#include <arrow/api.h>
1718
#include <cudf/column/column_factories.hpp>
1819
#include <cudf/concatenate.hpp>
1920
#include <cudf/filling.hpp>
21+
#include <cudf/interop.hpp>
2022
#include <cudf/hashing.hpp>
2123
#include <cudf/reshape.hpp>
2224
#include <cudf/utilities/bit.hpp>
25+
#include <cudf/detail/interop.hpp>
2326
#include <cudf/lists/detail/concatenate.hpp>
2427
#include <cudf/lists/lists_column_view.hpp>
2528
#include <cudf/scalar/scalar_factories.hpp>
@@ -50,6 +53,78 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, j
5053
CATCH_STD(env, 0);
5154
}
5255

56+
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env, jclass,
57+
jint j_type,
58+
jlong j_col_length,
59+
jlong j_null_count,
60+
jobject j_data_obj,
61+
jobject j_validity_obj,
62+
jobject j_offsets_obj) {
63+
try {
64+
cudf::jni::auto_set_device(env);
65+
cudf::type_id n_type = static_cast<cudf::type_id>(j_type);
66+
// not all the buffers are used for all types
67+
void const *data_address = 0;
68+
int data_length = 0;
69+
if (j_data_obj != 0) {
70+
data_address = env->GetDirectBufferAddress(j_data_obj);
71+
data_length = env->GetDirectBufferCapacity(j_data_obj);
72+
}
73+
void const *validity_address = 0;
74+
int validity_length = 0;
75+
if (j_validity_obj != 0) {
76+
validity_address = env->GetDirectBufferAddress(j_validity_obj);
77+
validity_length = env->GetDirectBufferCapacity(j_validity_obj);
78+
}
79+
void const *offsets_address = 0;
80+
int offsets_length = 0;
81+
if (j_offsets_obj != 0) {
82+
offsets_address = env->GetDirectBufferAddress(j_offsets_obj);
83+
offsets_length = env->GetDirectBufferCapacity(j_offsets_obj);
84+
}
85+
auto data_buffer = arrow::Buffer::Wrap(static_cast<const char *>(data_address), static_cast<int>(data_length));
86+
auto null_buffer = arrow::Buffer::Wrap(static_cast<const char *>(validity_address), static_cast<int>(validity_length));
87+
auto offsets_buffer = arrow::Buffer::Wrap(static_cast<const char *>(offsets_address), static_cast<int>(offsets_length));
88+
89+
cudf::jni::native_jlongArray outcol_handles(env, 1);
90+
std::shared_ptr<arrow::Array> arrow_array;
91+
switch (n_type) {
92+
case cudf::type_id::DECIMAL32:
93+
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL32 yet", 0);
94+
break;
95+
case cudf::type_id::DECIMAL64:
96+
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL64 yet", 0);
97+
break;
98+
case cudf::type_id::STRUCT:
99+
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting STRUCT yet", 0);
100+
break;
101+
case cudf::type_id::LIST:
102+
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting LIST yet", 0);
103+
break;
104+
case cudf::type_id::DICTIONARY32:
105+
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DICTIONARY32 yet", 0);
106+
break;
107+
case cudf::type_id::STRING:
108+
arrow_array = std::make_shared<arrow::StringArray>(j_col_length, offsets_buffer, data_buffer, null_buffer, j_null_count);
109+
break;
110+
default:
111+
// this handles the primitive types
112+
arrow_array = cudf::detail::to_arrow_array(n_type, j_col_length, data_buffer, null_buffer, j_null_count);
113+
}
114+
auto name_and_type = arrow::field("col", arrow_array->type());
115+
std::vector<std::shared_ptr<arrow::Field>> fields = {name_and_type};
116+
std::shared_ptr<arrow::Schema> schema = std::make_shared<arrow::Schema>(fields);
117+
auto arrow_table = arrow::Table::Make(schema, std::vector<std::shared_ptr<arrow::Array>>{arrow_array});
118+
std::unique_ptr<cudf::table> table_result = cudf::from_arrow(*(arrow_table));
119+
std::vector<std::unique_ptr<cudf::column>> retCols = table_result->release();
120+
if (retCols.size() != 1) {
121+
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Must result in one column", 0);
122+
}
123+
return reinterpret_cast<jlong>(retCols[0].release());
124+
}
125+
CATCH_STD(env, 0);
126+
}
127+
53128
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object,
54129
jlongArray handles,
55130
jlong j_type,

0 commit comments

Comments
 (0)