Skip to content

Commit

Permalink
PtndArrayEx.multiboxDetection() implementation (#2769)
Browse files Browse the repository at this point in the history
* Implement PtNDArraryEx.multiboxDetection

* MultiboxDetection - code cleanup

* MultiboxDetection - code cleanup

* MultiboxDetection - code cleanup

* MultiboxDetection - code cleanup

* format code

* Fix, add tests, and pass CI

---------

Co-authored-by: Zach Kimberg <[email protected]>
  • Loading branch information
juliangamble and zachgk authored Sep 28, 2023
1 parent 15fd0d0 commit 963332d
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 8 deletions.
15 changes: 15 additions & 0 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,21 @@ default void freezeParameters(boolean freeze) {
}
}

/**
* Freezes or unfreezes all parameters inside the block that pass the predicate.
*
* @param freeze true to mark as frozen rather than unfrozen
* @param pred true tests if the parameter should be updated
* @see Parameter#freeze(boolean)
*/
default void freezeParameters(boolean freeze, Predicate<Parameter> pred) {
for (Parameter parameter : getParameters().values()) {
if (pred.test(parameter)) {
parameter.freeze(freeze);
}
}
}

/**
* Validates that actual layout matches the expected layout.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai.djl.Model;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.Parameter;
import ai.djl.nn.Parameter.Type;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
Expand Down Expand Up @@ -189,7 +190,9 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
}
if (wasLoaded) {
// Unfreeze parameters if training directly
block.freezeParameters(false);
block.freezeParameters(
false,
p -> p.getType() != Type.RUNNING_MEAN && p.getType() != Type.RUNNING_VAR);
}
for (Pair<Initializer, Predicate<Parameter>> pair : initializer) {
if (pair.getKey() != null && pair.getValue() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDUtils;
Expand All @@ -24,6 +25,8 @@
import ai.djl.nn.recurrent.RNN;
import ai.djl.pytorch.jni.JniUtils;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

/** {@code PtNDArrayEx} is the PyTorch implementation of the {@link NDArrayEx}. */
Expand Down Expand Up @@ -760,7 +763,152 @@ public NDList multiBoxDetection(
float nmsThreshold,
boolean forceSuppress,
int nmsTopK) {
throw new UnsupportedOperationException("Not implemented");
assert (inputs.size() == 3);

NDArray clsProb = inputs.get(0);
NDArray locPred = inputs.get(1);
NDArray anchors = inputs.get(2).reshape(new Shape(-1, 4));

NDManager ndManager = array.getManager();

NDArray variances = ndManager.create(new float[] {0.1f, 0.1f, 0.2f, 0.2f});

assert (variances.size() == 4); // << "Variance size must be 4";
final int numClasses = (int) clsProb.size(1);
final int numAnchors = (int) clsProb.size(2);
final int numBatches = (int) clsProb.size(0);

final float[] pAnchor = anchors.toFloatArray();

// [id, prob, xmin, ymin, xmax, ymax]
// TODO Move to NDArray-based implementation
NDList batchOutputs = new NDList();
for (int nbatch = 0; nbatch < numBatches; ++nbatch) {
float[][] outputs = new float[numAnchors][6];
final float[] pClsProb = clsProb.get(nbatch).toFloatArray();
final float[] pLocPred = locPred.get(nbatch).toFloatArray();

for (int i = 0; i < numAnchors; ++i) {
// find the predicted class id and probability
float score = -1;
int id = 0;
for (int j = 1; j < numClasses; ++j) {
float temp = pClsProb[j * numAnchors + i];
if (temp > score) {
score = temp;
id = j;
}
}

if (id > 0 && score < threshold) {
id = 0;
}

// [id, prob, xmin, ymin, xmax, ymax]
outputs[i][0] = id - 1;
outputs[i][1] = score;
int offset = i * 4;
float[] pAnchorRow4 = new float[4];
pAnchorRow4[0] = pAnchor[offset];
pAnchorRow4[1] = pAnchor[offset + 1];
pAnchorRow4[2] = pAnchor[offset + 2];
pAnchorRow4[3] = pAnchor[offset + 3];
float[] pLocPredRow4 = new float[4];
pLocPredRow4[0] = pLocPred[offset];
pLocPredRow4[1] = pLocPred[offset + 1];
pLocPredRow4[2] = pLocPred[offset + 2];
pLocPredRow4[3] = pLocPred[offset + 3];
float[] outRowLast4 =
transformLocations(
pAnchorRow4,
pLocPredRow4,
clip,
variances.toFloatArray()[0],
variances.toFloatArray()[1],
variances.toFloatArray()[2],
variances.toFloatArray()[3]);
outputs[i][2] = outRowLast4[0];
outputs[i][3] = outRowLast4[1];
outputs[i][4] = outRowLast4[2];
outputs[i][5] = outRowLast4[3];
}

outputs =
Arrays.stream(outputs)
.filter(o -> o[0] >= 0)
.sorted(Comparator.comparing(o -> -o[1]))
.toArray(float[][]::new);

// apply nms
for (int i = 0; i < outputs.length; ++i) {
for (int j = i + 1; j < outputs.length; ++j) {
if (outputs[i][0] == outputs[j][0]) {
float[] outputsIRow4 = new float[4];
float[] outputsJRow4 = new float[4];
outputsIRow4[0] = outputs[i][2];
outputsIRow4[1] = outputs[i][3];
outputsIRow4[2] = outputs[i][4];
outputsIRow4[3] = outputs[i][5];
outputsJRow4[0] = outputs[j][2];
outputsJRow4[1] = outputs[j][3];
outputsJRow4[2] = outputs[j][4];
outputsJRow4[3] = outputs[j][5];
float iou = calculateOverlap(outputsIRow4, outputsJRow4);
if (iou >= nmsThreshold) {
outputs[j][0] = -1;
}
}
}
}
batchOutputs.add(ndManager.create(outputs));
} // end iter batch

NDArray pOutNDArray = NDArrays.stack(batchOutputs);
NDList resultNDList = new NDList();
resultNDList.add(pOutNDArray);
assert (resultNDList.size() == 1);
return resultNDList;
}

