diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index 0eaed72a709..8d32ccb4b6c 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -30,6 +30,7 @@ import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Arrays; import java.util.List; import java.util.Locale; @@ -686,7 +687,6 @@ static PaddingStrategy fromValue(String value) { /** The builder for creating huggingface tokenizer. */ public static final class Builder { - private Path tokenizerPath; private NDManager manager; private Map options; @@ -724,7 +724,7 @@ public Builder optTokenizerName(String tokenizerName) { * @return this builder */ public Builder optTokenizerPath(Path tokenizerPath) { - this.tokenizerPath = tokenizerPath; + options.putIfAbsent("tokenizerPath", tokenizerPath.toString()); return this; } @@ -894,9 +894,11 @@ public HuggingFaceTokenizer build() throws IOException { if (tokenizerName != null) { return managed(HuggingFaceTokenizer.newInstance(tokenizerName, options)); } - if (tokenizerPath == null) { + String path = options.get("tokenizerPath"); + if (path == null) { throw new IllegalArgumentException("Missing tokenizer path."); } + Path tokenizerPath = Paths.get(path); if (Files.isDirectory(tokenizerPath)) { Path tokenizerFile = tokenizerPath.resolve("tokenizer.json"); if (Files.exists(tokenizerFile)) { diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java index ef4015d94d3..2a98f63db65 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java @@ -64,6 +64,7 @@ public void testCrossEncoderTranslator() .optBlock(block) .optEngine("PyTorch") .optArgument("tokenizer", "bert-base-cased") + .optArgument("tokenizerPath", modelDir) .optOption("hasParameter", "false") .optTranslatorFactory(new CrossEncoderTranslatorFactory()) .build();