-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial track_pybuda; 2313 rc * fix typo in ViLT tests * Modify test case paths * Modify clean up command to include onnx and tflite file formats * Fix ONNX download paths for ResNet and RetinaNet * Add NotImplemented error to Fuyu-8B model * Fix ONNX model paths * Add clean up for .h5 files * Remove .png files from clean up * Add wideresnet in model_demos * Add Xception in model_demos * Add GhostNet in model_demos * Fix model demos table * Fix WideResNet and Xception file paths * Stream image and label files * Patch Xception variant for GS silicon * Skip Fuyu-8B (WIP) * Remove commented code
- Loading branch information
Showing
69 changed files
with
2,921 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
model_demos/cv_demos/beit/pytorch_beit_classify_16_224_hf.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# BeiT Model Demo | ||
|
||
import os | ||
|
||
import pybuda | ||
import requests | ||
from PIL import Image | ||
from pybuda._C.backend_api import BackendDevice | ||
from transformers import BeitForImageClassification, BeitImageProcessor | ||
|
||
|
||
def run_beit_classify_224_hf_pytorch(variant="microsoft/beit-base-patch16-224"): | ||
|
||
# Set PyBuda configuration parameters | ||
compiler_cfg = pybuda.config._get_global_compiler_config() | ||
available_devices = pybuda.detect_available_devices() | ||
|
||
compiler_cfg.enable_t_streaming = True | ||
if variant == "microsoft/beit-base-patch16-224": | ||
compiler_cfg.retain_tvm_python_files = True | ||
compiler_cfg.enable_tvm_constant_prop = True | ||
if available_devices[0] == BackendDevice.Grayskull: | ||
os.environ["PYBUDA_ENABLE_STABLE_SOFTMAX"] = "1" | ||
elif variant == "microsoft/beit-large-patch16-224": | ||
if available_devices[0] == BackendDevice.Grayskull: | ||
compiler_cfg.retain_tvm_python_files = True | ||
compiler_cfg.enable_tvm_constant_prop = True | ||
os.environ["PYBUDA_ENABLE_STABLE_SOFTMAX"] = "1" | ||
else: | ||
compiler_cfg.default_df_override = pybuda._C.DataFormat.Float16_b | ||
|
||
# Create PyBuda module from PyTorch model | ||
image_processor = BeitImageProcessor.from_pretrained(variant) | ||
model = BeitForImageClassification.from_pretrained(variant) | ||
tt_model = pybuda.PyTorchModule("pt_beit_classif_16_224", model) | ||
|
||
# Get sample image | ||
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
sample_image = Image.open(requests.get(url, stream=True).raw) | ||
|
||
# Preprocessing | ||
img_tensor = image_processor(sample_image, return_tensors="pt").pixel_values | ||
|
||
# Run inference on Tenstorrent device | ||
output_q = pybuda.run_inference(tt_model, inputs=([img_tensor])) | ||
output = output_q.get()[0].value().detach().float().numpy() | ||
|
||
# Postprocessing | ||
predicted_class_idx = output.argmax(-1).item() | ||
|
||
# Print output | ||
print("Predicted class:", predicted_class_idx) | ||
print(model.config.id2label[predicted_class_idx]) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_beit_classify_224_hf_pytorch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
model_demos/cv_demos/efficientnet_lite/tflite_efficientnet_lite0_1x1.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# EfficientNet-Lite0 1x1 demo | ||
|
||
import os | ||
import shutil | ||
import tarfile | ||
|
||
import pybuda | ||
import requests | ||
import torch | ||
from pybuda import TFLiteModule | ||
from pybuda._C.backend_api import BackendDevice | ||
|
||
|
||
def run_efficientnet_lite0_1x1(): | ||
|
||
# Device specific configurations | ||
available_devices = pybuda.detect_available_devices() | ||
if available_devices: | ||
if available_devices[0] != BackendDevice.Wormhole_B0: | ||
raise NotImplementedError("Model not supported on Grayskull") | ||
|
||
# Set PyBuda configuration parameters | ||
compiler_cfg = pybuda.config._get_global_compiler_config() | ||
compiler_cfg.balancer_policy = "Ribbon" | ||
compiler_cfg.enable_t_streaming = True | ||
compiler_cfg.enable_tvm_constant_prop = True | ||
compiler_cfg.graph_solver_self_cut_type = "FastCut" | ||
compiler_cfg.default_df_override = pybuda.DataFormat.Float16 | ||
|
||
# Set PyBDUA environment variable | ||
os.environ["PYBUDA_OVERRIDE_DEVICE_YAML"] = "wormhole_b0_1x1.yaml" | ||
os.environ["PYBUDA_FORCE_CONV_MULTI_OP_FRACTURE"] = "1" | ||
|
||
# Download model weights | ||
MODEL = "efficientnet-lite0" | ||
url = f"https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/lite/{MODEL}.tar.gz" | ||
extract_to = "cv_demos/efficientnet_lite" | ||
file_name = url.split("/")[-1] | ||
response = requests.get(url, stream=True) | ||
with open(file_name, "wb") as f: | ||
f.write(response.content) | ||
with tarfile.open(file_name, "r:gz") as tar: | ||
tar.extractall(path=extract_to) | ||
os.remove(file_name) | ||
|
||
# Load model path | ||
tflite_path = f"cv_demos/efficientnet_lite/{MODEL}/{MODEL}-fp32.tflite" | ||
tt_model = TFLiteModule("tflite_efficientnet_lite0", tflite_path) | ||
|
||
# Run inference on Tenstorrent device | ||
input_shape = (1, 224, 224, 3) | ||
input_tensor = torch.rand(input_shape) | ||
|
||
output_q = pybuda.run_inference(tt_model, inputs=([input_tensor])) | ||
output = output_q.get()[0].value().detach().float().numpy() | ||
print(output) | ||
|
||
# Remove remanent files | ||
shutil.rmtree(extract_to + "/" + MODEL) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_efficientnet_lite0_1x1() |
63 changes: 63 additions & 0 deletions
63
model_demos/cv_demos/efficientnet_lite/tflite_efficientnet_lite4_1x1.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# EfficientNet-Lite4 1x1 demo | ||
|
||
import os | ||
import shutil | ||
import tarfile | ||
|
||
import pybuda | ||
import requests | ||
import torch | ||
from pybuda import TFLiteModule | ||
from pybuda._C.backend_api import BackendDevice | ||
|
||
|
||
def run_efficientnet_lite4_1x1(): | ||
|
||
# Device specific configurations | ||
available_devices = pybuda.detect_available_devices() | ||
if available_devices: | ||
if available_devices[0] != BackendDevice.Wormhole_B0: | ||
raise NotImplementedError("Model not supported on Grayskull") | ||
|
||
# Set PyBuda configuration parameters | ||
compiler_cfg = pybuda.config._get_global_compiler_config() | ||
compiler_cfg.balancer_policy = "Ribbon" | ||
compiler_cfg.enable_t_streaming = True | ||
compiler_cfg.enable_tvm_constant_prop = True | ||
compiler_cfg.graph_solver_self_cut_type = "FastCut" | ||
compiler_cfg.default_df_override = pybuda.DataFormat.Float16 | ||
|
||
# Set PyBDUA environment variable | ||
os.environ["PYBUDA_OVERRIDE_DEVICE_YAML"] = "wormhole_b0_1x1.yaml" | ||
os.environ["PYBUDA_FORCE_CONV_MULTI_OP_FRACTURE"] = "1" | ||
|
||
# Download model weights | ||
MODEL = "efficientnet-lite4" | ||
url = f"https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/lite/{MODEL}.tar.gz" | ||
extract_to = "cv_demos/efficientnet_lite" | ||
file_name = url.split("/")[-1] | ||
response = requests.get(url, stream=True) | ||
with open(file_name, "wb") as f: | ||
f.write(response.content) | ||
with tarfile.open(file_name, "r:gz") as tar: | ||
tar.extractall(path=extract_to) | ||
os.remove(file_name) | ||
|
||
# Load model path | ||
tflite_path = f"cv_demos/efficientnet_lite/{MODEL}/{MODEL}-fp32.tflite" | ||
tt_model = TFLiteModule("tflite_efficientnet_lite4", tflite_path) | ||
|
||
# STEP 3: Run inference on Tenstorrent device | ||
input_shape = (1, 320, 320, 3) | ||
input_tensor = torch.rand(input_shape) | ||
|
||
output_q = pybuda.run_inference(tt_model, inputs=([input_tensor])) | ||
output = output_q.get()[0].value().detach().float().numpy() | ||
print(output) | ||
|
||
# Remove remanent files | ||
shutil.rmtree(extract_to + "/" + MODEL) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_efficientnet_lite4_1x1() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Ghostnet | ||
|
||
import os | ||
import urllib | ||
|
||
import pybuda | ||
import requests | ||
import timm | ||
import torch | ||
from PIL import Image | ||
|
||
|
||
def run_ghostnet_timm(): | ||
# Set PyBuda configuration parameters | ||
compiler_cfg = pybuda.config._get_global_compiler_config() # load global compiler config object | ||
compiler_cfg.enable_t_streaming = True | ||
compiler_cfg.balancer_policy = "Ribbon" | ||
compiler_cfg.default_df_override = pybuda.DataFormat.Float16_b | ||
os.environ["PYBUDA_RIBBON2"] = "1" | ||
|
||
model = timm.create_model("ghostnet_100", pretrained=True) | ||
|
||
# Create PyBuda module from PyTorch model | ||
tt_model = pybuda.PyTorchModule("ghostnet_100_timm_pt", model) | ||
|
||
data_config = timm.data.resolve_data_config({}, model=model) | ||
transforms = timm.data.create_transform(**data_config) | ||
|
||
url = "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg" | ||
img = Image.open(requests.get(url, stream=True).raw).convert("RGB") | ||
img_tensor = transforms(img).unsqueeze(0) | ||
|
||
# Run inference on Tenstorrent device | ||
output_q = pybuda.run_inference(tt_model, inputs=([img_tensor])) | ||
output = output_q.get()[0].value() | ||
|
||
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5) | ||
|
||
# Get imagenet class mappings | ||
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" | ||
image_classes = urllib.request.urlopen(url) | ||
categories = [s.decode("utf-8").strip() for s in image_classes.readlines()] | ||
|
||
for i in range(top5_probabilities.size(1)): | ||
class_idx = top5_class_indices[0, i].item() | ||
class_prob = top5_probabilities[0, i].item() | ||
class_label = categories[class_idx] | ||
|
||
print(f"{class_label} : {class_prob}") | ||
|
||
|
||
if __name__ == "__main__": | ||
run_ghostnet_timm() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.