Skip to content

Commit

Permalink
SAM2 AMG example mIoU, perf numbers and more SAM2 model annotations (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Oct 30, 2024
1 parent 186e578 commit 4f1fc4c
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 83 deletions.
98 changes: 98 additions & 0 deletions examples/sam2_amg_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,101 @@
```
curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output path/to/output.png
```

## Example script to collect rles

Start the server

```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
```

Collect the rles

```
xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < image_paths > rle_masks
```

## mIoU scores on random subset of sav validation dataset

Experiments run on H100 and with batch size 1

| mode | mIoU | mask count mismatch | avg. ms per request |
| --- |--- | ------------------ | ----------------- |
| baseline | 1.0 | 0 | 786 |
| ao | 1.0 | 0 | 738 |
| fast | 0.95 | 190 | 563 |
| furious | 0 | 1000 | 204 |

mask count mismatch counts the number of requests where the number of masks differ from the baseline.
For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19.
We exclude these examples from the mIoU calculation.

### 1. Create a random subset of 1000 images
```
find sav_val -type f > sav_val_image_paths
shuf -n 1000 sav_val_image_paths > sav_val_image_paths_shuf_1000
```

### 2. Use the baseline (https://github.com/facebookresearch/sam2) to generate rles

Make sure you've installed https://github.com/facebookresearch/sam2

Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --baseline
```

Generate and save rles (one line per json via `-w "\n"`)
```
$ time xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_baseline_shuf_1000
real 13m6.374s
user 0m3.349s
sys 0m4.137s
```

### 3. Start server with torchao variant of SAM2
Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname>
```

Generate and save rles (one line per json via `-w "\n"`)
```
$ time xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_shuf_1000
real 12m18.916s
user 0m3.506s
sys 0m4.350s
```

### 4. Start server with torchao variant of SAM2 and `--fast` optimizations
Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
```

Generate and save rles (one line per json via `-w "\n"`)
```
$ time xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_fast_shuf_1000
real 9m23.912s
user 0m3.271s
sys 0m4.138s
```

### 5. Start server with torchao variant of SAM2 and `--fast` and `--furious` optimizations
Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast --furious
```

Generate and save rles (one line per json via `-w "\n"`)
```
$ time xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_fast_furious_shuf_1000
real 3m24.383s
user 0m3.583s
sys 0m4.519s
```
41 changes: 41 additions & 0 deletions examples/sam2_amg_server/compare_rle_lists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import fire
import torch
import json
from sam2.utils.amg import rle_to_mask

"""
Script to calculate mIoU given two lists of rles from upload_rle endpoint
of server.
"""


def iou(mask1, mask2):
assert mask1.dim() == 2
assert mask2.dim() == 2
intersection = torch.logical_and(mask1, mask2)
union = torch.logical_or(mask1, mask2)
return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))


def main(path0, path1):
fail_count = 0
miou_sum = 0.0
miou_count = 0
with open(path0, 'r') as f0, open(path1, 'r') as f1:
for line0, line1 in zip(f0, f1):
masks0 = json.loads(line0)
masks1 = json.loads(line1)
if masks0.keys() != masks1.keys():
fail_count += 1
continue
for mask0, mask1 in zip(masks0.values(), masks1.values()):
mask0 = torch.from_numpy(rle_to_mask(mask0))
mask1 = torch.from_numpy(rle_to_mask(mask1))
miou_sum += iou(mask0, mask1).item()
miou_count += 1

print(f"fail_count: {fail_count} mIoU: {miou_sum / miou_count}")


if __name__ == "__main__":
fire.Fire(main)
1 change: 1 addition & 0 deletions examples/sam2_amg_server/dog_rle.json

Large diffs are not rendered by default.

113 changes: 104 additions & 9 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@

# torch.set_float32_matmul_precision('high')


def iou(mask1, mask2):
assert mask1.dim() == 2
assert mask2.dim() == 2
intersection = torch.logical_and(mask1, mask2)
union = torch.logical_or(mask1, mask2)
return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))


def show_anns(anns):
if len(anns) == 0:
return
Expand All @@ -49,17 +58,44 @@ def show_anns(anns):
return torch.stack(ms)


def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=False, points_per_batch=64):
def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result


