Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: function name 'requires_agumentation' to 'requires_augmentation' #118

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def requires_custom_loading(self):
return True

@property
def requires_agumentation(self):
def requires_augmentation(self):
return True

def augmentation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def requires_custom_loading(self):
return True

@property
def requires_agumentation(self):
def requires_augmentation(self):
# will skip the augmentation if _no_peft_model == True
return not self._no_peft_model

Expand Down
8 changes: 4 additions & 4 deletions plugins/accelerated-peft/tests/test_peft_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_configure_gptq_plugin():

# check flags and callbacks
assert framework.requires_custom_loading
assert framework.requires_agumentation
assert framework.requires_augmentation
assert len(framework.get_callbacks_and_ready_for_train()) == 0

# attempt to activate plugin with configuration pointing to wrong path
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_configure_bnb_plugin():

# check flags and callbacks
assert framework.requires_custom_loading
assert framework.requires_agumentation
assert framework.requires_augmentation
assert len(framework.get_callbacks_and_ready_for_train()) == 0

# test valid combinatinos
Expand All @@ -187,7 +187,7 @@ def test_configure_bnb_plugin():
):
# check flags and callbacks
assert framework.requires_custom_loading
assert framework.requires_agumentation
assert framework.requires_augmentation
assert len(framework.get_callbacks_and_ready_for_train()) == 0

# test no_peft_model is true skips plugin.augmentation
Expand All @@ -202,7 +202,7 @@ def test_configure_bnb_plugin():
require_packages_check=False,
):
# check flags and callbacks
assert (not correct_value) == framework.requires_agumentation
assert (not correct_value) == framework.requires_augmentation

# attempt to activate plugin with configuration pointing to wrong path
# - raise with message that no plugins can be configured
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
assert self._pad_token_id is not None, "need to get pad token id"

@property
def requires_agumentation(self):
def requires_augmentation(self):
return True

def augmentation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, configurations: Dict[str, Dict]):
)

@property
def requires_agumentation(self):
def requires_augmentation(self):
return True

