diff --git a/icenet/model/predict.py b/icenet/model/predict.py index 49f0dab..6aeec42 100644 --- a/icenet/model/predict.py +++ b/icenet/model/predict.py @@ -6,11 +6,11 @@ import pandas as pd import tensorflow as tf -import icenet.model.models as models from icenet.data.loader import save_sample from icenet.data.dataset import IceNetDataSet from icenet.model.cli import predict_args +from icenet.model.networks.tensorflow import unet_batchnorm """ @@ -22,7 +22,7 @@ def predict_forecast( network_name: object, dataset_name: object = None, legacy_rounding: bool = False, - model_func: callable = models.unet_batchnorm, + model_func: callable = unet_batchnorm, n_filters_factor: float = 1 / 8, network_folder: object = None, output_folder: object = None, @@ -56,7 +56,8 @@ def predict_forecast( dataset_name = dataset_name if dataset_name else ds.identifier network_path = os.path.join( - network_folder, "{}.network_{}.{}.h5".format(network_name, dataset_name, + network_folder, "{}.network_{}.{}.h5".format(network_name, + dataset_name, seed)) logging.info("Loading model from {}...".format(network_path))