Skip to content

Commit

Permalink
Add an event handler for ColumnVector.close (rapidsai#13279)
Browse files Browse the repository at this point in the history
This adds an event handler for `ColumnVector` to be invoked when `.close()` is invoked.

I also made `getRefCount` public since it is useful for debugging and I couldn't think of a good reason to keep this hidden.

This event handler is used to keep track of what columns are spillable at a time and is part of NVIDIA/spark-rapids#7672.

Authors:
  - Alessandro Bellina (https://github.com/abellina)

Approvers:
  - Gera Shegalov (https://github.com/gerashegalov)
  - Jason Lowe (https://github.com/jlowe)
  - Jim Brennan (https://github.com/jbrennan333)

URL: rapidsai#13279
  • Loading branch information
abellina authored May 3, 2023
1 parent d0a7dec commit f6abfdd
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
46 changes: 44 additions & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,6 +39,23 @@
* to increment the reference count.
*/
public final class ColumnVector extends ColumnView {
/**
* Interface to handle events for this ColumnVector. Only invoked during
* close, hence `onClosed` is the only event.
*/
public interface EventHandler {
/**
* `onClosed` is invoked with the updated `refCount` during `close`.
* The last invocation of `onClosed` will be with `refCount=0`.
*
* @note the callback is invoked with this `ColumnVector`'s lock held.
*
* @param refCount - the updated ref count for this ColumnVector at the time
* of invocation
*/
void onClosed(int refCount);
}

private static final Logger log = LoggerFactory.getLogger(ColumnVector.class);

static {
Expand All @@ -47,6 +64,7 @@ public final class ColumnVector extends ColumnView {

private Optional<Long> nullCount = Optional.empty();
private int refCount;
private EventHandler eventHandler;

/**
* Wrap an existing on device cudf::column with the corresponding ColumnVector. The new
Expand Down Expand Up @@ -200,6 +218,27 @@ static ColumnVector fromViewWithContiguousAllocation(long columnViewAddress, Dev
return new ColumnVector(columnViewAddress, buffer);
}

/**
* Set an event handler for this vector. This method can be invoked with null
* to unset the handler.
*
* @param newHandler - the EventHandler to use from this point forward
* @return the prior event handler, or null if not set.
*/
public synchronized EventHandler setEventHandler(EventHandler newHandler) {
EventHandler prev = this.eventHandler;
this.eventHandler = newHandler;
return prev;
}

/**
* Returns the current event handler for this ColumnVector or null if no handler
* is associated.
*/
public synchronized EventHandler getEventHandler() {
return this.eventHandler;
}

/**
* This is a really ugly API, but it is possible that the lifecycle of a column of
* data may not have a clear lifecycle thanks to java and GC. This API informs the leak
Expand All @@ -217,6 +256,9 @@ public void noWarnLeakExpected() {
public synchronized void close() {
refCount--;
offHeap.delRef();
if (eventHandler != null) {
eventHandler.onClosed(refCount);
}
if (refCount == 0) {
offHeap.clean(false);
} else if (refCount < 0) {
Expand Down Expand Up @@ -272,7 +314,7 @@ public long getNullCount() {
/**
* Returns this column's current refcount
*/
synchronized int getRefCount() {
public synchronized int getRefCount() {
return refCount;
}

Expand Down
27 changes: 26 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -6676,4 +6676,29 @@ void testApplyBooleanMaskFromListOfStructure() {
assertColumnsAreEqual(expectedCv, actualCv);
}
}

@Test
public void testEventHandlerIsCalledForEachClose() {
final AtomicInteger onClosedWasCalled = new AtomicInteger(0);
try (ColumnVector cv = ColumnVector.fromInts(1,2,3,4)) {
cv.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet());
}
assertEquals(1, onClosedWasCalled.get());
}

@Test
public void testEventHandlerIsNotCalledIfNotSet() {
final AtomicInteger onClosedWasCalled = new AtomicInteger(0);
try (ColumnVector cv = ColumnVector.fromInts(1,2,3,4)) {
assertNull(cv.getEventHandler());
}
assertEquals(0, onClosedWasCalled.get());

try (ColumnVector cv = ColumnVector.fromInts(1,2,3,4)) {
cv.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet());
cv.setEventHandler(null);
}
assertEquals(0, onClosedWasCalled.get());
}

}

0 comments on commit f6abfdd

Please sign in to comment.