Skip to content

Commit

Permalink
feat: dataclass args for accelerated MoE tuning (#390)
Browse files Browse the repository at this point in the history
* feat: accelerated MoE dataclass and init

Signed-off-by: Will Johnson <[email protected]>

* fix: author's note

Signed-off-by: Will Johnson <[email protected]>

* feat: accelerated moe in acceleration framework

Signed-off-by: Will Johnson <[email protected]>

* feat: accelerated moe to sft_trainer

Signed-off-by: Will Johnson <[email protected]>

* feat: fmt, testing

Signed-off-by: Will Johnson <[email protected]>

* fix: rename accelerated moe to fast moe

Signed-off-by: Will Johnson <[email protected]>

* test: add testing for scatter moe on accel framework

Signed-off-by: Will Johnson <[email protected]>

* fix: model, dtype, assertions

Signed-off-by: Will Johnson <[email protected]>

* fix: post init check removed from FastMoe, experimental set to True

Signed-off-by: Will Johnson <[email protected]>

* fix: if non-iterable nested dataclass, still initialize

Signed-off-by: Will Johnson <[email protected]>

* test: add failing test for wrong ep_degree

Signed-off-by: Will Johnson <[email protected]>

* fix: actually expect failure

Signed-off-by: Will Johnson <[email protected]>

* test: make sure fast moe doesn't work with non-moe model

Signed-off-by: Will Johnson <[email protected]>

* fix: regex of new test

Signed-off-by: Will Johnson <[email protected]>

* comment: explain iterable unpacking

Signed-off-by: Will Johnson <[email protected]>

* docs: fast MOE in README

Signed-off-by: Will Johnson <[email protected]>

* docs: Add note for post-processing

Signed-off-by: Will Johnson <[email protected]>

* fix: Dockerfile

Signed-off-by: Will Johnson <[email protected]>

* test: fix params

Signed-off-by: Will Johnson <[email protected]>

* fix: file path

Signed-off-by: Will Johnson <[email protected]>

* fix: expand on docs, remove from Dockerfile, move iterable data to else statement

Signed-off-by: Will Johnson <[email protected]>

* lint

Signed-off-by: Will Johnson <[email protected]>

* fix: spelling

Signed-off-by: Will Johnson <[email protected]>

---------

Signed-off-by: Will Johnson <[email protected]>
  • Loading branch information
willmj authored Jan 7, 2025
1 parent 3dc8ef7 commit 8851227
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 12 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,8 @@ The list of configurations for various `fms_acceleration` plugins:
- [attention_and_distributed_packing](./tuning/config/acceleration_configs/attention_and_distributed_packing.py):
- `--padding_free`: technique to process multiple examples in single batch without adding padding tokens that waste compute.
- `--multipack`: technique for *multi-gpu training* to balance out number of tokens processed in each device, to minimize waiting time.
- [fast_moe_config](./tuning/config/acceleration_configs/fast_moe.py) (experimental):
- `--fast_moe`: trains MoE models in parallel, increasing throughput and decreasing memory usage.

Notes:
* `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass.
Expand All @@ -762,6 +764,17 @@ Notes:
* Notes on Multipack
- works only for *multi-gpu*.
- currently only includes the version of *multipack* optimized for linear attention implementations like *flash-attn*.
* Notes on Fast MoE
- `--fast_moe` is an integer value that configures the amount of expert parallel sharding (ep_degree).
- `world_size` must be divisible by the `ep_degree`
- Running fast moe modifies the state dict of the model, and must be post-processed using [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) to run inference (HF, vLLM, etc.).
- The typical usecase for this script is to run:
```
python -m fms_acceleration_moe.utils.checkpoint_utils \
<checkpoint file> \
<output file> \
<original model>
```
Note: To pass the above flags via a JSON config, each of the flags expects the value to be a mixed type list, so the values must be a list. For example:
```json
Expand Down
8 changes: 8 additions & 0 deletions tests/acceleration/test_acceleration_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
MultiPack,
PaddingFree,
)
from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
FusedLoraConfig,
Expand Down Expand Up @@ -88,6 +89,13 @@ def test_dataclass_parse_successfully():
)
assert isinstance(cfg.multipack, MultiPack)

