diff --git a/smt/applications/mfk.py b/smt/applications/mfk.py index b6bf8538e..d42f7c173 100644 --- a/smt/applications/mfk.py +++ b/smt/applications/mfk.py @@ -37,7 +37,7 @@ class NestedLHS(object): - def __init__(self, nlevel, xlimits, random_state=None): + def __init__(self, nlevel, xlimits, random_state=None, design_space=None): """ Constructor where values of options can be passed in. @@ -55,6 +55,7 @@ def __init__(self, nlevel, xlimits, random_state=None): self.nlevel = nlevel self.xlimits = xlimits self.random_state = random_state + self.design_space = design_space def __call__(self, nb_samples_hifi): """ @@ -86,13 +87,21 @@ def __call__(self, nb_samples_hifi): doe = [] p0 = LHS(xlimits=self.xlimits, criterion="ese", random_state=self.random_state) - doe.append(p0(nt[0])) + p0nt0 = p0(nt[0]) + if self.design_space: + p0nt0, _ = self.design_space.correct_get_acting(p0nt0) + + doe.append(p0nt0) for i in range(1, self.nlevel): p = LHS( xlimits=self.xlimits, criterion="ese", random_state=self.random_state ) - doe.append(p(nt[i])) + pnti = p(nt[i]) + if self.design_space: + pnti, _ = self.design_space.correct_get_acting(pnti) + + doe.append(pnti) for i in range(1, self.nlevel)[::-1]: ind = []