Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cannot import name 'get_dino_model' from 'utils' #90

Open
arthurwolf opened this issue Nov 20, 2024 · 2 comments
Open

cannot import name 'get_dino_model' from 'utils' #90

arthurwolf opened this issue Nov 20, 2024 · 2 comments

Comments

@arthurwolf
Copy link

Hello.

I tried running omniparser.py and got this error:

❯ python omniparser.py
Traceback (most recent call last):
  File "/home/arthur/dev/ai/OmniParser/omniparser.py", line 1, in <module>
    from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor,  get_dino_model, get_yolo_model
ImportError: cannot import name 'get_dino_model' from 'utils' (/home/arthur/dev/ai/OmniParser/utils.py). Did you mean: 'get_yolo_model'?

What am I doing wrong?

( Note: The Gradio demo works fine )

Thanks!

@fdciabdul
Copy link

from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model
import torch
from ultralytics import YOLO
from PIL import Image
from typing import Dict, Tuple, List
import io
import base64


config = {
    'som_model_path': 'weights/icon_detect_v1_5/best.pt',
    'caption_model_name': 'blip2', 
    'caption_model_path': 'weights/icon_caption_blip2',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'draw_bbox_config': {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 3,
        'thickness': 3,
    },
    'BOX_TRESHOLD': 0.05
}

class Omniparser(object):
    def __init__(self, config: Dict):
        self.config = config
        
        self.som_model = get_yolo_model(model_path=config['som_model_path'])
        
        self.caption_model_processor = get_caption_model_processor(
            model_name=config.get('caption_model_name', "blip2"),
            model_name_or_path=config['caption_model_path']
        )
        self.caption_model_processor['model'].to(torch.float32)

    def parse(self, image_path: str):
        print('Parsing image:', image_path)
        
        ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
            image_path, display_img=False, output_bb_format='xyxy',
            goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold': 0.9}
        )
        text, ocr_bbox = ocr_bbox_rslt

        print("OCR Text:", text)
        print("OCR Bounding Boxes:", ocr_bbox)

        draw_bbox_config = self.config['draw_bbox_config']
        BOX_TRESHOLD = self.config['BOX_TRESHOLD']

        dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
            image_path, self.som_model, BOX_TRESHOLD=BOX_TRESHOLD, output_coord_in_ratio=False,
            ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config,
            caption_model_processor=self.caption_model_processor, ocr_text=text, use_local_semantics=False
        )

        image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
        labeled_image_path = "labeled_image.png"
        image.save(labeled_image_path)
        print(f"Labeled image saved at {labeled_image_path}")

        print("Parsed Content List:", parsed_content_list)

        return_list = [
            {
                'from': 'omniparser',
                'shape': {'x': coord[0], 'y': coord[1], 'width': coord[2], 'height': coord[3]},
                'text': parsed_content_list[i].get('text', 'None'),  
                'type': 'text'
            }
            for i, (k, coord) in enumerate(label_coordinates.items())
            if i < len(parsed_content_list)
        ]
        return_list.extend(
            [
                {
                    'from': 'omniparser',
                    'shape': {'x': coord[0], 'y': coord[1], 'width': coord[2], 'height': coord[3]},
                    'text': 'None',
                    'type': 'icon'
                }
                for i, (k, coord) in enumerate(label_coordinates.items())
                if i >= len(parsed_content_list)
            ]
        )

        return [image, return_list]

parser = Omniparser(config)
image_path = 'mutasi-details.png'

import time
s = time.time()
image, parsed_content_list = parser.parse(image_path)
device = config['device']
print(f"Time taken for Omniparser on {device}: {time.time() - s}")

for item in parsed_content_list:
    print(item)

@arthurwolf
Copy link
Author

Thanks, I'll try it soon.

Any advice on how to speed the inference up? I only need to provide it an image, and get back the data as JSON/whatever format, nothing more (no labelled image, etc).

Aside from that, any strategy to increase speed otherwise? I need to read the screen once a second, and it currently takes 10 seconds... Would reducing resolution help?

Thanks again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants