Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Leak when calling ConcreteFunction #477

Open
lucaro opened this issue Nov 20, 2022 · 6 comments
Open

Memory Leak when calling ConcreteFunction #477

lucaro opened this issue Nov 20, 2022 · 6 comments

Comments

@lucaro
Copy link

lucaro commented Nov 20, 2022

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04 x86_64): Windows 10
  • TensorFlow version (use command below): 0.4.2

I have an application where I need to write a large amount of data as TFRecords to be used externally. To encode these records properly, I need to invoke tf.io.serializeTensor inside a ConcreteFunction. When running this repeatedly, memory usage increases until the program eventually runs out of memory and crashes. When inspecting the process with VisualVM, I can see that the memory does not build up in the JVM, so it must be some sort of memory leak when calling native code. I added a minimal example below. Did I improperly close something or is this indeed a bug?

Code to reproduce the issue

import java.util.Random;
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.io.SerializeTensor;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TString;

public class MemoryLeakTest {

  public static Signature serializeTensor(Ops tf) {
    Placeholder<TFloat32> input = tf.placeholder(TFloat32.class);
    SerializeTensor output = tf.io.serializeTensor(input);
    return Signature.builder().input("tensor", input).output("out", output).build();
  }

  public static void main(String[] args) {

    ConcreteFunction function = ConcreteFunction.create(MemoryLeakTest::serializeTensor);

    Random random = new Random(0);

    //run same operation for many tensors
    for (int loop = 0; loop < 10000; loop++) {

      //generate some tensor
      float[] arr = new float[512 * 512 * 512];
      for (int i = 0; i < arr.length; i++) {
        arr[i] = random.nextFloat();
      }
      FloatDataBuffer buf = DataBuffers.of(arr);
      TFloat32 inputTensor = TFloat32.tensorOf(Shape.of(512, 512, 512), buf);

      //serialize tensor
      TString outputTensor = (TString) function.call(inputTensor);
      DataBuffer<byte[]> outputBuffer = DataBuffers.ofObjects(byte[].class, 1);
      outputTensor.asBytes().read(outputBuffer);
      byte[] serialized = outputBuffer.getObject(0);

      System.out.println("Generated serialized tensor with " + serialized.length + " bytes");

      //close tensors
      inputTensor.close();
      outputTensor.close();
    }

    //close function in the end
    function.close();
  }

}
@Craigacp
Copy link
Collaborator

Craigacp commented Nov 22, 2022

I think the example looks ok. Can you take a heap dump with VisualVM and see what Java objects it's allocating?

@lucaro
Copy link
Author

lucaro commented Nov 22, 2022

VisualVM shows no exceptionally high memory usage, remaining stable over time:
grafik

The largest memory use within the JVM is generated by the expected arrays:
grafik

At the time I took this snapshot, Windows reported over 60GB in use by that process:
grafik

On my machine, this example leaks roughly 10GB of memory per minute.

@karllessard
Copy link
Collaborator

Just out of curiosity @lucaro , what happens if you don't copy the tensor values to a byte array but simple access the string in the tensor like this?

            TString outputTensor = (TString) function.call(inputTensor);
            String serialized = outputTensor.getObject();

            System.out.println("Generated serialized tensor with " + serialized.length() + " bytes");

Just want to narrow down the possible source of leakage

@karllessard
Copy link
Collaborator

Also, if that may unblock you until we find the problem, you can build a TensorProto directly instead of invoking tf.io.serializeTensor

@lucaro
Copy link
Author

lucaro commented Nov 24, 2022

Getting the String rather than the byte[] doesn't change anything with respect to the leakage.

Thanks for pointing me to the TensorProto, that was what I was looking for initially. However, I did some tests and what I get from serializing the TensorProto and what I get by calling tf.io.serializeTensor is not exactly the same. The tensor serialized via TensorProto produces an array that is 2 bytes shorter, specifically the leading two. I extended the example below. I haven't yet been able to check if the byte array generated via the TensorProto can also be read in e.g., a Python-based environment or if these two bytes are actually important.

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.io.SerializeTensor;
import org.tensorflow.proto.framework.TensorProto;
import org.tensorflow.proto.framework.TensorShapeProto;
import org.tensorflow.proto.framework.TensorShapeProto.Dim;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TString;

public class MemoryLeakTest {

  public static Signature serializeTensor(Ops tf) {
    Placeholder<TFloat32> input = tf.placeholder(TFloat32.class);
    SerializeTensor output = tf.io.serializeTensor(input);
    return Signature.builder().input("tensor", input).output("out", output).build();
  }

  public static void main(String[] args) {

    ConcreteFunction function = ConcreteFunction.create(MemoryLeakTest::serializeTensor);

    Random random = new Random(0);

    //run same operation for many tensors
    for (int loop = 0; loop < 10000; loop++) {

      //generate some tensor
      float[] arr = new float[512 * 512 * 512];
      ArrayList<Float> farr = new ArrayList<>(512 * 512 * 512);
      for (int i = 0; i < arr.length; i++) {
        arr[i] = random.nextFloat();
        farr.add(arr[i]);
      }
      FloatDataBuffer buf = DataBuffers.of(arr);
      TFloat32 inputTensor = TFloat32.tensorOf(Shape.of(512, 512, 512), buf);
      
      //serialize tensor via native tensorflow
      TString outputTensor = (TString) function.call(inputTensor);
      DataBuffer<byte[]> outputBuffer = DataBuffers.ofObjects(byte[].class, 1);
      outputTensor.asBytes().read(outputBuffer);
      byte[] serialized = outputBuffer.getObject(0);

      //serialize tensor via TensorProto
      TensorProto proto = TensorProto.newBuilder()
          .setTensorShape(TensorShapeProto.newBuilder()
              .addDim(Dim.newBuilder().setSize(512).build())
              .addDim(Dim.newBuilder().setSize(512).build())
              .addDim(Dim.newBuilder().setSize(512).build())
              .build())
          .addAllFloatVal(farr)
          .build();
      
      System.out.println("Generated serialized tensor 1 with " + serialized.length + " bytes");

      byte[] serialized2 = proto.toByteArray();

      System.out.println("Generated serialized tensor 2 with " + serialized2.length + " bytes");

      if (Arrays.equals(serialized, serialized2)) {
        System.out.println("Arrays are the same");
      } else {
        System.out.println("Arrays are not the same");
      }

      //close tensors
      inputTensor.close();
      outputTensor.close();

    }

    //close function in the end
    function.close();

  }

}
Generated serialized tensor 1 with 536870937 bytes
Generated serialized tensor 2 with 536870935 bytes
Arrays are not the same

grafik

@lucaro
Copy link
Author

lucaro commented Nov 24, 2022

So the two missing bytes came from me forgetting to set the data type, after adding .setDtype(DataType.DT_FLOAT) the two arrays have the same length. They are still not identical, but the TensorProto-serialized bytes can be read again, so my use-case is satisfied.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants