Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RHOAIENG-7082: Add SHAP support for KServe explainer #7

Merged
merged 3 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ spec:

The explanation request will be identical to the LIME explainer case.

## Configuration

The following environment variables can be used in the `InferenceService` to customize the explainer:

| Name | Description | Default |
|--------------------------------------------------------------------------|--------------------------------------------------------------------|---------------|
| `EXPLAINER_TYPE` | `LIME` or `SHAP`, the explainer to use. | `LIME` |
| `LIME_SAMPLES` | The number of samples to use in LIME | `200` |
| `LIME_RETRIES` | Number of LIME retries | `2` |
| `LIME_WLR` | Use LIME Weighted Linear Regression, `true` or `false` | `true` |
| `LIME_NORMALIZE_WEIGHTS` | Whether LIME should normalize the weights, `true` or `false` | `true` |
| `EXPLAINER_SHAP_BACKGROUND_QUEUE` | The number of observations to keep in memory for SHAP's background | `10` |
| `EXPLAINER_SHAP_BACKGROUND_DIVERSITY` | The number of synthetic samples to generate for diversity | `10` |

## Contributing

To get started with contributing to this project:
Expand Down
8 changes: 7 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.kie.trustyai</groupId>
<artifactId>trustyai-kserve</artifactId>
<version>1.0-SNAPSHOT</version>
<version>0.2-SNAPSHOT</version>

