diff --git a/scripts/GNN_topk_analysis.py b/scripts/GNN_topk_analysis.py index 9ae9078d7..9d2d3a787 100644 --- a/scripts/GNN_topk_analysis.py +++ b/scripts/GNN_topk_analysis.py @@ -54,7 +54,7 @@ def main() -> None: in_domain_objects = parse_in_domain_objects(input_dir) object_counts = {in_domain_object: 0 for in_domain_object in in_domain_objects} - GNN_rank_matrix = [[0 for _ in range(len(in_domain_objects))] for _ in range(len(in_domain_objects))] + gnn_rank_matrix = [[0 for _ in range(len(in_domain_objects))] for _ in range(len(in_domain_objects))] for feature_file_path in sorted(input_dir.glob("situation*")): with open(feature_file_path / 'description.yaml', encoding="utf-8") as description_file: description_yaml = yaml.safe_load(description_file) @@ -66,10 +66,10 @@ def main() -> None: for i in range(len(gnn_objects)): gnn_object = gnn_objects[i] if gnn_object == 'small_single_mug': gnn_object = 'mug' - GNN_rank_matrix[in_domain_objects.index(expected_object)][in_domain_objects.index(gnn_object)] += len(gnn_objects) - i - for object_id in range(len(GNN_rank_matrix)): - GNN_rank_matrix[object_id] = [label_count / object_counts[in_domain_objects[object_id]] for label_count in arr[object_id]] - df = pd.DataFrame(GNN_rank_matrix, in_domain_objects, in_domain_objects) + gnn_rank_matrix[in_domain_objects.index(expected_object)][in_domain_objects.index(gnn_object)] += len(gnn_objects) - i + for object_id in range(len(gnn_rank_matrix)): + gnn_rank_matrix[object_id] = [label_count / object_counts[in_domain_objects[object_id]] for label_count in arr[object_id]] + df = pd.DataFrame(gnn_rank_matrix, in_domain_objects, in_domain_objects) sn.set(font_scale=0.9) sn.color_palette('colorblind') sn.set_context('paper')