Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding save_jit option to save scripted jit module #103

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions lume_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def process_torch_module(
key: str = "",
file_prefix: Union[str, os.PathLike] = "",
save_modules: bool = True,
save_jit: bool = True,
):
"""Optionally saves the given torch module to file and returns the filename.

Expand All @@ -60,18 +61,22 @@ def process_torch_module(
filepath_prefix, filename_prefix = os.path.split(file_prefix)
prefixes = [ele for ele in [filename_prefix, base_key] if not ele == ""]
filename = "{}.pt".format(key)
jit_filename = "{}.jit".format(key)
if prefixes:
filename = "_".join((*prefixes, filename))
filepath = os.path.join(filepath_prefix, filename)
if save_modules:
torch.save(module, filepath)
if save_jit:
torch.jit.script(module)
return filename

def recursive_serialize(
v: dict[str, Any],
base_key: str = "",
file_prefix: Union[str, os.PathLike] = "",
save_models: bool = True,
save_jit: bool = True,
):
"""Recursively performs custom serialization for the given object.

Expand All @@ -93,12 +98,12 @@ def recursive_serialize(
v[key] = recursive_serialize(value, key)
elif torch is not None and isinstance(value, torch.nn.Module):
v[key] = process_torch_module(value, base_key, key, file_prefix,
save_models)
save_models, save_jit)
elif isinstance(value, list) and torch is not None and any(
isinstance(ele, torch.nn.Module) for ele in value):
v[key] = [
process_torch_module(value[i], base_key, f"{key}_{i}", file_prefix,
save_models)
save_models, save_jit)
for i in range(len(value))
]
else:
Expand Down Expand Up @@ -137,6 +142,7 @@ def json_dumps(
base_key="",
file_prefix: Union[str, os.PathLike] = "",
save_models: bool = True,
save_jit: bool = True,
):
"""Serializes variables before dumping with json.

Expand All @@ -149,7 +155,7 @@ def json_dumps(
Returns:
JSON formatted string.
"""
v = recursive_serialize(v.model_dump(), base_key, file_prefix, save_models)
v = recursive_serialize(v.model_dump(), base_key, file_prefix, save_models, save_jit)
v = json.dumps(v)
return v

Expand Down Expand Up @@ -322,14 +328,15 @@ def yaml(
base_key: str = "",
file_prefix: str = "",
save_models: bool = False,
save_jit: bool = False,
) -> str:
"""Serializes the object and returns a YAML formatted string defining the model.

Args:
base_key: Base key for serialization.
file_prefix: Prefix for generated filenames.
save_models: Determines whether models are saved to file.

save_jit: Determines whether the structure of the model is saved as TorchScript
Returns:
YAML formatted string defining the model.
"""
Expand All @@ -349,6 +356,7 @@ def dump(
file: Union[str, os.PathLike],
base_key: str = "",
save_models: bool = True,
save_jit: bool=True,
):
"""Returns and optionally saves YAML formatted string defining the model.

Expand All @@ -364,6 +372,7 @@ def dump(
base_key=base_key,
file_prefix=file_prefix,
save_models=save_models,
save_jit=save_jit,
)
)

Expand Down
8 changes: 7 additions & 1 deletion lume_model/models/torch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,15 @@ def yaml(
base_key: str = "",
file_prefix: str = "",
save_models: bool = False,
save_jit: bool=False,
) -> str:
"""Serializes the object and returns a YAML formatted string defining the TorchModule instance.

Args:
base_key: Base key for serialization.
file_prefix: Prefix for generated filenames.
save_models: Determines whether models are saved to file.
save_jit: Determines whether the structure of the model is saved as TorchScript

Returns:
YAML formatted string defining the TorchModule instance.
Expand All @@ -111,13 +113,14 @@ def yaml(
if k not in ["self", "args", "model"]:
d[k] = getattr(self, k)
output = json.loads(
json.dumps(recursive_serialize(d, base_key, file_prefix, save_models))
json.dumps(recursive_serialize(d, base_key, file_prefix, save_models, save_jit))
)
model_output = json.loads(
self._model.to_json(
base_key=base_key,
file_prefix=file_prefix,
save_models=save_models,
save_jit=save_jit,
)
)
output["model"] = model_output
Expand All @@ -131,13 +134,15 @@ def dump(
file: Union[str, os.PathLike],
save_models: bool = True,
base_key: str = "",
save_jit: bool = True,
):
"""Returns and optionally saves YAML formatted string defining the model.

Args:
file: File path to which the YAML formatted string and corresponding files are saved.
base_key: Base key for serialization.
save_models: Determines whether models are saved to file.
save_jit : Whether the model is saved using just in time pytorch method
"""
file_prefix = os.path.splitext(file)[0]
with open(file, "w") as f:
Expand All @@ -146,6 +151,7 @@ def dump(
save_models=save_models,
base_key=base_key,
file_prefix=file_prefix,
save_jit=save_jit,
)
)

Expand Down