diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index c9e7dd004..2a2967fa6 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -103,15 +103,18 @@ def train( "user-supplied composition weights" ) cur_weight_dict = self.hypers["fixed_composition_weights"][target_name] - atomic_types = set() + atomic_types = [] num_species = len(cur_weight_dict) fixed_weights = torch.zeros(num_species, dtype=dtype, device=device) for ii, (key, weight) in enumerate(cur_weight_dict.items()): - atomic_types.add(key) + atomic_types.append(key) fixed_weights[ii] = weight - if not set(atomic_types) == model.atomic_types: + if ( + not set(atomic_types) + == (model.module if is_distributed else model).atomic_types + ): raise ValueError( "Supplied atomic types are not present in the dataset." )