diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py index 80181ec0..5c1872ab 100644 --- a/neural_lam/models/hi_lam_parallel.py +++ b/neural_lam/models/hi_lam_parallel.py @@ -16,8 +16,8 @@ class HiLAMParallel(BaseHiGraphModel): of Hi-LAM. """ - def __init__(self, args): - super().__init__(args) + def __init__(self, args, datastore): + super().__init__(args, datastore=datastore) # Processor GNNs # Create the complete edge_index combining all edges for processing