Skip to content

Commit

Permalink
[shardformer] add tests to mistral (#5105)
Browse files Browse the repository at this point in the history
* [shardformer] add tests to mistral

fix

fix

* fix

fix

fix

fix

fix
  • Loading branch information
flybird11111 authored Nov 26, 2023
1 parent 2e04af1 commit c8420cd
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 0 deletions.
9 changes: 9 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ class PolicyLocation:
"transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation(
file_name="falcon", class_name="FalconForQuestionAnsweringPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation(
file_name="mistral", class_name="MistralModelPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForCausalLM": PolicyLocation(
file_name="mistral", class_name="MistralForCausalLMPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
),
}

_INFER_POLICY_LIST = {
Expand Down
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@


class BloomPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The Bloom model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass

Expand Down
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@


class FalconPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The Falcon model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass

Expand Down
14 changes: 14 additions & 0 deletions colossalai/shardformer/policies/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ class MistralModelPolicy(MistralPolicy):
def __init__(self) -> None:
super().__init__()

def module_policy(self):
if self.pipeline_stage_manager:
warnings.warn("Mistral dosen't support pipeline parallelism now.")

return super().module_policy()

class MistralForCausalLMPolicy(MistralPolicy):
def module_policy(self):
from transformers import MistralForCausalLM
Expand All @@ -150,6 +156,10 @@ def module_policy(self):
]
)
}

if self.pipeline_stage_manager:
warnings.warn("Mistral dosen't support pipeline parallelism now.")

policy.update(new_item)

return policy
Expand All @@ -171,5 +181,9 @@ def module_policy(self):
]
)
}

if self.pipeline_stage_manager:
warnings.warn("Mistral dosen't support pipeline parallelism now.")

policy.update(new_item)
return policy
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@


class OPTPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The OPT model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass

Expand Down
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@


class WhisperPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The Whisper model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass

Expand Down
4 changes: 4 additions & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@
from .t5 import *
from .vit import *
from .whisper import *
try:
from .mistral import *
except ImportError:
print("This version of transformers doesn't support mistral.")
79 changes: 79 additions & 0 deletions tests/kit/model_zoo/transformers/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

from transformers import MistralConfig

# ===============================
# Register single-sentence Mistral
# ===============================

def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)

def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data

def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data

# define output transform function
output_transform_fn = lambda x: x

# define loss function
loss_fn_for_mistral_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()

config = MistralConfig(
hidden_size=256,
intermediate_size=256,
num_attention_heads=64,
num_hidden_layers=2,
vocab_size=50258
)

model_zoo.register(
name="transformers_mistral",
model_fn=lambda: transformers.MistralModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_mistral_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mistral_for_casual_lm",
model_fn=lambda: transformers.MistralForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mistral_for_sequence_classification",
model_fn=lambda: transformers.MistralForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)
148 changes: 148 additions & 0 deletions tests/test_shardformer/test_model/test_shard_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import os

import pytest
import torch

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)

os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"


def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)

org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)

stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group

# unwrap model
mistral_model = unwrap_model(org_model, "MistralModel", "model")
shard_mistral_model = unwrap_model(sharded_model, "MistralModel", "model")

row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]

# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 5e-5, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
mistral_model, shard_mistral_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
col_layer_grads = get_grad_tensors_for_check(
mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)

# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()

# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3

if org_model.__class__.__name__ == "MistralModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)

check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)

# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)

# check grads
check_all_grad_tensors(grads_to_check)

torch.cuda.empty_cache()


@parameterize(
"test_config",
[
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_mistral_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_mistral")

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()


def check_mistral(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mistral_test()



@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_mistral():
spawn(check_mistral, 4)


if __name__ == "__main__":
test_mistral()

0 comments on commit c8420cd

Please sign in to comment.