private float[] transformLocations(
final float[] anchors,
final float[] locPred,
final boolean clip,
final float vx,
final float vy,
final float vw,
final float vh) {
float[] outRowLast4 = new float[4];
// transform predictions to detection results
float al = anchors[0];
float at = anchors[1];
float ar = anchors[2];
float ab = anchors[3];
float aw = ar - al;
float ah = ab - at;
float ax = (al + ar) / 2.f;
float ay = (at + ab) / 2.f;
float px = locPred[0];
float py = locPred[1];
float pw = locPred[2];
float ph = locPred[3];
float ox = px * vx * aw + ax;
float oy = py * vy * ah + ay;
float ow = (float) (Math.exp(pw * vw) * aw / 2);
float oh = (float) (Math.exp(ph * vh) * ah / 2);
outRowLast4[0] = clip ? Math.max(0f, Math.min(1f, ox - ow)) : (ox - ow);
outRowLast4[1] = clip ? Math.max(0f, Math.min(1f, oy - oh)) : (oy - oh);
outRowLast4[2] = clip ? Math.max(0f, Math.min(1f, ox + ow)) : (ox + ow);
outRowLast4[3] = clip ? Math.max(0f, Math.min(1f, oy + oh)) : (oy + oh);
return outRowLast4;
}

private float calculateOverlap(final float[] a, final float[] b) {
float w = Math.max(0f, Math.min(a[2], b[2]) - Math.max(a[0], b[0]));
float h = Math.max(0f, Math.min(a[3], b[3]) - Math.max(a[1], b[1]));
float i = w * h;
float u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i;
return u <= 0.f ? 0f : (i / u);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ public class TrainPikachuTest {

@Test
public void testDetection() throws IOException, MalformedModelException, TranslateException {
TestRequirements.engine("MXNet");
TestRequirements.nightly();

String[] args;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
Expand Down Expand Up @@ -123,10 +124,8 @@ private TrainingConfig setupTrainingConfig() {
}

private ZooModel<Image, DetectedObjects> getModel() throws IOException, ModelException {
// SSD-pikachu model only available in MXNet
// TODO: Add PyTorch model to model zoo
TestUtils.requiresEngine("MXNet");

TestUtils.requiresEngine(
ModelZoo.getModelZoo("ai.djl.zoo").getSupportedEngines().toArray(String[]::new));
Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public String getGroupId() {
public Set<String> getSupportedEngines() {
Set<String> set = new HashSet<>();
set.add("MXNet");
set.add("PyTorch");
// TODO Currently WIP in supporting these two engines in the basic model zoo
// set.add("PyTorch");
// set.add("TensorFlow");
return set;
}
Expand Down

0 comments on commit 963332d

Please sign in to comment.