diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 0585b8e..dda5ac0 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -34,12 +34,10 @@ def __init__( self.save_hyperparameters(ignore=["datastore"]) self.args = args self._datastore = datastore - # XXX: should be this be somewhere else? - split = "train" num_state_vars = datastore.get_num_data_vars(category="state") num_forcing_vars = datastore.get_num_data_vars(category="forcing") da_static_features = datastore.get_dataarray( - category="static", split=split + category="static", split=None ) da_state_stats = datastore.get_standardization_dataarray( category="state"