Skip to content

Commit

Permalink
rename functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Antoine-Averland committed Dec 19, 2024
1 parent 6e39928 commit 4c4305b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions smt/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def compute_relative_error(sm, xe=None, ye=None, kx=None):
return num / den


def compute_pva_error(sm, xe, ye):
def compute_pva(sm, xe, ye):
ye = ye.reshape((xe.shape[0], 1))
N = len(ye)
ye2 = sm.predict_values(xe)
Expand All @@ -119,18 +119,18 @@ def compute_pva_error(sm, xe, ye):
return pva


def compute_rmse_error(sm, xe, ye):
def compute_rmse(sm, xe, ye):
ye = ye.reshape((xe.shape[0], 1))
N = len(ye)
ye2 = sm.predict_values(xe)
rmse = np.sqrt(np.sum((ye2 - ye) ** 2) / N)
return rmse


def compute_q2_error(sm, xe, ye):
def compute_q2(sm, xe, ye):
ye = ye.reshape((xe.shape[0], 1))
N = len(ye)
square_rmse = compute_rmse_error(sm, xe, ye) ** 2
square_rmse = compute_rmse(sm, xe, ye) ** 2
ye_mean = np.mean(ye)
variance = np.sum((ye - ye_mean) ** 2) / N
Q2 = 1 - (square_rmse / variance)
Expand Down
12 changes: 6 additions & 6 deletions smt/utils/test/test_misc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import numpy as np

from smt.utils.misc import (
compute_q2_error,
compute_pva_error,
compute_rmse_error,
compute_q2,
compute_pva,
compute_rmse,
standardization,
)
from smt.problems import Sphere
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_pva_error(self):
sm.set_training_values(xe, ye)
sm.train()

pva = compute_pva_error(sm, xe, ye)
pva = compute_pva(sm, xe, ye)
self.assertLess(pva, 0.7)

def test_rmse_error(self):
Expand All @@ -60,7 +60,7 @@ def test_rmse_error(self):
sm.set_training_values(xe, ye)
sm.train()

rmse = compute_rmse_error(sm, xe, ye)
rmse = compute_rmse(sm, xe, ye)
self.assertLess(rmse, 0.1)

def test_q2_error(self):
Expand All @@ -69,7 +69,7 @@ def test_q2_error(self):
sm.set_training_values(xe, ye)
sm.train()

q2 = compute_q2_error(sm, xe, ye)
q2 = compute_q2(sm, xe, ye)
self.assertAlmostEqual(q2, 1.0, delta=1e-3)


Expand Down

0 comments on commit 4c4305b

Please sign in to comment.