diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 50645979..756211d0 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -382,6 +382,8 @@ def get_wrapped_model( model: ModelInterface A wrapped ModelInterface model. """ + if isinstance(model, ModelInterface): + return model if util.find_spec("tensorflow"): if isinstance(model, tf.keras.Model): return TensorFlowModel(