Skip to content

Commit

Permalink
Update webcam example
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Jan 7, 2024
1 parent 4d1bdd0 commit 7013eea
Showing 1 changed file with 14 additions and 22 deletions.
36 changes: 14 additions & 22 deletions 06_gpu_and_ml/obj_detection_webcam/webcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,20 @@

from fastapi import FastAPI, Request, Response
from fastapi.staticfiles import StaticFiles
from modal import (
Image,
Mount,
Stub,
asgi_app,
method,
)
from modal import Image, Mount, Stub, asgi_app, build, method

# We need to install [transformers](https://github.com/huggingface/transformers)
# which is a package Huggingface uses for all their models, but also
# [Pillow](https://python-pillow.org/) which lets us work with images from Python,
# and a system font for drawing.
#
# This example uses the `facebook/detr-resnet-50` pre-trained model, which is downloaded
# once at image build time using the `download_model` function and saved into the image.
# 'Baking' models into the `modal.Image` at build time provided the fastest cold start.
# once at image build time using the `@build` hook and saved into the image. 'Baking'
# models into the `modal.Image` at build time provided the fastest cold start.

model_repo_id = "facebook/detr-resnet-50"


def download_model():
from huggingface_hub import snapshot_download

snapshot_download(repo_id=model_repo_id, cache_dir="/cache")


stub = Stub("example-webcam-object-detection")
image = (
Image.debian_slim()
Expand All @@ -70,7 +58,6 @@ def download_model():
"transformers",
)
.apt_install("fonts-freefont-ttf")
.run_function(download_model)
)


Expand All @@ -92,14 +79,23 @@ def download_model():
# web interface can render it on top of the webcam view.


with image.imports():
import torch
from huggingface_hub import snapshot_download
from PIL import Image, ImageColor, ImageDraw, ImageFont
from transformers import DetrForObjectDetection, DetrImageProcessor


@stub.cls(
cpu=4,
image=image,
)
class ObjectDetection:
def __enter__(self):
from transformers import DetrForObjectDetection, DetrImageProcessor
@build()
def download_model(self):
snapshot_download(repo_id=model_repo_id, cache_dir="/cache")

def __enter__(self):
self.feature_extractor = DetrImageProcessor.from_pretrained(
model_repo_id,
cache_dir="/cache",
Expand All @@ -112,14 +108,10 @@ def __enter__(self):
@method()
def detect(self, img_data_in):
# Based on https://huggingface.co/spaces/nateraw/detr-object-detection/blob/main/app.py
from PIL import Image, ImageColor, ImageDraw, ImageFont

# Read png from input
image = Image.open(io.BytesIO(img_data_in)).convert("RGB")

# Make prediction
import torch

inputs = self.feature_extractor(image, return_tensors="pt")
outputs = self.model(**inputs)
img_size = torch.tensor([tuple(reversed(image.size))])
Expand Down

0 comments on commit 7013eea

Please sign in to comment.