diff --git a/lume_model/base.py b/lume_model/base.py index 513bebc..58a9106 100644 --- a/lume_model/base.py +++ b/lume_model/base.py @@ -43,6 +43,7 @@ def process_torch_module( key: str = "", file_prefix: Union[str, os.PathLike] = "", save_modules: bool = True, + save_jit: bool = True, ): """Optionally saves the given torch module to file and returns the filename. @@ -60,11 +61,14 @@ def process_torch_module( filepath_prefix, filename_prefix = os.path.split(file_prefix) prefixes = [ele for ele in [filename_prefix, base_key] if not ele == ""] filename = "{}.pt".format(key) + jit_filename = "{}.jit".format(key) if prefixes: filename = "_".join((*prefixes, filename)) filepath = os.path.join(filepath_prefix, filename) if save_modules: torch.save(module, filepath) + if save_jit: + torch.jit.script(module) return filename def recursive_serialize( @@ -72,6 +76,7 @@ def recursive_serialize( base_key: str = "", file_prefix: Union[str, os.PathLike] = "", save_models: bool = True, + save_jit: bool = True, ): """Recursively performs custom serialization for the given object. @@ -93,12 +98,12 @@ def recursive_serialize( v[key] = recursive_serialize(value, key) elif torch is not None and isinstance(value, torch.nn.Module): v[key] = process_torch_module(value, base_key, key, file_prefix, - save_models) + save_models, save_jit) elif isinstance(value, list) and torch is not None and any( isinstance(ele, torch.nn.Module) for ele in value): v[key] = [ process_torch_module(value[i], base_key, f"{key}_{i}", file_prefix, - save_models) + save_models, save_jit) for i in range(len(value)) ] else: @@ -137,6 +142,7 @@ def json_dumps( base_key="", file_prefix: Union[str, os.PathLike] = "", save_models: bool = True, + save_jit: bool = True, ): """Serializes variables before dumping with json. @@ -149,7 +155,7 @@ def json_dumps( Returns: JSON formatted string. """ - v = recursive_serialize(v.model_dump(), base_key, file_prefix, save_models) + v = recursive_serialize(v.model_dump(), base_key, file_prefix, save_models, save_jit) v = json.dumps(v) return v @@ -322,6 +328,7 @@ def yaml( base_key: str = "", file_prefix: str = "", save_models: bool = False, + save_jit: bool = False, ) -> str: """Serializes the object and returns a YAML formatted string defining the model. @@ -329,7 +336,7 @@ def yaml( base_key: Base key for serialization. file_prefix: Prefix for generated filenames. save_models: Determines whether models are saved to file. - + save_jit: Determines whether the structure of the model is saved as TorchScript Returns: YAML formatted string defining the model. """ @@ -349,6 +356,7 @@ def dump( file: Union[str, os.PathLike], base_key: str = "", save_models: bool = True, + save_jit: bool=True, ): """Returns and optionally saves YAML formatted string defining the model. @@ -364,6 +372,7 @@ def dump( base_key=base_key, file_prefix=file_prefix, save_models=save_models, + save_jit=save_jit, ) ) diff --git a/lume_model/models/torch_module.py b/lume_model/models/torch_module.py index 507f9ef..180d769 100644 --- a/lume_model/models/torch_module.py +++ b/lume_model/models/torch_module.py @@ -95,6 +95,7 @@ def yaml( base_key: str = "", file_prefix: str = "", save_models: bool = False, + save_jit: bool=False, ) -> str: """Serializes the object and returns a YAML formatted string defining the TorchModule instance. @@ -102,6 +103,7 @@ def yaml( base_key: Base key for serialization. file_prefix: Prefix for generated filenames. save_models: Determines whether models are saved to file. + save_jit: Determines whether the structure of the model is saved as TorchScript Returns: YAML formatted string defining the TorchModule instance. @@ -111,13 +113,14 @@ def yaml( if k not in ["self", "args", "model"]: d[k] = getattr(self, k) output = json.loads( - json.dumps(recursive_serialize(d, base_key, file_prefix, save_models)) + json.dumps(recursive_serialize(d, base_key, file_prefix, save_models, save_jit)) ) model_output = json.loads( self._model.to_json( base_key=base_key, file_prefix=file_prefix, save_models=save_models, + save_jit=save_jit, ) ) output["model"] = model_output @@ -131,6 +134,7 @@ def dump( file: Union[str, os.PathLike], save_models: bool = True, base_key: str = "", + save_jit: bool = True, ): """Returns and optionally saves YAML formatted string defining the model. @@ -138,6 +142,7 @@ def dump( file: File path to which the YAML formatted string and corresponding files are saved. base_key: Base key for serialization. save_models: Determines whether models are saved to file. + save_jit : Whether the model is saved using just in time pytorch method """ file_prefix = os.path.splitext(file)[0] with open(file, "w") as f: @@ -146,6 +151,7 @@ def dump( save_models=save_models, base_key=base_key, file_prefix=file_prefix, + save_jit=save_jit, ) )