Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advice needed: text summarization via a pre-trained model on a local computer #587

Open
JeeDevUser opened this issue Nov 24, 2024 · 7 comments

Comments

@JeeDevUser
Copy link

Hello everyone, I intend to do text summarization using some (any) pretrained model, which I would load from the local computer.

If something like this is possible, I am interested in:

  • which model is suitable for that
  • where can I download it from.
  • how do I load it and do the summarization

Any help/online resource/... is welcome

@Craigacp
Copy link
Collaborator

There aren't many pretrained TF models for summarization on Hugging Face, and I think those models tend to be saved in Keras format anyway so will need converting into TF SavedModel format before they can be used with TF Java. I'd probably start with something like this - https://huggingface.co/google/flan-t5-large, but you could also use a decoder-only model (i.e. an LLM) though there aren't many of those in TF or Keras h5 format either (maybe Google's gemma model?). You'll need to load in the tokenizer (which for T5 is sentencepiece) and then tokenize the inputs, before passing them into TF-Java to get the output predicted token, loop it back around and keep predicting until you hit a termination condition (like </s> or a fixed token count).

To be honest, it might be simpler to use jlama or llama3.java and use a pre-trained llama 3 checkpoint. Those models are fairly good at summarization and the libraries already have tokenization and the token generation loop sorted which you'd need to implement on top of TF-Java. I don't think either of them support GPUs, so it depends how big your workload is in terms of batch size and latency.

@JeeDevUser
Copy link
Author

@Craigacp, thanks for advice!

If I understand correctly, the tokenizer I would have to use is Sentencepiece (I would try with pegasus-xsum pretrained, but it uses the same tokenizer, so it doesn't matter) in order to tokenize inputs.

At this moment, I generate the savedModel format from the pagesus-xsum model (=got the .pb and/variables files etc....)

Now I wonder, if TensorFlow Java API supports SentencePiece somehow or not?
Maybe the question doesn't make sense (I'm still new to this) - mainly, is there a way to somehow use this tokenizer kind of tokenizer for tokenization... I want to avoid any dependency on Python in the runtime, so I'm only interested in a pure Java implementation...huh?

@JeeDevUser
Copy link
Author

JeeDevUser commented Nov 25, 2024

...and, for example, DeepJavaLibrary (DJL) have the Sentencepiece implementation
I wonder, could I use it to tokenize input, before passing them to TF Java?

@Craigacp
Copy link
Collaborator

You can put the sentencepiece op from tensorflow-text into a TF graph, but that is a bit of a pain to do from Java and will require you to understand how TF works on a deeper level. Otherwise DJL's wrapper should be fine, and there are others available. You'll still need to write the generative loop yourself, passing in the input tokens, the previously generated tokens and the key-value cache, then building a sampling mechanism.

@JeeDevUser
Copy link
Author

thanks @Craigacp, I made some progress, by using DJL Sentencepiece implementation, over TF SavedModel format I made. I successfully tokenized the input text. But I got stuck with what you said:

You'll still need to write the generative loop yourself, passing in the input tokens, the previously generated tokens and the key-value cache, then building a sampling mechanism.

  • here is relevant part of code:
         String text = "This is an example text for summarization.";
         int[] intTokens = spProcessor.encode(text); // DJL Sentencepiece
         System.out.println("Tokens: " + Arrays.toString(intTokens));

         // Converting tokens into  TInt32 Tensor
         IntNdArray intData = NdArrays.ofInts(Shape.of(1, intTokens.length));
         for (int i = 0; i < intTokens.length; i++) {
            intData.setInt(intTokens[i], 0, i);
         }
         // so far, so good... now:
     try (TInt32 loopInputTensor = Tensor.of(TInt32.class, Shape.of(1, intTokens.length), intData::copyTo)) {         
            // 2.TF SavedModel loading:
            String modelPath = "d:/Install/TensorFlow/models/pegasus_xsum/saved_model";
            SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");
            System.out.println("The saved Model successfully loaded!");

            // 3. getting the name of input and output tensors:
            SignatureDef signatureDef = model.metaGraphDef().getSignatureDefOrThrow("serving_default");
            String inputTensorName = signatureDef.getInputsOrThrow("input_ids").getName();
            String outputTensorName = signatureDef.getOutputsOrThrow("logits").getName();
            
             // 4. starting the inference:
            try (Tensor outputTensor = model.session().runner()
                     .feed(inputTensorName, loopInputTensor)
                     .fetch(outputTensorName)
                     .run().get(0)) {
                     // ---- I'M STUCK HERE, I DON'T KNOW WHAT TO DO...any help?!?
             }        
         }
         

