diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java index a8a183de317c..451fc9676e70 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java @@ -123,7 +123,10 @@ public void zeroGradients() { NDManager systemManager = MxNDManager.getSystemManager(); for (NDArray array : systemManager.getManagedArrays()) { if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); + // To prevent memory leak we must close gradient after use. + try (NDArray gradient = array.getGradient()) { + gradient.subi(gradient); + } } } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index d090e08decb6..c5995053f5b2 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -75,8 +75,9 @@ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean c public void zeroGradients() { NDManager systemManager = PtNDManager.getSystemManager(); for (NDArray array : systemManager.getManagedArrays()) { - if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); + // To prevent memory leak we must close gradient after use. + try (NDArray gradient = array.getGradient()) { + gradient.subi(gradient); } } }