diff --git a/visualization.py b/visualization.py index 10b19c0..fcda18e 100644 --- a/visualization.py +++ b/visualization.py @@ -30,7 +30,7 @@ def merge(image, token_dict, patch_size=14, alpha=0.2, line_color=np.array([200, img = np.asarray(image, dtype=np.uint8).copy() h, w, _ = img.shape - patch_num_h, patch_num_w = w // patch_size, w // patch_size + patch_num_h, patch_num_w = h // patch_size, w // patch_size color_map = {} idx = token_dict["idx_token"].tolist()[0] @@ -137,4 +137,4 @@ def merge(image, token_dict, patch_size=14, alpha=0.2, line_color=np.array([200, token_dict2 = block2(ctm2(token_dict1)) img = merge(image, token_dict2, alpha=0.2, line_color=np.array([255, 255, 255])) - img.save("{}/stage3.jpg".format(output_file)) \ No newline at end of file + img.save("{}/stage3.jpg".format(output_file))