From 614e7a807753bd4256c6787864d5dbf4cb82dcb7 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 3 Jan 2025 15:56:13 +0000 Subject: [PATCH] move filter_mp functions out Signed-off-by: Yu Chin Fabian Lim --- .../framework_plugin_fast_kernels.py | 3 +-- .../fms_acceleration_foak/models/granite.py | 2 +- .../src/fms_acceleration_foak/models/llama.py | 2 +- .../fms_acceleration_foak/models/mistral.py | 2 +- .../src/fms_acceleration_foak/models/utils.py | 20 ++----------------- .../src/fms_acceleration_foak/utils.py | 20 +++++++++++++++++++ .../tests/test_model_utils.py | 2 +- 7 files changed, 27 insertions(+), 24 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index 7948a98c..df21fd5c 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -23,8 +23,7 @@ import torch # Local -from .models.utils import filter_mp_rules -from .utils import lora_adapters_switch_ddp_from_fsdp +from .utils import filter_mp_rules, lora_adapters_switch_ddp_from_fsdp # consider rewriting register_foak_model_patch_rules into something diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py index e4b58572..ee2f4206 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py @@ -30,12 +30,12 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from ..utils import filter_mp_rules from .utils import ( KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, - filter_mp_rules, get_hidden_activation_fn_key, trigger_fused_ops, ) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index 94fab82f..ad2df4a7 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -36,12 +36,12 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from ..utils import filter_mp_rules from .utils import ( KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, - filter_mp_rules, get_hidden_activation_fn_key, trigger_fused_ops, ) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py index 64e65274..56dc48f0 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -36,12 +36,12 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from ..utils import filter_mp_rules from .utils import ( KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, - filter_mp_rules, get_hidden_activation_fn_key, trigger_fused_ops, ) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py index 2236f38d..375a2a6e 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py @@ -1,10 +1,10 @@ # Standard from functools import partial -from typing import Callable, List, Set, Type +from typing import Callable, List, Type import os # Third Party -from fms_acceleration.model_patcher import ModelPatcherRule, ModelPatcherTrigger +from fms_acceleration.model_patcher import ModelPatcherTrigger from transformers import PretrainedConfig import torch @@ -203,22 +203,6 @@ def trigger_fused_ops( return isinstance(module, attn_cls) and all(_is_loralayer(x) for x in _mods) -# helper function to filter rules -def filter_mp_rules( - rules: List[ModelPatcherRule], - filter_endswith: Set[str], - drop: bool = False, -): - if drop: - # this means if any of the filter terms appear, we drop - return [ - r for r in rules if not any(r.rule_id.endswith(x) for x in filter_endswith) - ] - - # this means if any if the filter terms appear, we keep - return [r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith)] - - # helper function to get the hidden activation function str def get_hidden_activation_fn_key(config: PretrainedConfig): for key in KEY_HIDDEN_ACTIVATIONS: diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/utils.py index 224f4975..7a9a8049 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/utils.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Standard +from typing import List, Set + # Third Party from accelerate.utils import set_module_tensor_to_device +from fms_acceleration.model_patcher import ModelPatcherRule from transformers.modeling_utils import is_fsdp_enabled import torch import torch.distributed as dist @@ -74,3 +78,19 @@ def _all_reduce_hook(grad): # - this has to be done after all weight replacement happens A.weight.register_hook(_all_reduce_hook) B.weight.register_hook(_all_reduce_hook) + + +# helper function to filter rules +def filter_mp_rules( + rules: List[ModelPatcherRule], + filter_endswith: Set[str], + drop: bool = False, +): + if drop: + # this means if any of the filter terms appear, we drop + return [ + r for r in rules if not any(r.rule_id.endswith(x) for x in filter_endswith) + ] + + # this means if any if the filter terms appear, we keep + return [r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith)] diff --git a/plugins/fused-ops-and-kernels/tests/test_model_utils.py b/plugins/fused-ops-and-kernels/tests/test_model_utils.py index 1e8a53ef..d55aff05 100644 --- a/plugins/fused-ops-and-kernels/tests/test_model_utils.py +++ b/plugins/fused-ops-and-kernels/tests/test_model_utils.py @@ -2,7 +2,7 @@ from fms_acceleration.model_patcher import ModelPatcherRule # First Party -from fms_acceleration_foak.models.utils import filter_mp_rules +from fms_acceleration_foak.utils import filter_mp_rules def test_filter_mp_rules():