Skip to content

Commit dd1efe1

Browse files
author
Kuhu Shukla
authored
Add JNI and Java bindings for list_contains (#7125)
Adds JNI and Java side bindings for `list_contains` that is being added as part of #7039. Authors: - Kuhu Shukla (@kuhushukla) Approvers: - Robert (Bobby) Evans (@revans2) - MithunR (@mythrocks) URL: #7125
1 parent fc40c52 commit dd1efe1

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

java/src/main/java/ai/rapids/cudf/ColumnView.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2323,6 +2323,37 @@ public static ColumnView makeStructView(ColumnView... columns) {
23232323
return makeStructView(columns[0].rows, columns);
23242324
}
23252325

2326+
/**
2327+
* Create a column of bool values indicating whether the specified scalar
2328+
* is an element of each row of a list column.
2329+
* Output `column[i]` is set to null if one or more of the following are true:
2330+
* 1. The key is null
2331+
* 2. The column vector list value is null
2332+
* 3. The list row does not contain the key, and contains at least
2333+
* one null.
2334+
* @param key the scalar to look up
2335+
* @return a Boolean ColumnVector with the result of the lookup
2336+
*/
2337+
public final ColumnVector listContains(Scalar key) {
2338+
assert type.equals(DType.LIST) : "column type must be a LIST";
2339+
return new ColumnVector(listContains(getNativeView(), key.getScalarHandle()));
2340+
}
2341+
2342+
/**
2343+
* Create a column of bool values indicating whether the list rows of the first
2344+
* column contain the corresponding values in the second column.
2345+
* 1. The key value is null
2346+
* 2. The column vector list value is null
2347+
* 3. The list row does not contain the key, and contains at least
2348+
* one null.
2349+
* @param key the ColumnVector with look up values
2350+
* @return a Boolean ColumnVector with the result of the lookup
2351+
*/
2352+
public final ColumnVector listContainsColumn(ColumnView key) {
2353+
assert type.equals(DType.LIST) : "column type must be a LIST";
2354+
return new ColumnVector(listContainsColumn(getNativeView(), key.getNativeView()));
2355+
}
2356+
23262357
/////////////////////////////////////////////////////////////////////////////
23272358
// INTERNAL/NATIVE ACCESS
23282359
/////////////////////////////////////////////////////////////////////////////
@@ -2558,6 +2589,22 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
25582589

25592590
private static native long extractListElement(long nativeView, int index);
25602591

2592+
/**
2593+
* Native method for list lookup
2594+
* @param nativeView the column view handle of the list
2595+
* @param key the scalar key handle
2596+
* @return column handle of the resultant
2597+
*/
2598+
private static native long listContains(long nativeView, long key);
2599+
2600+
/**
2601+
* Native method for list lookup
2602+
* @param nativeView the column view handle of the list
2603+
* @param keyColumn the column handle of look up keys
2604+
* @return column handle of the resultant
2605+
*/
2606+
private static native long listContainsColumn(long nativeView, long keyColumn);
2607+
25612608
private static native long castTo(long nativeHandle, int type, int scale);
25622609

25632610
private static native long logicalCastTo(long nativeHandle, int type, int scale);

java/src/main/native/src/ColumnViewJni.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#include <cudf/transform.hpp>
5757
#include <cudf/unary.hpp>
5858
#include <cudf/utilities/bit.hpp>
59+
#include <cudf/lists/contains.hpp>
5960
#include <cudf/lists/lists_column_view.hpp>
6061
#include <cudf/structs/structs_column_view.hpp>
6162
#include <map_lookup.hpp>
@@ -329,6 +330,40 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractListElement(JNIEnv
329330
CATCH_STD(env, 0);
330331
}
331332

333+
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listContains(JNIEnv *env, jclass,
334+
jlong column_view,
335+
jlong lookup_key) {
336+
JNI_NULL_CHECK(env, column_view, "column is null", 0);
337+
JNI_NULL_CHECK(env, lookup_key, "lookup scalar is null", 0);
338+
try {
339+
cudf::jni::auto_set_device(env);
340+
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(column_view);
341+
cudf::lists_column_view lcv(*cv);
342+
cudf::scalar *lookup_scalar = reinterpret_cast<cudf::scalar *>(lookup_key);
343+
344+
std::unique_ptr<cudf::column> ret = cudf::lists::contains(lcv, *lookup_scalar);
345+
return reinterpret_cast<jlong>(ret.release());
346+
}
347+
CATCH_STD(env, 0);
348+
}
349+
350+
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listContainsColumn(JNIEnv *env, jclass,
351+
jlong column_view,
352+
jlong lookup_key_cv) {
353+
JNI_NULL_CHECK(env, column_view, "column is null", 0);
354+
JNI_NULL_CHECK(env, lookup_key_cv, "lookup column is null", 0);
355+
try {
356+
cudf::jni::auto_set_device(env);
357+
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(column_view);
358+
cudf::lists_column_view lcv(*cv);
359+
cudf::column_view *lookup_cv = reinterpret_cast<cudf::column_view *>(lookup_key_cv);
360+
361+
std::unique_ptr<cudf::column> ret = cudf::lists::contains(lcv, *lookup_cv);
362+
return reinterpret_cast<jlong>(ret.release());
363+
}
364+
CATCH_STD(env, 0);
365+
}
366+
332367
JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass,
333368
jlong column_view,
334369
jlong delimiter) {

java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2899,6 +2899,67 @@ void testExtractListElements() {
28992899
}
29002900
}
29012901

