-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[shardformer] add tests to mistral (#5105)
* [shardformer] add tests to mistral fix fix * fix fix fix fix fix
- Loading branch information
1 parent
2e04af1
commit c8420cd
Showing
9 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
import transformers | ||
|
||
from ..registry import ModelAttribute, model_zoo | ||
|
||
from transformers import MistralConfig | ||
|
||
# =============================== | ||
# Register single-sentence Mistral | ||
# =============================== | ||
|
||
def data_gen(): | ||
# Generated from following code snippet | ||
# | ||
# from transformers import AutoModelForCausalLM, AutoTokenizer | ||
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") | ||
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement) | ||
# tokenized_input = tokenizer([input], return_tensors="pt") | ||
# input_ids = tokenized_input['input_ids'] | ||
# attention_mask = tokenized_input['attention_mask'] | ||
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64) | ||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) | ||
return dict(input_ids=input_ids, attention_mask=attention_mask) | ||
|
||
def data_gen_for_lm(): | ||
# LM data gen | ||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` | ||
data = data_gen() | ||
data["labels"] = data["input_ids"].clone() | ||
return data | ||
|
||
def data_gen_for_sequence_classification(): | ||
# sequence classification data gen | ||
data = data_gen() | ||
data["labels"] = torch.tensor([1], dtype=torch.int64) | ||
return data | ||
|
||
# define output transform function | ||
output_transform_fn = lambda x: x | ||
|
||
# define loss function | ||
loss_fn_for_mistral_model = lambda x: torch.nn.functional.mse_loss( | ||
x.last_hidden_state, torch.ones_like(x.last_hidden_state) | ||
) | ||
loss_fn = lambda x: x.loss | ||
loss_fn_for_seq_classification = lambda output: output.logits.mean() | ||
|
||
config = MistralConfig( | ||
hidden_size=256, | ||
intermediate_size=256, | ||
num_attention_heads=64, | ||
num_hidden_layers=2, | ||
vocab_size=50258 | ||
) | ||
|
||
model_zoo.register( | ||
name="transformers_mistral", | ||
model_fn=lambda: transformers.MistralModel(config), | ||
data_gen_fn=data_gen, | ||
output_transform_fn=output_transform_fn, | ||
loss_fn=loss_fn_for_mistral_model, | ||
model_attribute=ModelAttribute(has_control_flow=True), | ||
) | ||
model_zoo.register( | ||
name="transformers_mistral_for_casual_lm", | ||
model_fn=lambda: transformers.MistralForCausalLM(config), | ||
data_gen_fn=data_gen_for_lm, | ||
output_transform_fn=output_transform_fn, | ||
loss_fn=loss_fn, | ||
model_attribute=ModelAttribute(has_control_flow=True), | ||
) | ||
model_zoo.register( | ||
name="transformers_mistral_for_sequence_classification", | ||
model_fn=lambda: transformers.MistralForSequenceClassification(config), | ||
data_gen_fn=data_gen_for_sequence_classification, | ||
output_transform_fn=output_transform_fn, | ||
loss_fn=loss_fn_for_seq_classification, | ||
model_attribute=ModelAttribute(has_control_flow=True), | ||
) |
148 changes: 148 additions & 0 deletions
148
tests/test_shardformer/test_model/test_shard_mistral.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import os | ||
|
||
import pytest | ||
import torch | ||
|
||
import colossalai | ||
from colossalai.logging import disable_existing_loggers | ||
from colossalai.shardformer.layer.utils import Randomizer | ||
from colossalai.tensor.d_tensor.api import clear_layout_converter | ||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn | ||
from tests.kit.model_zoo import model_zoo | ||
from tests.test_shardformer.test_model._utils import ( | ||
build_model_from_hybrid_plugin, | ||
check_all_grad_tensors, | ||
check_loss, | ||
check_output_hidden_state, | ||
check_weight, | ||
get_grad_tensors_for_check, | ||
run_forward_backward_with_hybrid_plugin, | ||
unwrap_model, | ||
) | ||
|
||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" | ||
|
||
|
||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): | ||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( | ||
model_fn, loss_fn, test_config | ||
) | ||
|
||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( | ||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster | ||
) | ||
|
||
stage_manager = booster.plugin.stage_manager | ||
tp_group = booster.plugin.tp_group | ||
|
||
# unwrap model | ||
mistral_model = unwrap_model(org_model, "MistralModel", "model") | ||
shard_mistral_model = unwrap_model(sharded_model, "MistralModel", "model") | ||
|
||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] | ||
col_layer_for_check = ["layers[0].self_attn.o_proj"] | ||
|
||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. | ||
grads_to_check = {} | ||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: | ||
if test_config["precision"] == "fp32": | ||
atol, rtol = 5e-5, 1e-4 | ||
else: | ||
atol, rtol = 5e-3, 5e-3 | ||
row_layer_grads = get_grad_tensors_for_check( | ||
mistral_model, shard_mistral_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False | ||
) | ||
col_layer_grads = get_grad_tensors_for_check( | ||
mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False | ||
) | ||
grads_to_check.update(col_layer_grads) | ||
grads_to_check.update(row_layer_grads) | ||
|
||
# optimizer executes step | ||
org_optimizer.step() | ||
sharded_optimizer.step() | ||
|
||
# check last hidden state & loss | ||
if stage_manager is None or stage_manager.is_last_stage(): | ||
if test_config["precision"] == "fp32": | ||
atol, rtol = 1e-5, 1e-3 | ||
else: | ||
atol, rtol = 5e-3, 5e-3 | ||
|
||
if org_model.__class__.__name__ == "MistralModel": | ||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) | ||
|
||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) | ||
|
||
# check weights | ||
if stage_manager is None or stage_manager.is_first_stage(): | ||
if test_config["precision"] == "fp32": | ||
atol, rtol = 1e-4, 1e-3 | ||
else: | ||
atol, rtol = 5e-3, 5e-3 | ||
check_weight( | ||
mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False | ||
) | ||
|
||
# check grads | ||
check_all_grad_tensors(grads_to_check) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
|
||
@parameterize( | ||
"test_config", | ||
[ | ||
{ | ||
"tp_size": 4, | ||
"pp_size": 1, | ||
"enable_all_optimization": True, | ||
"use_lazy_init": False, | ||
"precision": "fp32", | ||
}, | ||
{ | ||
"tp_size": 2, | ||
"pp_size": 1, | ||
"enable_all_optimization": True, | ||
"use_lazy_init": False, | ||
"precision": "fp32", | ||
}, | ||
{ | ||
"tp_size": 2, | ||
"pp_size": 1, | ||
"enable_all_optimization": True, | ||
"use_lazy_init": True, | ||
"zero_stage": 2, | ||
"precision": "fp16", | ||
"initial_scale": 1, | ||
}, | ||
], | ||
) | ||
def run_mistral_test(test_config): | ||
sub_model_zoo = model_zoo.get_sub_registry("transformers_mistral") | ||
|
||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): | ||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) | ||
|
||
clear_layout_converter() | ||
Randomizer.reset_index() | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def check_mistral(rank, world_size, port): | ||
disable_existing_loggers() | ||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") | ||
run_mistral_test() | ||
|
||
|
||
|
||
@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") | ||
@pytest.mark.dist | ||
@rerun_if_address_is_in_use() | ||
@clear_cache_before_run() | ||
def test_mistral(): | ||
spawn(check_mistral, 4) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_mistral() |