diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 59a529eb..c47c44ff 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -17,7 +17,9 @@ def load_dataset_stats(dataset_name, device="cpu"): def loads_file(fn): return torch.load( - os.path.join(static_dir_path, fn), map_location=device + os.path.join(static_dir_path, fn), + map_location=device, + weights_only=True, ) data_mean = loads_file("parameter_mean.pt") # (d_features,) @@ -42,7 +44,9 @@ def load_static_data(dataset_name, device="cpu"): def loads_file(fn): return torch.load( - os.path.join(static_dir_path, fn), map_location=device + os.path.join(static_dir_path, fn), + map_location=device, + weights_only=True, ) # Load border mask, 1. if node is part of border, else 0. @@ -116,7 +120,11 @@ def load_graph(graph_name, device="cpu"): graph_dir_path = os.path.join("graphs", graph_name) def loads_file(fn): - return torch.load(os.path.join(graph_dir_path, fn), map_location=device) + return torch.load( + os.path.join(graph_dir_path, fn), + map_location=device, + weights_only=True, + ) # Load edges (edge_index) m2m_edge_index = BufferList(