2902+
@Test
2903+
void testListContainsString() {
2904+
List<String> list1 = Arrays.asList("Héllo there", "thésé");
2905+
List<String> list2 = Arrays.asList("", "ARé some", "test strings");
2906+
List<String> list3 = Arrays.asList(null, "", "ARé some", "test strings", "thésé");
2907+
List<String> list4 = Arrays.asList(null, "", "ARé some", "test strings");
2908+
List<String> list5 = null;
2909+
try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true,
2910+
new HostColumnVector.BasicType(true, DType.STRING)), list1, list2, list3, list4, list5);
2911+
ColumnVector expected = ColumnVector.fromBoxedBooleans(true, false, true, null, null);
2912+
ColumnVector result = v.listContains(Scalar.fromString("thésé"))) {
2913+
assertColumnsAreEqual(expected, result);
2914+
}
2915+
}
2916+
2917+
@Test
2918+
void testListContainsInt() {
2919+
List<Integer> list1 = Arrays.asList(1, 2, 3);
2920+
List<Integer> list2 = Arrays.asList(4, 5, 6);
2921+
List<Integer> list3 = Arrays.asList(7, 8, 9);
2922+
List<Integer> list4 = null;
2923+
try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true,
2924+
new HostColumnVector.BasicType(true, DType.INT32)), list1, list2, list3, list4);
2925+
ColumnVector expected = ColumnVector.fromBoxedBooleans(false, false, true, null);
2926+
ColumnVector result = v.listContains(Scalar.fromInt(7))) {
2927+
assertColumnsAreEqual(expected, result);
2928+
}
2929+
}
2930+
2931+
@Test
2932+
void testListContainsStringCol() {
2933+
List<String> list1 = Arrays.asList("Héllo there", "thésé");
2934+
List<String> list2 = Arrays.asList("", "ARé some", "test strings");
2935+
List<String> list3 = Arrays.asList("FOO", "", "ARé some", "test");
2936+
List<String> list4 = Arrays.asList(null, "FOO", "", "ARé some", "test");
2937+
List<String> list5 = Arrays.asList(null, "FOO", "", "ARé some", "test");
2938+
List<String> list6 = null;
2939+
try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true,
2940+
new HostColumnVector.BasicType(true, DType.STRING)), list1, list2, list3, list4, list5, list6);
2941+
ColumnVector expected = ColumnVector.fromBoxedBooleans(true, true, true, true, null, null);
2942+
ColumnVector result = v.listContainsColumn(
2943+
ColumnVector.fromStrings("thésé", "", "test", "test", "iotA", null))) {
2944+
assertColumnsAreEqual(expected, result);
2945+
}
2946+
}
2947+
2948+
@Test
2949+
void testListContainsIntCol() {
2950+
List<Integer> list1 = Arrays.asList(1, 2, 3);
2951+
List<Integer> list2 = Arrays.asList(4, 5, 6);
2952+
List<Integer> list3 = Arrays.asList(null, 8, 9);
2953+
List<Integer> list4 = Arrays.asList(null, 8, 9);
2954+
List<Integer> list5 = null;
2955+
try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true,
2956+
new HostColumnVector.BasicType(true, DType.INT32)), list1, list2, list3, list4, list5);
2957+
ColumnVector expected = ColumnVector.fromBoxedBooleans(true, false, true, null, null);
2958+
ColumnVector result = v.listContainsColumn(ColumnVector.fromBoxedInts(3, 3, 8, 3, null))) {
2959+
assertColumnsAreEqual(expected, result);
2960+
}
2961+
}
2962+
29022963
@Test
29032964
void testStringSplitRecord() {
29042965
try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", "null", "", "ARé some", "test strings");

0 commit comments

Comments
 (0)