# 5. Specifing "--fast_moe" will parse an FastMoe class
parser = transformers.HfArgumentParser(dataclass_types=FastMoeConfig)
(cfg,) = parser.parse_args_into_dataclasses(
["--fast_moe", "1"],
)
assert isinstance(cfg.fast_moe, FastMoe)


def test_two_dataclasses_parse_successfully_together():
"""Ensure that the two dataclasses can parse arguments successfully
Expand Down
159 changes: 156 additions & 3 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
MultiPack,
PaddingFree,
)
from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
FusedLoraConfig,
Expand All @@ -56,7 +57,7 @@
# for some reason the CI will raise an import error if we try to import
# these from tests.artifacts.testdata
TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(
os.path.dirname(__file__), "../artifacts/testdata/twitter_complaints_json.json"
os.path.dirname(__file__), "../artifacts/testdata/json/twitter_complaints_json.json"
)
TWITTER_COMPLAINTS_TOKENIZED = os.path.join(
os.path.dirname(__file__),
Expand Down Expand Up @@ -87,6 +88,10 @@
# Third Party
from fms_acceleration_aadp import PaddingFreeAccelerationPlugin

if is_fms_accelerate_available(plugins="moe"):
# Third Party
from fms_acceleration_moe import ScatterMoEAccelerationPlugin


# There are more extensive unit tests in the
# https://github.com/foundation-model-stack/fms-acceleration
Expand Down Expand Up @@ -360,7 +365,7 @@ def test_framework_raises_due_to_invalid_arguments(
acceleration_configs_map,
ids=["bitsandbytes", "auto_gptq"],
)
def test_framework_intialized_properly_peft(
def test_framework_initialized_properly_peft(
quantized_lora_config, model_name_or_path, mock_and_spy
):
"""Ensure that specifying a properly configured acceleration dataclass
Expand Down Expand Up @@ -412,7 +417,7 @@ def test_framework_intialized_properly_peft(
"and foak plugins"
),
)
def test_framework_intialized_properly_foak():
def test_framework_initialized_properly_foak():
"""Ensure that specifying a properly configured acceleration dataclass
properly activates the framework plugin and runs the train sucessfully.
"""
Expand Down Expand Up @@ -477,6 +482,60 @@ def test_framework_intialized_properly_foak():
assert spy2["get_ready_for_train_calls"] == 1


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="moe"),
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
)
def test_framework_initialized_properly_moe():
"""Ensure that specifying a properly configured acceleration dataclass
properly activates the framework plugin and runs the train sucessfully.
"""

with tempfile.TemporaryDirectory() as tempdir:

model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
model_args.torch_dtype = torch.bfloat16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = "no"
train_args.bf16 = True
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
data_args.response_template = "\n\n### Label:"
data_args.dataset_text_field = "output"

# initialize a config
moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))

# create mocked plugin class for spying
MockedPlugin1, spy = create_mock_plugin_class_and_spy(
"FastMoeMock", ScatterMoEAccelerationPlugin
)

# 1. mock a plugin class
# 2. register the mocked plugins
# 3. call sft_trainer.train
with build_framework_and_maybe_instantiate(
[
(["training.moe.scattermoe"], MockedPlugin1),
],
instantiate=False,
):
with instantiate_model_patcher():
sft_trainer.train(
model_args,
data_args,
train_args,
fast_moe_config=moe_config,
)

