Skip to content

Commit

Permalink
update mean iou
Browse files Browse the repository at this point in the history
  • Loading branch information
weiyueli7 committed Feb 25, 2024
1 parent ffcf6bf commit 452be4c
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions evaluate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
78: 'hair drier',
79: 'toothbrush'}

def detecting_objects():
def detecting_objects(DIR):
print("Loading model (yolov8m.pt)...")
model = YOLO("yolov8m.pt")
# DIR = f"img_generations/img_generations_templatev0.3_lmd_plus_demo_gpt-4/run0"
Expand Down Expand Up @@ -178,43 +178,51 @@ def evaluate_image(yolo_file_path, original_prompt):

target_objs = set()
for obj, cor in original_prompt:
# Extract the noun from the object description (e.g. "a person" -> "person")
obj = extract_noun(obj)
print(obj)
# Convert the object name to its corresponding class ID
target_objs.add(find_key_for_value(CLASSES, obj))

converted_detections = []
all_detected_objects = set()
for det in yolo_detections:
class_id = det[0]
all_detected_objects.add(class_id)
if class_id in target_objs: # Class ID for 'umbrella'
if class_id in target_objs:
bbox = convert_center_to_corner(det[1:], img_width, img_height)
converted_detections.append(bbox)

# Evaluate Object Count Accuracy
expected_count = len(original_prompt)
detected_count = len(converted_detections)
count_accuracy = detected_count == expected_count
# count_accuracy = detected_count == expected_count
count_accuracy = abs(detected_count - expected_count) / expected_count

# Evaluate Bounding Box Accuracy
iou_threshold = 0.9
accurate_boxes = 0
# iou_threshold = 0.9
all_ious = []
# accurate_boxes = 0
for det_box in converted_detections:
cur_ious = []
for _, org_box in original_prompt:
iou = calculate_iou(det_box, org_box)
if iou > iou_threshold:
accurate_boxes += 1
break
cur_ious.append(iou)
# print(cur_ious)
all_ious.append(max(cur_ious))
# print(all_ious)
# if iou > iou_threshold:
# accurate_boxes += 1
# break

bbox_accuracy = accurate_boxes / expected_count
# bbox_accuracy = accurate_boxes / expected_count

class_id_to_name = CLASSES
expected_objects = target_objs
extra_detected_objects = all_detected_objects.difference(expected_objects)
extra_detected_objects_names = [class_id_to_name[obj_id] for obj_id in extra_detected_objects]

# Results
return count_accuracy, bbox_accuracy, extra_detected_objects_names
return count_accuracy, np.mean(all_ious), extra_detected_objects_names



Expand All @@ -228,23 +236,27 @@ def extract_noun(text):
return nouns[0]

if __name__ == "__main__":
parser = argparse.ArgumentParser()

# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--lm", default='gpt-4', type=str)
parser.add_argument("--template_version", default='v0.3', type=str)
parser.add_argument("--prompt_type", default='demo', type=str)

args = parser.parse_args()

DIR = f"img_generations/img_generations_template{args.template_version}_lmd_plus_{args.prompt_type}_{args.lm}/run0"
# DIR = f"img_generations/img_generations_templatev0.3_lmd_plus_demo_gpt-4/run0"
detecting_objects()

prompts = json.load(open(f"cache/cache_demo_{args.template_version}_{args.lm}.json"))
# Detecting objects from synthetic images
detecting_objects(DIR)

# Evaluate the detected objects
prompts = json.load(open(f"cache/cache_demo_{args.template_version}_{args.lm}.json"))
for ind, (key, value) in enumerate(prompts.items()):
original_prompt = eval(value[0].split("Background prompt:")[0])
yolo_path = f"object_detection/{DIR}/{cur_time}/results_{ind}/labels/img_4.txt"
try:
print(evaluate_image(yolo_path, original_prompt))
eval_result = evaluate_image(yolo_path, original_prompt)
print(f"extra/miss ratio: {eval_result[0]}, mean_iou: {eval_result[1]}, extra_detected_objects: {eval_result[2]}")
print("=================================")
except:
pass

0 comments on commit 452be4c

Please sign in to comment.