Skip to content

Commit

Permalink
Add weights_only keyword to all torch.load calls
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 14, 2024
1 parent 7112013 commit 2d187a8
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def load_dataset_stats(dataset_name, device="cpu"):

def loads_file(fn):
return torch.load(
os.path.join(static_dir_path, fn), map_location=device
os.path.join(static_dir_path, fn),
map_location=device,
weights_only=True,
)

data_mean = loads_file("parameter_mean.pt") # (d_features,)
Expand All @@ -42,7 +44,9 @@ def load_static_data(dataset_name, device="cpu"):

def loads_file(fn):
return torch.load(
os.path.join(static_dir_path, fn), map_location=device
os.path.join(static_dir_path, fn),
map_location=device,
weights_only=True,
)

# Load border mask, 1. if node is part of border, else 0.
Expand Down Expand Up @@ -116,7 +120,11 @@ def load_graph(graph_name, device="cpu"):
graph_dir_path = os.path.join("graphs", graph_name)

def loads_file(fn):
return torch.load(os.path.join(graph_dir_path, fn), map_location=device)
return torch.load(
os.path.join(graph_dir_path, fn),
map_location=device,
weights_only=True,
)

# Load edges (edge_index)
m2m_edge_index = BufferList(
Expand Down

0 comments on commit 2d187a8

Please sign in to comment.