<properties>
<compiler-plugin.version>3.11.0</compiler-plugin.version>
Expand Down Expand Up @@ -46,6 +46,12 @@
<artifactId>explainability-connectors</artifactId>
<version>${trustyai.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-collections4</artifactId>
<version>4.4</version>
</dependency>


<dependency>
<groupId>io.quarkus</groupId>
Expand Down
17 changes: 13 additions & 4 deletions src/main/java/org/kie/trustyai/CommandLineArgs.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ public class CommandLineArgs {

@CommandLine.Option(names = "--http_port", description = "The HTTP port of the predictor")
private int httpPort;
@CommandLine.Option(names = "--predictor_protocol", defaultValue = "v1", description = "The predictor protocol version (v1 or v2)")
private String predictorProtocol;

public String getPredictorProtocol() {
return predictorProtocol;
}

public String getPredictorHost() {
return predictorHost;
Expand All @@ -27,14 +33,17 @@ public int getHttpPort() {
return httpPort;
}

public String getV1HTTPPredictorURI() {
public String getV1HTTPPredictorURI(String modelName) {
return "http://" + predictorHost + "/v1/models/" + modelName + ":predict";
}

public String getV2HTTPPredictorURI(String modelName) {

return "http://" +
ruivieira marked this conversation as resolved.
Show resolved Hide resolved
predictorHost +
"/v1/models/" +
"/v2/models/" +
modelName +
":predict";
"/infer";
}


}
17 changes: 8 additions & 9 deletions src/main/java/org/kie/trustyai/ConfigCommand.java
Original file line number Diff line number Diff line change
@@ -1,42 +1,41 @@
package org.kie.trustyai;

import io.quarkus.logging.Log;
import io.quarkus.runtime.Quarkus;
import io.quarkus.runtime.QuarkusApplication;
import io.quarkus.runtime.annotations.QuarkusMain;
import jakarta.inject.Inject;
import picocli.CommandLine;
import org.jboss.logging.Logger;

import java.util.Arrays;

@QuarkusMain
public class ConfigCommand implements QuarkusApplication {

private static final Logger LOGGER = Logger.getLogger(ConfigCommand.class.getName());

@Inject
CommandLineArgs cmdArgs;

@Override
public int run(String... args) {
LOGGER.debug("Starting application...");
Log.info("Starting application...");
final CommandLine commandLine = new CommandLine(cmdArgs);

Log.debug("Using command-line arguments: " + Arrays.toString(args));
try {
commandLine.parseArgs(args);
if (commandLine.isUsageHelpRequested()) {
commandLine.usage(System.out);
return 0;
}


LOGGER.debug("Configuration loaded successfully.");
Log.info("Configuration loaded successfully.");
} catch (CommandLine.ParameterException e) {
LOGGER.error("Error parsing command line: " + e.getMessage());
Log.error("Error parsing command line: " + e.getMessage());
commandLine.usage(System.err);
return 1;
}

Quarkus.waitForExit(); // Wait for Quarkus shutdown events
LOGGER.debug("Quarkus is waiting for exit...");
Log.info("Quarkus is waiting for exit...");
return 0;
}
}
39 changes: 20 additions & 19 deletions src/main/java/org/kie/trustyai/ConfigService.java
Original file line number Diff line number Diff line change
@@ -1,55 +1,56 @@
package org.kie.trustyai;

import io.quarkus.logging.Log;
import jakarta.annotation.PostConstruct;
import jakarta.enterprise.context.ApplicationScoped;
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.jboss.logging.Logger;


@ApplicationScoped
public class ConfigService {

private static final Logger LOGGER = Logger.getLogger(ConfigService.class.getName());



@ConfigProperty(name = "explainer.type", defaultValue = "LIME")
ExplainerType explainerType;
@ConfigProperty(name = "lime.samples", defaultValue = "200")
int limeSamples;
@ConfigProperty(name = "lime.retries", defaultValue = "2")
int limeRetries;
@ConfigProperty(name = "lime.wlr", defaultValue = "true")
boolean limeWLR;
@ConfigProperty(name = "lime.normalize.weights", defaultValue = "true")
boolean limeNormalizeWeights;
@ConfigProperty(name = "explainer.shap.background.queue", defaultValue = "10")
int queueSize;
@ConfigProperty(name = "explainer.shap.background.diversity", defaultValue = "10")
int diversitySize;

public int getLimeSamples() {
return limeSamples;
}

@ConfigProperty(name = "lime.samples", defaultValue = "200")
int limeSamples;

public int getLimeRetries() {
return limeRetries;
}

@ConfigProperty(name = "lime.retries", defaultValue = "2")
int limeRetries;

public boolean getLimeWLR() {
return limeWLR;
}

@ConfigProperty(name = "lime.wlr", defaultValue = "true")
boolean limeWLR;


public boolean getLimeNormalizeWeights() {
return limeNormalizeWeights;
}

@ConfigProperty(name = "lime.normalize.weights", defaultValue = "true")
boolean limeNormalizeWeights;
public int getQueueSize() {
return queueSize;
}

public int getDiversitySize() {
return diversitySize;
}

@PostConstruct
private void validateConfig() {
if (explainerType == null) {
LOGGER.error("Unknown explainer type configured. Falling back to LIME.");
Log.error("Unknown explainer type configured. Falling back to LIME.");
explainerType = ExplainerType.LIME;
}
}
Expand Down
23 changes: 18 additions & 5 deletions src/main/java/org/kie/trustyai/ExplainerFactory.java
Original file line number Diff line number Diff line change
@@ -1,34 +1,47 @@
package org.kie.trustyai;

import java.util.List;

import io.quarkus.logging.Log;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.kie.trustyai.explainability.local.LocalExplainer;
import org.kie.trustyai.explainability.local.lime.LimeConfig;
import org.kie.trustyai.explainability.local.lime.LimeExplainer;
import org.kie.trustyai.explainability.local.shap.ShapConfig;
import org.kie.trustyai.explainability.local.shap.ShapKernelExplainer;
import org.kie.trustyai.explainability.model.*;

import java.util.List;
import org.kie.trustyai.explainability.model.PredictionInput;
import org.kie.trustyai.explainability.model.SaliencyResults;

@Singleton
public class ExplainerFactory {

@Inject
ConfigService configService;

public LocalExplainer<SaliencyResults> getExplainer(ExplainerType type, List<PredictionInput> background) throws IllegalArgumentException {
@Inject
StreamingGeneratorManager streamingGeneratorManager;

public LocalExplainer<SaliencyResults> getExplainer(ExplainerType type) throws IllegalArgumentException {
return switch (type) {
case LIME -> {
final LimeConfig limeConfig = new LimeConfig()
.withNormalizeWeights(configService.getLimeNormalizeWeights())
.withSamples(configService.getLimeSamples())
.withRetries(configService.getLimeRetries())
.withUseWLRLinearModel(configService.getLimeWLR());
Log.info("Instating LIME explainer");
yield new LimeExplainer(limeConfig);
}
case SHAP -> {
ShapConfig shapConfig = ShapConfig.builder().withLink(ShapConfig.LinkType.IDENTITY).withBackground(background).build();
final int backgroundSize = configService.getQueueSize() + configService.getDiversitySize();
Log.debug("Requesting " + backgroundSize + " background samples from SHAP's streaming generator");
final List<PredictionInput> background = streamingGeneratorManager.getStreamingGenerator().generate(backgroundSize);
Log.debug("The background has a size of " + background.size());
final ShapConfig shapConfig = ShapConfig.builder().withRegularizer(5)
.withLink(ShapConfig.LinkType.IDENTITY)
.withBackground(background).build();
Log.info("Instantiating SHAP explainer");
yield new ShapKernelExplainer(shapConfig);
}
default -> throw new IllegalArgumentException("Unsupported explainer type: " + type);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
package org.kie.trustyai;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.quarkus.logging.Log;
import jakarta.enterprise.inject.Default;
import jakarta.inject.Inject;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import org.jboss.logging.Logger;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.kie.trustyai.connectors.kserve.v1.KServeV1HTTPPredictionProvider;
import org.kie.trustyai.connectors.kserve.v1.KServeV1RequestPayload;
import org.kie.trustyai.explainability.local.LocalExplainer;
import org.kie.trustyai.explainability.model.*;

import jakarta.inject.Inject;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.kie.trustyai.explainability.model.Prediction;
import org.kie.trustyai.explainability.model.PredictionInput;
import org.kie.trustyai.explainability.model.PredictionOutput;
import org.kie.trustyai.explainability.model.PredictionProvider;
import org.kie.trustyai.explainability.model.SaliencyResults;
import org.kie.trustyai.explainability.model.SimplePrediction;
import org.kie.trustyai.payloads.SaliencyExplanationResponse;

import java.util.List;
import java.util.concurrent.ExecutionException;

@Path("/v1/models/{modelName}:explain")
public class ExplainerEndpoint {

private static final Logger LOGGER = Logger.getLogger(ExplainerEndpoint.class.getName());
public class ExplainerV1Endpoint {

@Inject
ObjectMapper objectMapper;
Expand All @@ -38,37 +43,54 @@ public class ExplainerEndpoint {
@Inject
ExplainerFactory explainerFactory;

@Inject
StreamingGeneratorManager streamingGeneratorManager;

@POST
@Consumes(MediaType.APPLICATION_JSON)
public Response explainIncome(@PathParam("modelName") String modelName, KServeV1RequestPayload data)
public Response explain(@PathParam("modelName") String modelName, KServeV1RequestPayload data)
throws ExecutionException, InterruptedException {
final String predictorURI = cmdArgs.getV1HTTPPredictorURI();

LOGGER.debug("Using explainer type [" + configService.getExplainerType() + "]");
LOGGER.debug("Using predictor URI [" + predictorURI + "]");
Log.info("Using explainer type [" + configService.getExplainerType() + "]");
Log.info("Using V1 HTTP protocol");
final String predictorURI = cmdArgs.getV1HTTPPredictorURI(modelName);
final PredictionProvider provider = new KServeV1HTTPPredictionProvider(null, null, predictorURI, 1);
Log.info("Using predictor URI [" + predictorURI + "]");

final PredictionProvider provider = new KServeV1HTTPPredictionProvider(null, null, predictorURI);
final List<PredictionInput> input = data.toPredictionInputs();
final PredictionOutput output = provider.predictAsync(input).get().get(0);
final Prediction prediction = new SimplePrediction(input.get(0), output);
final int dimensions = input.get(0).getFeatures().size();

if (configService.getExplainerType() == ExplainerType.SHAP) {
if (Objects.isNull(streamingGeneratorManager.getStreamingGenerator())) {
Log.info("Initializing SHAP's Streaming Background Generator with dimension " + dimensions);
streamingGeneratorManager.initialize(dimensions);
}
final double[] numericData = new double[dimensions];
for (int i = 0; i < dimensions; i++) {
numericData[i] = input.get(0).getFeatures().get(i).getValue().asNumber();
}
final RealVector vectorData = new ArrayRealVector(numericData);
streamingGeneratorManager.getStreamingGenerator().update(vectorData);
}

final ExplainerType explainerType = configService.getExplainerType();

try {
final LocalExplainer<SaliencyResults> explainer = explainerFactory.getExplainer(explainerType, input);

final LocalExplainer<SaliencyResults> explainer = explainerFactory.getExplainer(explainerType);
Log.info("Sending explaining request to " + predictorURI);
final SaliencyResults results = explainer.explainAsync(prediction, provider).get();
final SaliencyExplanationResponse response = SaliencyExplanationResponse.fromSaliencyResults(results);

try {
String resultsJson = objectMapper.writeValueAsString(results);
return Response.ok(response, MediaType.APPLICATION_JSON).build();
} catch (Exception e) {
return Response.serverError().entity("Error serializing SaliencyResults to JSON: " + e.getMessage())
.build();
}
} catch (IllegalArgumentException e) {
return Response.serverError().entity("Explainer type not supported: " + explainerType).build();
return Response.serverError().entity("Error: " + e.getMessage()).build();
}

}
Expand Down
29 changes: 29 additions & 0 deletions src/main/java/org/kie/trustyai/StreamingGeneratorManager.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.kie.trustyai;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.kie.trustyai.explainability.local.shap.background.StreamingGenerator;
import org.kie.trustyai.statistics.MultivariateOnlineEstimator;
import org.kie.trustyai.statistics.distributions.gaussian.MultivariateGaussianParameters;
import org.kie.trustyai.statistics.estimators.WelfordOnlineEstimator;

@Singleton
public class StreamingGeneratorManager {

@Inject
ConfigService configService;

private StreamingGenerator streamingGenerator = null;

public synchronized void initialize(int dimensions) {
if (streamingGenerator == null && configService.getExplainerType() == ExplainerType.SHAP) {
final MultivariateOnlineEstimator<MultivariateGaussianParameters> estimator = new WelfordOnlineEstimator(dimensions);
streamingGenerator = new StreamingGenerator(dimensions, configService.getQueueSize(), configService.getDiversitySize(), estimator);
}
}

public StreamingGenerator getStreamingGenerator() {
return streamingGenerator;
}

}
Loading
Loading