Skip to content

Commit

Permalink
move filter_mp functions out
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jan 3, 2025
1 parent 83ae096 commit 614e7a8
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import torch

# Local
from .models.utils import filter_mp_rules
from .utils import lora_adapters_switch_ddp_from_fsdp
from .utils import filter_mp_rules, lora_adapters_switch_ddp_from_fsdp


# consider rewriting register_foak_model_patch_rules into something
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
from .utils import (
KEY_MLP,
KEY_O,
KEY_QKV,
build_lora_fused_ops,
filter_mp_rules,
get_hidden_activation_fn_key,
trigger_fused_ops,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
from .utils import (
KEY_MLP,
KEY_O,
KEY_QKV,
build_lora_fused_ops,
filter_mp_rules,
get_hidden_activation_fn_key,
trigger_fused_ops,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
from .utils import (
KEY_MLP,
KEY_O,
KEY_QKV,
build_lora_fused_ops,
filter_mp_rules,
get_hidden_activation_fn_key,
trigger_fused_ops,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Standard
from functools import partial
from typing import Callable, List, Set, Type
from typing import Callable, List, Type
import os

# Third Party
from fms_acceleration.model_patcher import ModelPatcherRule, ModelPatcherTrigger
from fms_acceleration.model_patcher import ModelPatcherTrigger
from transformers import PretrainedConfig
import torch

Expand Down Expand Up @@ -203,22 +203,6 @@ def trigger_fused_ops(
return isinstance(module, attn_cls) and all(_is_loralayer(x) for x in _mods)


# helper function to filter rules
def filter_mp_rules(
rules: List[ModelPatcherRule],
filter_endswith: Set[str],
drop: bool = False,
):
if drop:
# this means if any of the filter terms appear, we drop
return [
r for r in rules if not any(r.rule_id.endswith(x) for x in filter_endswith)
]

# this means if any if the filter terms appear, we keep
return [r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith)]


# helper function to get the hidden activation function str
def get_hidden_activation_fn_key(config: PretrainedConfig):
for key in KEY_HIDDEN_ACTIVATIONS:
Expand Down
20 changes: 20 additions & 0 deletions plugins/fused-ops-and-kernels/src/fms_acceleration_foak/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from typing import List, Set

# Third Party
from accelerate.utils import set_module_tensor_to_device
from fms_acceleration.model_patcher import ModelPatcherRule
from transformers.modeling_utils import is_fsdp_enabled
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -74,3 +78,19 @@ def _all_reduce_hook(grad):
# - this has to be done after all weight replacement happens
A.weight.register_hook(_all_reduce_hook)
B.weight.register_hook(_all_reduce_hook)


# helper function to filter rules
def filter_mp_rules(
rules: List[ModelPatcherRule],
filter_endswith: Set[str],
drop: bool = False,
):
if drop:
# this means if any of the filter terms appear, we drop
return [
r for r in rules if not any(r.rule_id.endswith(x) for x in filter_endswith)
]

# this means if any if the filter terms appear, we keep
return [r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith)]
2 changes: 1 addition & 1 deletion plugins/fused-ops-and-kernels/tests/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fms_acceleration.model_patcher import ModelPatcherRule

# First Party
from fms_acceleration_foak.models.utils import filter_mp_rules
from fms_acceleration_foak.utils import filter_mp_rules


def test_filter_mp_rules():
Expand Down

0 comments on commit 614e7a8

Please sign in to comment.