Skip to content

Commit

Permalink
Rust formatting (#540)
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-Saves authored Mar 20, 2024
1 parent 817ede6 commit 4d80c52
Show file tree
Hide file tree
Showing 95 changed files with 440 additions and 405 deletions.
6 changes: 3 additions & 3 deletions smt/applications/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .vfm import VFM
from .moe import MOE, MOESurrogateModel
from .ego import EGO, Evaluator
from .mfk import MFK, NestedLHS
from .mfkpls import MFKPLS
from .mfkplsk import MFKPLSK
from .ego import EGO, Evaluator
from .moe import MOE, MOESurrogateModel
from .vfm import VFM

__all__ = [
"VFM",
Expand Down
4 changes: 2 additions & 2 deletions smt/applications/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
This package is distributed under New BSD license.
"""

from smt.surrogate_models import GEKPLS, KPLS, KPLSK, KRG, LS, MGP, QP
from smt.utils.options_dictionary import OptionsDictionary
from smt.surrogate_models import LS, QP, KPLS, KRG, KPLSK, GEKPLS, MGP

try:
from smt.surrogate_models import IDW, RBF, RMTC, RMTB
from smt.surrogate_models import IDW, RBF, RMTB, RMTC

COMPILED_AVAILABLE = True
except ImportError:
Expand Down
9 changes: 4 additions & 5 deletions smt/applications/ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@
"""

import numpy as np

from types import FunctionType

from scipy.stats import norm
import numpy as np
from scipy.optimize import minimize
from scipy.stats import norm

from smt.surrogate_models import KPLS, KRG, KPLSK, MGP, GEKPLS
from smt.applications.application import SurrogateBasedApplication
from smt.applications.mixed_integer import (
MixedIntegerContext,
MixedIntegerSamplingMethod,
)
from smt.sampling_methods import LHS
from smt.surrogate_models import GEKPLS, KPLS, KPLSK, KRG, MGP
from smt.utils.design_space import (
BaseDesignSpace,
DesignSpace,
)
from smt.sampling_methods import LHS


class Evaluator(object):
Expand Down
20 changes: 9 additions & 11 deletions smt/applications/mfk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,32 @@
Adapted on January 2021 by Andres Lopez-Lopera to the new SMT version
"""

from copy import deepcopy
import warnings
from copy import deepcopy

import numpy as np
from scipy.linalg import solve_triangular
from scipy import linalg
from scipy.linalg import solve_triangular
from scipy.spatial.distance import cdist

from sklearn.cross_decomposition import PLSRegression as pls

from smt.sampling_methods import LHS
from smt.surrogate_models.krg_based import (
KrgBased,
MixIntKernelType,
compute_n_param,
)

from smt.sampling_methods import LHS
from smt.utils.design_space import ensure_design_space
from smt.utils.kriging import (
cross_distances,
componentwise_distance,
compute_X_cont,
cross_distances,
cross_levels,
differences,
gower_componentwise_distances,
cross_levels,
compute_X_cont,
)
from smt.utils.misc import standardization

from smt.surrogate_models.krg_based import compute_n_param
from smt.utils.design_space import ensure_design_space


class NestedLHS(object):
def __init__(self, nlevel, xlimits=None, design_space=None, random_state=None):
Expand Down
1 change: 0 additions & 1 deletion smt/applications/mfkpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

import numpy as np

from sklearn.cross_decomposition import PLSRegression as pls
from sklearn.metrics.pairwise import check_pairwise_arrays

Expand Down
2 changes: 1 addition & 1 deletion smt/applications/mfkplsk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
Adapted on January 2021 by Andres Lopez-Lopera to the new SMT version
"""

from smt.utils.kriging import componentwise_distance
from smt.applications import MFKPLS
from smt.utils.kriging import componentwise_distance


class MFKPLSK(MFKPLS):
Expand Down
8 changes: 5 additions & 3 deletions smt/applications/mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
This package is distributed under New BSD license.
"""

import warnings

import numpy as np
from smt.surrogate_models.surrogate_model import SurrogateModel

from smt.sampling_methods.sampling_method import SamplingMethod
from smt.utils.checks import ensure_2d_array
from smt.surrogate_models.krg_based import KrgBased, MixIntKernelType
from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils.checks import ensure_2d_array
from smt.utils.design_space import (
BaseDesignSpace,
CategoricalVariable,
ensure_design_space,
)
import warnings


class MixedIntegerSamplingMethod(SamplingMethod):
Expand Down
3 changes: 1 addition & 2 deletions smt/applications/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# TODO : documentation

import numpy as np

from sklearn.mixture import GaussianMixture
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture

from smt.applications.application import SurrogateBasedApplication
from smt.surrogate_models.surrogate_model import SurrogateModel
Expand Down
47 changes: 24 additions & 23 deletions smt/applications/tests/test_ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,32 @@

import os
import unittest
import numpy as np
from multiprocessing import Pool
from sys import argv

import numpy as np

import smt.utils.design_space as ds
from smt.applications import EGO
from smt.applications.ego import Evaluator
from smt.utils.sm_test_case import SMTestCase
from smt.applications.mixed_integer import (
MixedIntegerContext,
MixedIntegerSamplingMethod,
)
from smt.problems import Branin, Rosenbrock
from smt.sampling_methods import FullFactorial
from multiprocessing import Pool
from smt.sampling_methods import LHS
from smt.sampling_methods import LHS, FullFactorial
from smt.surrogate_models import (
KRG,
GEKPLS,
KPLS,
MixIntKernelType,
KRG,
CategoricalVariable,
DesignSpace,
OrdinalVariable,
FloatVariable,
CategoricalVariable,
IntegerVariable,
MixIntKernelType,
OrdinalVariable,
)
from smt.applications.mixed_integer import (
MixedIntegerContext,
MixedIntegerSamplingMethod,
)
import smt.utils.design_space as ds
from smt.utils.sm_test_case import SMTestCase

try:
import matplotlib
Expand Down Expand Up @@ -1084,11 +1083,12 @@ def test_examples(self):

@staticmethod
def run_ego_example():
import matplotlib.pyplot as plt
import numpy as np

from smt.applications import EGO
from smt.surrogate_models import KRG
from smt.utils.design_space import DesignSpace
import matplotlib.pyplot as plt

def function_test_1d(x):
# function xsinx
Expand Down Expand Up @@ -1173,18 +1173,18 @@ def function_test_1d(x):

@staticmethod
def run_ego_mixed_integer_example():
import matplotlib.pyplot as plt
import numpy as np

from smt.applications import EGO
from smt.applications.mixed_integer import MixedIntegerContext
from smt.surrogate_models import MixIntKernelType
from smt.surrogate_models import KRG, MixIntKernelType
from smt.utils.design_space import (
DesignSpace,
CategoricalVariable,
DesignSpace,
FloatVariable,
IntegerVariable,
)
import matplotlib.pyplot as plt
from smt.surrogate_models import KRG

# Regarding the interface, the function to be optimized should handle
# categorical values as index values in the enumeration type specification.
Expand Down Expand Up @@ -1271,13 +1271,13 @@ def function_test_mixed_integer(X):

@staticmethod
def run_ego_parallel_example():
import matplotlib.pyplot as plt
import numpy as np

from smt.applications import EGO
from smt.applications.ego import Evaluator
from smt.surrogate_models import KRG, DesignSpace

import matplotlib.pyplot as plt

def function_test_1d(x):
# function xsinx
import numpy as np
Expand Down Expand Up @@ -1305,9 +1305,10 @@ class ParallelEvaluator(Evaluator):
def run(self, fun, x):
n_thread = 5
# Caveat: import are made here due to SMT documentation building process
import numpy as np
from sys import version_info
from multiprocessing.pool import ThreadPool
from sys import version_info

import numpy as np

if version_info.major == 2:
return fun(x)
Expand Down
14 changes: 8 additions & 6 deletions smt/applications/tests/test_mfk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import unittest

import numpy as np

try:
Expand All @@ -16,14 +17,14 @@
except ImportError:
NO_MATPLOTLIB = True

from copy import deepcopy

from smt.applications.mfk import MFK, NestedLHS
from smt.problems import Sphere, TensorProduct
from smt.sampling_methods import LHS, FullFactorial

from smt.utils.sm_test_case import SMTestCase
from smt.utils.silence import Silence
from smt.utils.misc import compute_rms_error
from smt.applications.mfk import MFK, NestedLHS
from copy import deepcopy
from smt.utils.silence import Silence
from smt.utils.sm_test_case import SMTestCase

print_output = False

Expand Down Expand Up @@ -138,8 +139,9 @@ def test_mfk_derivs(self):

@staticmethod
def run_mfk_example():
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

from smt.applications.mfk import MFK, NestedLHS

# low fidelity model
Expand Down
11 changes: 6 additions & 5 deletions smt/applications/tests/test_mfk_1fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
NO_MATPLOTLIB = True

import unittest

import numpy as np

from smt.applications.mfk import MFK
from smt.problems import TensorProduct
from smt.sampling_methods import LHS

from smt.utils.sm_test_case import SMTestCase
from smt.utils.silence import Silence
from smt.utils.misc import compute_rms_error
from smt.applications.mfk import MFK
from smt.utils.silence import Silence
from smt.utils.sm_test_case import SMTestCase

print_output = False

Expand Down Expand Up @@ -61,8 +61,9 @@ def test_mfk_1fidelity(self):

@staticmethod
def run_mfk_example_1fidelity():
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

from smt.applications.mfk import MFK, NestedLHS

# Consider only 1 fidelity level
Expand Down
14 changes: 5 additions & 9 deletions smt/applications/tests/test_mfk_mfkpls_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import unittest

import numpy as np

try:
Expand All @@ -19,28 +20,23 @@

import numpy.linalg as npl

from smt.applications.mfkpls import MFKPLS
from smt.applications import NestedLHS
from smt.applications.mfk import MFK

from smt.applications.mfkpls import MFKPLS
from smt.applications.mixed_integer import (
MixedIntegerSamplingMethod,
)

from smt.applications import NestedLHS

from smt.sampling_methods import LHS

from smt.surrogate_models import (
KRG,
KPLS,
KRG,
MixIntKernelType,
)

from smt.utils.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
IntegerVariable,
CategoricalVariable,
)


Expand Down
5 changes: 3 additions & 2 deletions smt/applications/tests/test_mfk_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
https://doi.org/10.1080/00401706.2014.928233
"""

import unittest

import numpy as np

from smt.applications.mfk import MFK, NestedLHS
from smt.sampling_methods import LHS
import unittest

from smt.utils.sm_test_case import SMTestCase

print_output = True
Expand Down
Loading

0 comments on commit 4d80c52

Please sign in to comment.