@Craigacp
Copy link
Collaborator

The output of the model will be a probability distribution over tokens. The simplest thing is to do "greedy decoding", where you pick the most likely token (i.e. the one with the highest probability), then you append that token id to your input tokens and run inference again.

@JeeDevUser
Copy link
Author

JeeDevUser commented Nov 26, 2024

ok, here is what I am doing:

         try (TInt32 loopInputTensor = Tensor.of(TInt32.class, Shape.of(1, intTokens.length), intData::copyTo)) {
            // 2. load  TensorFlow SavedModel:
            String modelPath = "d:/Install/TensorFlow/models/pegasus_xsum/saved_model";
            SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");
            System.out.println("The saved Model successfully loaded!");

            // 3. getting names of intput and output tensors:
            SignatureDef signatureDef = model.metaGraphDef().getSignatureDefOrThrow("serving_default");
            String inputTensorName = signatureDef.getInputsOrThrow("input_ids").getName();
            String outputTensorName = signatureDef.getOutputsOrThrow("logits").getName();


            // starting an inference:
            boolean stopCondition = false;
            int maxIterations = 50; // max number of iteration
            int eosTokenId = 1; // ID of EOS token (?)
            int[] currentInput = intTokens;

            while (!stopCondition && (maxIterations-- > 0)) {
               // create the tensor for the current input:
               IntNdArray currentData = NdArrays.ofInts(Shape.of(1, currentInput.length));
               for (int i = 0; i < currentInput.length; i++) {
                  currentData.setInt(currentInput[i], 0, i);
               }

               try (TInt32 currentInputTensor = Tensor.of(TInt32.class, Shape.of(1, currentInput.length), currentData::copyTo)) {
                  try (Tensor outputTensor = model.session().runner()
                           .feed(inputTensorName, currentInputTensor)
                           .fetch(outputTensorName)
                           .run().get(0)) {

                     // getting output from the model:
                     TFloat32 logits = (TFloat32) outputTensor;

                     // finding argmax values:
                     long vocabSize = logits.shape().size(2); // Vocabulary size
                     float maxProbability = Float.NEGATIVE_INFINITY;
                     int predictedToken = -1;

                     for (int i = 0; i < vocabSize; i++) {
                        float probability = logits.getFloat(0, currentInput.length - 1, i);
                        if (probability > maxProbability) {
                           maxProbability = probability;
                           predictedToken = i;
                        }
                     }

                     // Have we reached the EOS token?
                     if (predictedToken == eosTokenId) {
                        stopCondition = true;
                     } else {
                        // We add the intended token to the input:
                        currentInput = Arrays.copyOf(currentInput, currentInput.length + 1);
                        currentInput[currentInput.length - 1] = predictedToken;
                     }
                  }
               }
            }

            // decoding the final output:
            String summarizedText = spProcessor.decode(currentInput);
            System.out.println("Summarized Text: " + summarizedText);
         }

-question 1 : is this a good way, or is something missing?
-question 2: at the statement:

                  try (Tensor outputTensor = model.session().runner()
                           .feed(inputTensorName, currentInputTensor)
                           .fetch(outputTensorName)
                           .run().get(0))

I am getting:

org.tensorflow.exceptions.TFInvalidArgumentException: You must feed a value for placeholder tensor 'serving_default_attention_mask' with dtype int32 and shape [?,?]
-Any hint how to do that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants