diff --git a/smt/applications/mfk.py b/smt/applications/mfk.py index 604e9ae7a..58131143e 100644 --- a/smt/applications/mfk.py +++ b/smt/applications/mfk.py @@ -296,6 +296,9 @@ def _new_train_init(self): _, self.cat_features = compute_X_cont( np.concatenate(xt, axis=0), self.design_space ) + self.X_offset[self.cat_features] *=0 + self.X_scale[self.cat_features] *=0 + self.X_scale[self.cat_features] +=1 nlevel = self.nlvl