Skip to content

Commit

Permalink
fix fastsam cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
drcege committed Nov 12, 2024
1 parent ee86360 commit a906daa
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 3 additions & 4 deletions data_juicer/ops/mapper/image_segment_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def __init__(self, imgsz=1024, conf=0.05, iou=0.5, *args, **kwargs):
"""
super().__init__(*args, **kwargs)

self.model_key = prepare_model(model_type='fastsam',
model_path='FastSAM-x.pt')

self.imgsz = imgsz
self.conf = conf
self.iou = iou

self.model_key = prepare_model(model_type='fastsam',
model_path='FastSAM-x.pt')

def process_single(self, sample, rank=None, context=False):
# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
Expand All @@ -61,7 +61,6 @@ def process_single(self, sample, rank=None, context=False):
conf=self.conf,
iou=self.iou,
verbose=False)[0]
# breakpoint()
sample[Fields.bbox_tag].append(masks.boxes.xywh.cpu().numpy())

# match schema
Expand Down
4 changes: 3 additions & 1 deletion data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ def prepare_diffusion_model(pretrained_model_name_or_path, diffusion_type,


def prepare_fastsam_model(model_path, **model_params):
return ultralytics.FastSAM(model_path)
device = model_params.pop('device', 'cpu')
model = ultralytics.FastSAM(model_path).to(device)
return model


def prepare_fasttext_model(model_name='lid.176.bin', **model_params):
Expand Down

0 comments on commit a906daa

Please sign in to comment.