Skip to content

Commit

Permalink
auto-lint code
Browse files Browse the repository at this point in the history
  • Loading branch information
ivy-dev-bot committed Sep 20, 2024
1 parent 08ea81a commit 1b022f2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 21 deletions.
5 changes: 1 addition & 4 deletions ivy/functional/backends/tensorflow/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import tensorflow as tf
import keras
import numpy as np
import functools
from tensorflow.python.util import nest
from typing import (
Expand All @@ -23,7 +22,7 @@
from packaging.version import parse

if TYPE_CHECKING:
import torch.nn as nn
pass


if keras.__version__ >= "3.0.0":
Expand All @@ -32,8 +31,6 @@
KerasVariable = tf.Variable




def get_assignment_dict():
# Traverse the call stack
lhs = None
Expand Down
48 changes: 31 additions & 17 deletions ivy/stateful/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
def _is_submodule(obj, kw):
cls_str = {
"torch": ("torch.nn.modules.module.Module",),
"keras": ("keras.engine.training.Model", "tf_keras.src.engine.training.Model", "keras.src.models.model.Model", "tf_keras.src.engine.base_layer.Layer", "keras.src.engine.base_layer.Layer", "keras.src.layers.layer.Layer"),
"keras": (
"keras.engine.training.Model",
"tf_keras.src.engine.training.Model",
"keras.src.models.model.Model",
"tf_keras.src.engine.base_layer.Layer",
"keras.src.engine.base_layer.Layer",
"keras.src.layers.layer.Layer",
),
"flax": ("flax.nnx.nnx.module.Module",),
}[kw]
try:
Expand Down Expand Up @@ -53,14 +60,16 @@ def _retrive_layer(model, key):


def _sync_models_torch_and_jax(model1: "nn.Module", model2: "FlaxModel"):
"""
Synchronizes the parameters and buffers of the original and the translated model.
"""Synchronizes the parameters and buffers of the original and the
translated model.
Args:
----
model1 (torch.nn.Module): The original PyTorch model.
model2 (ivy.Module converted Flax.nnx.Module)): The converted ivy.Module converted Flax.nnx.Module.
Returns:
-------
None
"""

Expand Down Expand Up @@ -258,14 +267,16 @@ def _maybe_update_flax_layer_weights(layer, weight_name, new_weight):


def _sync_models_torch_and_tf(model1: "nn.Module", model2: "KerasModel"):
"""
Synchronizes the parameters and buffers of the original and the translated model.
"""Synchronizes the parameters and buffers of the original and the
translated model.
Args:
----
model1 (torch.nn.Module): The original PyTorch model.
model2 (ivy.Module converted keras.Model)): The converted ivy.Module converted keras.Model.
Returns:
-------
None
"""

Expand Down Expand Up @@ -470,9 +481,8 @@ def _maybe_update_keras_layer_weights(layer, weight_name, new_weight):
def sync_models_torch_and_tf(
model_pt: "nn.Module", model_tf: Union["keras.Model", "KerasModel"]
):
"""
Synchronizes the weights and buffers between a PyTorch model (`torch.nn.Module`)
and a TensorFlow model (`keras.Model`).
"""Synchronizes the weights and buffers between a PyTorch model
(`torch.nn.Module`) and a TensorFlow model (`keras.Model`).
This function ensures that both models have identical parameters and buffers by
iterating through their submodules and synchronizing them. The TensorFlow model
Expand All @@ -481,15 +491,18 @@ def sync_models_torch_and_tf(
including `named_parameters()` and `named_buffers()`.
Args:
----
model_pt (torch.nn.Module): The PyTorch model to synchronize from.
model_tf (keras.Model): The TensorFlow model to synchronize to, with submodules
inheriting from the custom `KerasModel`/`KerasLayer` class.
Returns:
-------
None
Example:
-------
```python
import torch.nn as nn
import keras
Expand Down Expand Up @@ -554,7 +567,7 @@ def _compute_module_dict_tf(model, prefix=""):
return _module_dict

try:
import torch
pass
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"`torch` was not found installed on your system. Please proceed "
Expand All @@ -569,7 +582,6 @@ def _compute_module_dict_tf(model, prefix=""):
"to install it and restart your interpreter to see the changes."
) from exc


if hasattr(model_tf, "named_parameters"):
_sync_models_torch_and_tf(model_pt, model_tf)
else:
Expand All @@ -587,9 +599,8 @@ def _compute_module_dict_tf(model, prefix=""):
def sync_models_torch_and_jax(
model_pt: "nn.Module", model_jax: Union["nnx.Module", "FlaxModel"]
):
"""
Synchronizes the weights and buffers between a PyTorch model (`torch.nn.Module`)
and a Flax model (`flax.nnx.Module`).
"""Synchronizes the weights and buffers between a PyTorch model
(`torch.nn.Module`) and a Flax model (`flax.nnx.Module`).
This function ensures both models have identical parameters and buffers by
iterating through their submodules and synchronizing them. The Flax model must
Expand All @@ -598,13 +609,17 @@ def sync_models_torch_and_jax(
including `named_parameters()` and `named_buffers()`.
Args:
----
model_pt (torch.nn.Module): The PyTorch model to synchronize from.
model_flax (flax.nnx.Module): The Flax model to synchronize to, with submodules
inheriting from the custom `FlaxModel` class.
Returns:
-------
None
Example:
-------
```python
import torch.nn as nn
import jax.numpy as jnp
Expand Down Expand Up @@ -722,16 +737,15 @@ def sync_models(
original_model: "nn.Module",
translated_model: Union["keras.Model", "KerasModel", "nnx.Module", "FlaxModel"],
):
"""
Synchronizes the weights and buffers between a native PyTorch model (`torch.nn.Module`)
and it's translated version in TensorFlow or Flax.
"""Synchronizes the weights and buffers between a native PyTorch model
(`torch.nn.Module`) and it's translated version in TensorFlow or Flax.
Args:
----
original_model (torch.nn.Module): The PyTorch model to synchronize from.
translated_model (tf.keras.Model or nnx.Module): The target model to synchronize to,
either a TensorFlow or Flax model.
"""

if not _is_submodule(original_model, "torch"):
raise ivy.utils.exceptions.IvyException(
"sync_models expected an instance of `nn.Module` as the first argument. got {}".format(
Expand Down

0 comments on commit 1b022f2

Please sign in to comment.