Skip to content

Commit

Permalink
Added doc string comment to predict_ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
jsschreck committed Sep 20, 2023
1 parent b50f01e commit c0dccdc
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions evml/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def save_model(self):

@classmethod
def load_model(cls, conf):
"""
Load a trained model using args from a configuration
"""
# Check if weights file exists
weights = os.path.join(conf["model"]["save_path"], "best.h5")
if not os.path.isfile(weights):
Expand All @@ -278,6 +281,7 @@ def load_model(cls, conf):
return model_class

def mae(self, y_true, y_pred):
""" Compute the MAE """
num_splits = y_pred.shape[-1]
if num_splits == 4:
mu, _, _, _ = tf.split(y_pred, num_splits, axis=-1)
Expand All @@ -288,6 +292,7 @@ def mae(self, y_true, y_pred):
return tf.keras.metrics.mean_absolute_error(y_true, mu)

def mse(self, y_true, y_pred):
""" Compute the MSE """
num_splits = y_pred.shape[-1]
if num_splits == 4:
mu, _, _, _ = tf.split(y_pred, num_splits, axis=-1)
Expand Down Expand Up @@ -320,6 +325,20 @@ def predict(self, x, scaler=None, batch_size=None):
return y_out

def predict_ensemble(self, x, weight_locations, batch_size=None, scaler=None, num_outputs=1):
"""
Predicts outcomes using an ensemble of trained Keras models.
Args:
x (numpy.ndarray): Input data for predictions.
weight_locations (list of str): List containing paths to saved Keras model weights.
batch_size (int, optional): Batch size for inference. Default is None.
scaler (object, optional): Scaler object for preprocessing input data. Default is None.
num_outputs (int, optional): Number of output predictions. Default is 1.
Returns:
numpy.ndarray: Ensemble predictions for the input data.
"""

num_models = len(weight_locations)

# Initialize output_shape based on the first model's prediction
Expand Down

0 comments on commit c0dccdc

Please sign in to comment.