Skip to content

Commit

Permalink
fix mistakes and add more materials
Browse files Browse the repository at this point in the history
  • Loading branch information
Ingvarstep committed Oct 5, 2024
1 parent 485ad55 commit c6d1683
Show file tree
Hide file tree
Showing 6 changed files with 2,746 additions and 2,196 deletions.
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@

# ⭐GLiClass.c: Generalist and Lightweight Model for Sequence Classification in C
# ⭐GLiClass.js: Generalist and Lightweight Model for Sequence Classification in JS

GLiClass.c is a C - based inference engine for running GLiClass(Generalist and Lightweight Model for Sequence Classification) models. This is an efficient zero-shot classifier inspired by [GLiNER](https://github.com/urchade/GLiNER) work. It demonstrates the same performance as a cross-encoder while being more compute-efficient because classification is done at a single forward path.
GLiClass.js is a TypeScript - based inference engine for running GLiClass(Generalist and Lightweight Model for Sequence Classification) models. This is an efficient zero-shot classifier inspired by [GLiNER](https://github.com/urchade/GLiNER) work. It demonstrates the same performance as a cross-encoder while being more compute-efficient because classification is done at a single forward path.

It can be used for topic classification, sentiment analysis and as a reranker in RAG pipelines.

<p align="center">
<img src="kg.png" style="position: relative; top: 5px;">
<a href="https://www.knowledgator.com/"> Knowledgator</a>
<a href="https://www.knowledgator.com/"> ☋ Knowledgator</a>
<span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>
<a href="https://www.linkedin.com/company/knowledgator/">✔️ LinkedIn</a>
<span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>
Expand All @@ -18,7 +17,7 @@ It can be used for topic classification, sentiment analysis and as a reranker in
<a href="https://huggingface.co/collections/knowledgator/gliclass-6661838823756265f2ac3848">🤗 GliClass Collection</a>
</p>

## 🌟 Key Features
## 💫 Key Features

- Flexible entity recognition without predefined categories
- Lightweight and fast inference
Expand Down Expand Up @@ -61,19 +60,19 @@ console.log(decoded);

#### ONNX settings API

- modelPath: can be either a URL to a local model as in the basic example, or it can also be the Model itself as an array of binary data.
- executionProvider: these are the same providers that ONNX web supports, currently we allow `webgpu` (recommended), `cpu`, `wasm`, `webgl` but more can be added
- wasmPaths: Path to the wasm binaries, this can be either a URL to the binaries like a CDN url, or a local path to a folder with the binaries.
- multiThread: wether to multithread at all, only relevent for wasm and cpu exeuction providers.
- multiThread: When choosing the wasm or cpu provider, multiThread will allow you to specify the number of cores you want to use.
- fetchBinary: will prefetch the binary from the default or provided wasm paths
- *modelPath*: can be either a URL to a local model as in the basic example, or it can also be the Model itself as an array of binary data.
- *executionProvider*: these are the same providers that ONNX web supports, currently we allow `webgpu` (recommended), `cpu`, `wasm`, `webgl` but more can be added
- *wasmPaths*: Path to the wasm binaries, this can be either a URL to the binaries like a CDN url, or a local path to a folder with the binaries.
- *multiThread*: wether to multithread at all, only relevent for wasm and cpu exeuction providers.
- *multiThread*: When choosing the wasm or cpu provider, multiThread will allow you to specify the number of cores you want to use.
- *fetchBinary*: will prefetch the binary from the default or provided wasm paths

## 🛠 Setup & Model Preparation

To use GLiNER models in a web environment, you need an ONNX format model. You can:

1. Search for pre-converted models on [HuggingFace](https://huggingface.co/onnx-community?search_models=gliclass)
2. Convert a model yourself using the [official Python script](https://github.com/Knowledgator/GLiClass.c/blob/main/ONNX_CONVERTING/convert_to_onnx.py)
2. Convert a model yourself using the Python `convert_to_onnx.py`.

### Converting to ONNX Format

Expand Down
100 changes: 100 additions & 0 deletions convert_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import argparse
import numpy as np

from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer

import torch, json
from onnxruntime.quantization import quantize_dynamic, QuantType

def get_original_logits(model, tokenized_inputs) -> list:
with torch.no_grad():
model_output = model(**tokenized_inputs)
logits = model_output.logits
logits = logits.round(decimals=5)

return logits.tolist()

def create_config(original_model_name, architecture_type, prompt_first, original_logits, save_path) -> None:
data = {
"original_model_name" : original_model_name,
"architecture_type" : architecture_type,
"prompt_first" : prompt_first,
"original_logits" : original_logits
}

with open(save_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default= "knowledgator/gliclass-base-v1.0")
parser.add_argument('--save_path', type=str, default = 'onnx/')
parser.add_argument('--quantize', type=bool, default = True)
parser.add_argument('--classification_type', type=str, default = "multi-label")

args = parser.parse_args()

if args.classification_type not in ['single-label', "multi-label"]:
raise NotImplementedError("This type is not supported yet")

os.makedirs(args.save_path, exist_ok= True)

onnx_save_path = os.path.join(args.save_path, "model.onnx")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print("Loading a model...")
gliclass_model = GLiClassModel.from_pretrained(args.model_path)
architecture_type = gliclass_model.config.architecture_type
prompt_first = gliclass_model.config.prompt_first
if architecture_type != 'uni-encoder':
raise NotImplementedError("This artchitecture is not implemented for ONNX yet")

tokenizer = AutoTokenizer.from_pretrained(args.model_path)
pipeline = ZeroShotClassificationPipeline(gliclass_model, tokenizer, classification_type=args.classification_type, device=device)

text = "ONNX is an open-source format designed to enable the interoperability of AI models across various frameworks and tools."
labels = ['format', 'model', 'tool', 'cat']

tokenized_inputs = pipeline.pipe.prepare_inputs(text, labels)

all_inputs = (tokenized_inputs['input_ids'], tokenized_inputs['attention_mask'])
input_names = ['input_ids', 'attention_mask']
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"logits": {0: "position", 1: "batch_size"}
}

print("Converting...")
torch.onnx.export(
gliclass_model, # Model
all_inputs, # Inputs for exprt
onnx_save_path, # output file name
input_names=input_names, # Output data name
output_names=["logits"], # output logits names
dynamic_axes=dynamic_axes, # Dynamic Axes
opset_version=14
)

if args.quantize:
quantized_save_path = os.path.join(args.save_path, "model-int8-quantized.onnx")
# Quantize the ONNX model
print("Quantizing the model...")
quantize_dynamic(
onnx_save_path, # Input model
quantized_save_path, # Output model
weight_type = QuantType.QUInt8 # Quantize weights to 8-bit integers
)
print("Creating configuration file...")
config_path = args.save_path + "config.json"
create_config(
original_model_name = args.model_path,
architecture_type = architecture_type,
prompt_first= prompt_first,
original_logits = get_original_logits(gliclass_model, tokenized_inputs),
save_path = config_path
)

print("Done")
Loading

0 comments on commit c6d1683

Please sign in to comment.