Skip to content

Commit

Permalink
update mfk
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-Saves committed Nov 13, 2023
1 parent 9772030 commit f8ae566
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions smt/applications/mfk.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
from smt.utils.misc import standardization

from smt.surrogate_models.krg_based import compute_n_param
from smt.utils.design_space import ensure_design_space


class NestedLHS(object):
def __init__(self, nlevel, xlimits, random_state=None, design_space=None):
def __init__(self, nlevel, xlimits=None, random_state=None, design_space=None):
"""
Constructor where values of options can be passed in.
Expand All @@ -49,13 +50,26 @@ def __init__(self, nlevel, xlimits, random_state=None, design_space=None):
xlimits : ndarray
The interval of the domain in each dimension with shape (nx, 2)
design_space : DesignSpace
The design space with type and bounds for every design variable
random_state : Numpy RandomState object or seed number which controls random draws
"""
self.nlevel = nlevel
self.xlimits = xlimits
self.random_state = random_state
self.design_space = design_space
if xlimits == None and design_space == None:
raise ValueError(
"Either xlimits or design_space should be specified to have bounds for the sampling."
)
elif xlimits != None and design_space != None:
raise ValueError(
"Use either design_space for mixed inputs or xlimits for continuous one. Please avoid overspecification."
)
elif xlimits != None:
self.design_space = ensure_design_space(xlimits=xlimits)
else:
self.design_space = design_space

def __call__(self, nb_samples_hifi):
"""
Expand Down Expand Up @@ -86,7 +100,11 @@ def __call__(self, nb_samples_hifi):
raise ValueError("nt must be a list of decreasing integers")

doe = []
p0 = LHS(xlimits=self.xlimits, criterion="ese", random_state=self.random_state)
p0 = LHS(
xlimits=self.design_space.get_x_limits(),
criterion="ese",
random_state=self.random_state,
)
p0nt0 = p0(nt[0])
if self.design_space:
p0nt0, _ = self.design_space.correct_get_acting(p0nt0)
Expand All @@ -95,7 +113,9 @@ def __call__(self, nb_samples_hifi):

for i in range(1, self.nlevel):
p = LHS(
xlimits=self.xlimits, criterion="ese", random_state=self.random_state
xlimits=self.design_space.get_x_limits(),
criterion="ese",
random_state=self.random_state,
)
pnti = p(nt[i])
if self.design_space:
Expand Down

0 comments on commit f8ae566

Please sign in to comment.