def augmentation(
Expand Down
2 changes: 1 addition & 1 deletion plugins/framework/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ model, (peft_config,) = framework.augmentation(
)
```

We also provide `framework.requires_agumentation` to check if augumentation is required by the plugins.
We also provide `framework.requires_augmentation` to check if augumentation is required by the plugins.

Finally pass the model to train:

Expand Down
8 changes: 4 additions & 4 deletions plugins/framework/src/fms_acceleration/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def augmentation(
x in model_archs for x in plugin.restricted_model_archs
):
raise ValueError(
f"Model architectures in '{model_archs}' are supported for '{plugin_name}'."
f"Model architectures in '{model_archs}' are not supported for '{plugin_name}'."
)

if plugin.requires_agumentation:
if plugin.requires_augmentation:
model, modifiable_args = plugin.augmentation(
model, train_args, modifiable_args=modifiable_args
)
Expand All @@ -214,8 +214,8 @@ def requires_custom_loading(self):
return len(self.plugins_require_custom_loading) > 0

@property
def requires_agumentation(self):
return any(x.requires_agumentation for _, x in self.active_plugins)
def requires_augmentation(self):
return any(x.requires_augmentation for _, x in self.active_plugins)

def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator: Accelerator = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def requires_custom_loading(self):
return False

@property
def requires_agumentation(self):
def requires_augmentation(self):
return False

def model_loader(self, model_name: str, **kwargs):
Expand Down
10 changes: 5 additions & 5 deletions plugins/framework/src/fms_acceleration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def create_plugin_cls(
restricted_models: Set = None,
require_pkgs: Set = None,
requires_custom_loading: bool = False,
requires_agumentation: bool = False,
agumentation: Callable = None,
requires_augmentation: bool = False,
augmentation: Callable = None,
model_loader: Callable = None,
):
"helper function to create plugin class"
Expand All @@ -174,11 +174,11 @@ def create_plugin_cls(
"restricted_model_archs": restricted_models,
"require_packages": require_pkgs,
"requires_custom_loading": requires_custom_loading,
"requires_agumentation": requires_agumentation,
"requires_augmentation": requires_augmentation,
}

if agumentation is not None:
attributes["augmentation"] = agumentation
if augmentation is not None:
attributes["augmentation"] = augmentation

if model_loader is not None:
attributes["model_loader"] = model_loader
Expand Down
34 changes: 17 additions & 17 deletions plugins/framework/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_model_with_no_config_raises():

# create model and (incomplete) plugin with requires_augmentation = True
model_no_config = torch.nn.Module() # empty model
incomplete_plugin = create_plugin_cls(requires_agumentation=True)
incomplete_plugin = create_plugin_cls(requires_augmentation=True)

# register and activate 1 incomplete plugin, and:
# 1. test correct plugin registration and activation.
Expand Down Expand Up @@ -104,13 +104,13 @@ def test_single_plugin():
empty_plugin = create_plugin_cls()
incomplete_plugin = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
requires_augmentation=True,
)
plugin = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
requires_augmentation=True,
requires_custom_loading=True,
agumentation=dummy_augmentation,
augmentation=dummy_augmentation,
model_loader=dummy_custom_loader,
)
train_args = None # dummy for now
Expand Down Expand Up @@ -175,32 +175,32 @@ def test_two_plugins():

model = create_noop_model_with_archs(archs=["CausalLM"])
incomp_plugin1 = create_plugin_cls(
restricted_models={"CausalLM"}, requires_agumentation=True
restricted_models={"CausalLM"}, requires_augmentation=True
)
incomp_plugin2 = create_plugin_cls(requires_agumentation=True)
incomp_plugin2 = create_plugin_cls(requires_augmentation=True)
incomp_plugin3 = create_plugin_cls(
class_name="PluginNoop2", requires_agumentation=True
class_name="PluginNoop2", requires_augmentation=True
)
plugin1 = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
requires_augmentation=True,
requires_custom_loading=True,
agumentation=dummy_augmentation,
augmentation=dummy_augmentation,
model_loader=dummy_custom_loader,
)
plugin2 = create_plugin_cls(
class_name="PluginNoop2",
restricted_models={"CausalLM"},
requires_agumentation=True,
requires_augmentation=True,
requires_custom_loading=True,
agumentation=dummy_augmentation,
augmentation=dummy_augmentation,
model_loader=dummy_custom_loader,
)
plugin3_no_loader = create_plugin_cls(
class_name="PluginNoop2",
restricted_models={"CausalLM"},
requires_agumentation=True,
agumentation=dummy_augmentation,
requires_augmentation=True,
augmentation=dummy_augmentation,
)
train_args = None # dummy for now

Expand Down Expand Up @@ -299,8 +299,8 @@ def _hook(
for class_name in ["PluginDEF", "PluginABC"]:
plugin = create_plugin_cls(
class_name=class_name,
requires_agumentation=True,
agumentation=hook_builder(act_order=plugin_activation_order),
requires_augmentation=True,
augmentation=hook_builder(act_order=plugin_activation_order),
)
plugins_to_be_installed.append((class_name, plugin))

Expand All @@ -319,8 +319,8 @@ def test_plugin_registration_combination_logic():

plugin = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
agumentation=dummy_augmentation,
requires_augmentation=True,
augmentation=dummy_augmentation,
)

configuration_contents = {"existing1": {"key1": 1}, "existing2": {"key1": 1}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(self, configurations: Dict[str, Dict]):
)

@property
def requires_agumentation(self):
def requires_augmentation(self):
return True

def augmentation(
Expand Down
2 changes: 1 addition & 1 deletion plugins/fused-ops-and-kernels/tests/test_foak_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_configure_gptq_foak_plugin():

# check flags and callbacks
assert framework.requires_custom_loading is False
assert framework.requires_agumentation
assert framework.requires_augmentation
assert len(framework.get_callbacks_and_ready_for_train()) == 0

# attempt to activate plugin with configuration pointing to wrong path
Expand Down
Loading