# spy inside the train to ensure that the ilab plugin is called
assert spy["model_loader_calls"] == 1
assert spy["augmentation_calls"] == 0
assert spy["get_ready_for_train_calls"] == 1


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="aadp"),
reason="Only runs if fms-accelerate is installed along with \
Expand Down Expand Up @@ -661,6 +720,100 @@ def test_error_raised_with_fused_lora_enabled_without_quantized_argument():
)


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="moe"),
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
)
def test_error_raised_with_undividable_fastmoe_argument():
"""
Ensure error is thrown when `--fast_moe` is passed and world_size
is not divisible by ep_degree
"""
with pytest.raises(
AssertionError, match="world size \\(1\\) not divisible by ep_size \\(3\\)"
):
with tempfile.TemporaryDirectory() as tempdir:

model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
model_args.torch_dtype = torch.bfloat16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = "no"
train_args.bf16 = True
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
data_args.response_template = "\n\n### Label:"
data_args.dataset_text_field = "output"

# initialize a config
moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=3))

# 1. mock a plugin class
# 2. register the mocked plugins
# 3. call sft_trainer.train
with build_framework_and_maybe_instantiate(
[
(["training.moe.scattermoe"], ScatterMoEAccelerationPlugin),
],
instantiate=False,
):
with instantiate_model_patcher():
sft_trainer.train(
model_args,
data_args,
train_args,
fast_moe_config=moe_config,
)


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="moe"),
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
)
def test_error_raised_fast_moe_with_non_moe_model():
"""
Ensure error is thrown when `--fast_moe` is passed and model is not MoE
"""
with pytest.raises(
AttributeError,
match="'LlamaConfig' object has no attribute 'num_local_experts'",
):
with tempfile.TemporaryDirectory() as tempdir:

model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
model_args.torch_dtype = torch.bfloat16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = "no"
train_args.bf16 = True
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
data_args.response_template = "\n\n### Label:"
data_args.dataset_text_field = "output"

# initialize a config
moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))

# 1. mock a plugin class
# 2. register the mocked plugins
# 3. call sft_trainer.train
with build_framework_and_maybe_instantiate(
[
(["training.moe.scattermoe"], ScatterMoEAccelerationPlugin),
],
instantiate=False,
):
with instantiate_model_patcher():
sft_trainer.train(
model_args,
data_args,
train_args,
fast_moe_config=moe_config,
)


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="foak"),
reason="Only runs if fms-accelerate is installed along with \
Expand Down
6 changes: 4 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -388,6 +389,7 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -398,14 +400,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down
1 change: 1 addition & 0 deletions tuning/config/acceleration_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
# Local
from .acceleration_framework_config import AccelerationFrameworkConfig
from .attention_and_distributed_packing import AttentionAndDistributedPackingConfig
from .fast_moe import FastMoeConfig
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
from .quantized_lora_config import QuantizedLoraConfig
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# Local
from .attention_and_distributed_packing import MultiPack, PaddingFree
from .fast_moe import FastMoe
from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig
from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig
from tuning.utils.import_utils import is_fms_accelerate_available
Expand Down Expand Up @@ -65,6 +66,7 @@ class AccelerationFrameworkConfig:
PACKAGE_PREFIX = "fms_acceleration_"

# each field will a single-level use case dataclass

auto_gptq: Annotated[
AutoGPTQLoraConfig,
ConfigAnnotation(
Expand All @@ -89,6 +91,17 @@ class AccelerationFrameworkConfig:
),
] = None

fast_moe: Annotated[
FastMoe,
ConfigAnnotation(
path="training.moe",
key="scattermoe",
standalone=True,
experimental=True,
required_packages=["moe"],
),
] = None

fast_kernels: Annotated[
FastKernelsConfig,
ConfigAnnotation(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from dataclasses import dataclass

Expand Down
36 changes: 36 additions & 0 deletions tuning/config/acceleration_configs/fast_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from dataclasses import dataclass

# Local
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass


@parsable_dataclass
@dataclass
class FastMoe:

ep_degree: int = 1


@dataclass
class FastMoeConfig:

fast_moe: FastMoe = None

def __post_init__(self):
# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)
Loading

0 comments on commit 8851227

Please sign in to comment.