def main(checkpoint_path,
baseline=False,
fast=False,
furious=False,
unittest=False,
benchmark=False,
profile=None,
verbose=False,
points_per_batch=64,
port=5000,
host="127.0.0.1",
dry=False):
if verbose:
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logging.info(f"Running with fast set to {fast} and furious set to {furious}")
logging.info(f"Running with port {port} and host {host}")

if fast:
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
else:
if baseline:
logging.info(f"Importing sam2 from outside of torchao. If this errors, install https://github.com/facebookresearch/sam2")
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.utils.amg import rle_to_mask
else:
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from torchao._models.sam2.utils.amg import rle_to_mask

device = "cuda"
from pathlib import Path
Expand All @@ -70,7 +106,7 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

logging.info(f"Using {points_per_batch} points_per_batch")
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch)
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")

if furious:
torch.set_float32_matmul_precision('high')
Expand Down Expand Up @@ -107,6 +143,37 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
logging.info(f"Running one iteration to compile.")
masks = mask_generator.generate(example_image)
logging.info(f"First iteration took {time.time() - t}s.")
if unittest:
logging.info(f"Running strict comparison to reference mask")
import json
ref_masks = json.loads(open("dog_rle.json").read())
ret_data = {}
for mask_id in range(len(masks)):
ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"]
v0_areas = []
v1_areas = []
miou_sum = 0.0
miou_count = 0
for k0 in ref_masks:
assert k0 in ret_data, f"Expected {k0} to be in return data"
from torchao._models.sam2.utils.amg import area_from_rle
v0_area = area_from_rle(ref_masks[k0])
v1_area = area_from_rle(ret_data[k0])
v0_areas.append(v0_area)
v1_areas.append(v1_area)
if v0_area != v1_area:
print(f"v0 area {v0_area} doesn't match v1 area {v1_area}")
v0_mask = torch.from_numpy(rle_to_mask(ref_masks[k0]))
v1_mask = torch.from_numpy(rle_to_mask(ret_data[k0]))
if not torch.allclose(v0_mask, v1_mask):
miou_sum += iou(v0_mask, v1_mask)
miou_count += 1
print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}")
if miou_count == 0:
print("Masks exactly match reference.")
else:
print(f"mIoU is {miou_sum / miou_count}")

if benchmark:
logging.info(f"Running 3 warumup iterations.")
for _ in range(3):
Expand All @@ -121,7 +188,13 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
print(f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%")
return

if profile is not None:
print(f"Saving profile under {profile}")
profiler_runner(profile, mask_generator.generate, example_image)

if dry:
return

app = FastAPI()

Expand All @@ -133,6 +206,25 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
allow_methods=["*"],
allow_headers=["*"],
)

@app.post("/upload_rle")
async def upload_rle(image: UploadFile = File(...)):
# Save the uploaded image to a temporary location
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{image.filename}")
with open(temp_file.name, "wb") as b:
shutil.copyfileobj(image.file, b)

# Read the image back into memory to send as response
example_image = cv2.imread(temp_file.name)
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
t = time.time()
with torch.backends.cuda.sdp_kernel(enable_cudnn=True):
masks = mask_generator.generate(example_image)
print(f"Took {time.time() - t} to generate a mask for input image.")
ret_data = {}
for mask_id in range(len(masks)):
ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"]
return ret_data

@app.post("/upload")
async def upload_image(image: UploadFile = File(...)):
Expand All @@ -143,13 +235,16 @@ async def upload_image(image: UploadFile = File(...)):

# Read the image back into memory to send as response
example_image = cv2.imread(temp_file.name)
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
t = time.time()
with torch.backends.cuda.sdp_kernel(enable_cudnn=True):
masks = mask_generator.generate(example_image)
print(f"Took {time.time() - t} to generate a mask for input image.")
# Save an example
plt.figure(figsize=(example_image.shape[1]/100., example_image.shape[0]/100.), dpi=100)
plt.imshow(example_image)
for i in range(len(masks)):
masks[i]["segmentation"] = rle_to_mask(masks[i]["segmentation"])
show_anns(masks)
plt.axis('off')
plt.tight_layout()
Expand All @@ -163,7 +258,7 @@ async def upload_image(image: UploadFile = File(...)):
return StreamingResponse(BytesIO(image_data), media_type="image/png")


uvicorn.run(app, host="127.0.0.1", port=5000, log_level="info")
uvicorn.run(app, host=host, port=port, log_level="info")

if __name__ == "__main__":
fire.Fire(main)
Loading

0 comments on commit 4f1fc4c

Please sign in to comment.