Skip to content

Commit

Permalink
Merge branch 'master' into gma/add_autotp_workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 18, 2024
2 parents f11a434 + 3110c38 commit 0ae3bdb
Show file tree
Hide file tree
Showing 16 changed files with 337 additions and 60 deletions.
2 changes: 1 addition & 1 deletion accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_accelerator():
except ImportError as e:
raise ValueError(
f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
elif is_current_accelerator_supported():
elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST:
raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
f'Value "{accelerator_name}" is not supported')
ds_set_method = "override"
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def build_hf_engine(path: str,
policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "falcon":
policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "phi-msft":
elif model_config.model_type == "phi":
policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen":
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
Expand Down
75 changes: 34 additions & 41 deletions deepspeed/inference/v2/model_implementations/phi/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,30 @@
# HF Phi-2 model looks like this:
PhiForCausalLM(
(transformer): PhiModel(
(embd): Embedding(
(wte): Embedding(51200, 2560)
(drop): Dropout(p=0.0, inplace=False)
)
(h): ModuleList(
(0-31): 32 x ParallelBlock(
(ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(resid_dropout): Dropout(p=0.1, inplace=False)
(mixer): MHA(
(rotary_emb): RotaryEmbedding()
(Wqkv): Linear(in_features=2560, out_features=7680, bias=True)
(out_proj): Linear(in_features=2560, out_features=2560, bias=True)
(inner_attn): SelfAttention(
(drop): Dropout(p=0.0, inplace=False)
)
(inner_cross_attn): CrossAttention(
(drop): Dropout(p=0.0, inplace=False)
)
(model): PhiModel(
(embed_tokens): Embedding(51200, 2560)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x PhiDecoderLayer(
(self_attn): PhiAttention(
(q_proj): Linear(in_features=2560, out_features=2560, bias=True)
(k_proj): Linear(in_features=2560, out_features=2560, bias=True)
(v_proj): Linear(in_features=2560, out_features=2560, bias=True)
(dense): Linear(in_features=2560, out_features=2560, bias=True)
(rotary_emb): PhiRotaryEmbedding()
)
(mlp): MLP(
(mlp): PhiMLP(
(activation_fn): NewGELUActivation()
(fc1): Linear(in_features=2560, out_features=10240, bias=True)
(fc2): Linear(in_features=10240, out_features=2560, bias=True)
(act): NewGELUActivation()
)
(input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
)
(final_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
)
(lm_head): CausalLMHead(
(ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(linear): Linear(in_features=2560, out_features=51200, bias=True)
)
(loss): CausalLMLoss(
(loss_fct): CrossEntropyLoss()
)
(lm_head): Linear(in_features=2560, out_features=51200, bias=True)
)
'''

Expand All @@ -54,8 +43,8 @@ class PhiTransformerContainer(LayerContainer):
"""
Transformer layer container for the Phi model.
"""
qkv_w: FusedQKVParameter
qkv_b: FusedQKVParameter
qkv_w: UnfusedQKVParameter
qkv_b: UnfusedQKVParameter
attn_out_w: AttentionOutputParameter
attn_out_b: AttentionOutputParameter
mlp_1_w: MLP1Parameter
Expand All @@ -66,16 +55,20 @@ class PhiTransformerContainer(LayerContainer):
ln_beta: NormParameter

PARAM_MAPPING = {
"mixer.Wqkv.weight": "qkv_w.params",
"mixer.Wqkv.bias": "qkv_b.params",
"mixer.out_proj.weight": "attn_out_w.params",
"mixer.out_proj.bias": "attn_out_b.params",
"self_attn.q_proj.weight": "qkv_w.q_params",
"self_attn.k_proj.weight": "qkv_w.k_params",
"self_attn.v_proj.weight": "qkv_w.v_params",
"self_attn.q_proj.bias": "qkv_b.q_params",
"self_attn.k_proj.bias": "qkv_b.k_params",
"self_attn.v_proj.bias": "qkv_b.v_params",
"self_attn.dense.weight": "attn_out_w.params",
"self_attn.dense.bias": "attn_out_b.params",
"mlp.fc1.weight": "mlp_1_w.params",
"mlp.fc1.bias": "mlp_1_b.params",
"mlp.fc2.weight": "mlp_2_w.params",
"mlp.fc2.bias": "mlp_2_b.params",
"ln.weight": "ln_gamma.params",
"ln.bias": "ln_beta.params",
"input_layernorm.weight": "ln_gamma.params",
"input_layernorm.bias": "ln_beta.params",
}


Expand All @@ -90,9 +83,9 @@ class PhiNonTransformerContainer(LayerContainer):
final_norm_beta: NormParameter

PARAM_MAPPING = {
"transformer.embd.wte.weight": "word_emb.params",
"lm_head.ln.weight": "final_norm_gamma.params",
"lm_head.ln.bias": "final_norm_beta.params",
"lm_head.linear.weight": "word_unembed_w.params",
"lm_head.linear.bias": "word_unembed_b.params",
"model.embed_tokens.weight": "word_emb.params",
"model.final_layernorm.weight": "final_norm_gamma.params",
"model.final_layernorm.bias": "final_norm_beta.params",
"lm_head.weight": "word_unembed_w.params",
"lm_head.bias": "word_unembed_b.params",
}
14 changes: 7 additions & 7 deletions deepspeed/inference/v2/model_implementations/phi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def max_sequence_length(self) -> int:

@property
def num_layers(self) -> int:
return self._config.n_layer
return self._config.num_hidden_layers

@property
def model_dim(self) -> int:
return self._config.n_embd
return self._config.hidden_size

@property
def vocab_size(self) -> int:
Expand All @@ -63,16 +63,15 @@ def head_size(self) -> int:

@property
def n_heads(self) -> int:
return self._config.n_head
return self._config.num_attention_heads

@property
def intermediate_dim(self) -> int:
n_inner = getattr(self._config, "n_inner", None)
return n_inner if n_inner is not None else 4 * self.model_dim
return self._config.intermediate_size

@property
def n_heads_kv(self) -> int:
return getattr(self._config, "n_head_kv", None) or self.n_heads
return self._config.num_key_value_heads

@property
def activation_dtype(self) -> DtypeEnum:
Expand All @@ -97,7 +96,8 @@ def positional_embedding_type(self) -> PositionalEmbeddingType:

@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(rotate_dim=self._config.rotary_dim)
rotary_dim = int(self._config.partial_rotary_factor * self.head_size)
return RotateHalfConfig(rotate_dim=rotary_dim, theta_base=self._config.rope_theta)

"""
Forward implementations
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/model_implementations/phi/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def build_container_map(self) -> ContainerMap:
trans_container_cls = PhiTransformerContainer
transformer_containers = [trans_container_cls(self.model) for _ in range(self.model.num_layers)]

map.set_transformer_params(['transformer.h'], transformer_containers)
map.set_transformer_params(['model.layers'], transformer_containers)

map.set_non_transformer_params(PhiNonTransformerContainer(self.model))

Expand Down
10 changes: 6 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2745,10 +2745,12 @@ def load_checkpoint(self,
load_module_only=load_module_only,
custom_load_fn=custom_load_fn)

load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization()
or self.bfloat16_enabled())
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
if load_zero_checkpoint:
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
if load_optimizer_states and not load_module_only:
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
else:
success = False
if not success:
self.optimizer._restore_from_bit16_weights()

Expand Down Expand Up @@ -2830,7 +2832,7 @@ def _load_checkpoint(self,
optim_checkpoint = None
if load_module_only:
deepspeed_states = ['module']
if self.optimizer is not None and self.fp16_enabled():
if self.optimizer is not None:
self.optimizer.refresh_fp32_params()
else:
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
Expand Down
15 changes: 14 additions & 1 deletion deepspeed/runtime/zero/mics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class MiCS_Init(Init):
def __init__(self,
module=None,
data_parallel_group=None,
sequence_data_parallel_group=None,
mem_efficient_linear=True,
remote_device=None,
pin_memory=False,
Expand Down Expand Up @@ -145,9 +146,21 @@ def __init__(self,
if not dist.is_initialized():
dist.init_distributed()
assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"

if data_parallel_group is None and sequence_data_parallel_group is None:
ds_process_group = dist.get_world_group()
elif sequence_data_parallel_group is not None:
ds_process_group = sequence_data_parallel_group
elif data_parallel_group is not None:
ds_process_group = data_parallel_group
else: # both given
raise ValueError(
"Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments."
)

self.mics_comm_groups = create_mics_comm_groups(
_ds_config.mics_shard_size,
data_parallel_group,
ds_process_group,
hierarchical_allgather=_ds_config.mics_hierarchial_params_gather,
mpu=mpu)

Expand Down
60 changes: 57 additions & 3 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtyp
torch.half, torch.bfloat16, torch.float
], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]"
self.wrapped_cls = set()
self.skip_init_depth = 0

self.quantized_initialization = None
if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization:
Expand Down Expand Up @@ -435,6 +436,51 @@ def wrapped_apply(module: Module, fn_to_apply: Callable) -> None:

return wrapped_apply

def hook_for_skip_init(module):
# this function is intended for handling the logic of torch.nn.utils.skip_init
# skip_init:module_cls(*args, **kwargs).to_empty(device=final_device), where kwargs['device']='meta'
# the function call occurs between module_cls(*args, **kwargs) and to_empty(device=final_device).
def partition_after_empty_init(f):

@functools.wraps(f)
def wrapper(module, *args, **kwargs):
_module = f(module, *args, **kwargs)
# here is the post-hook for module.apply(empty_like...)
# after module.apply(empty_like...), the module has completed its empty init on real device
# since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init
self._post_init_method(_module)
return _module

return wrapper

def post_wrapper_to_empty(f):
# append some wrapper restoration after to_empty() call
@functools.wraps(f)
def wrapper(*args, **kwargs):
res = f(*args, **kwargs)
# restore _apply hook
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_disable_class_apply(subclass)
# self restore
module.to_empty = f
return res

return wrapper

def _enable_class_apply(cls):
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)

def _disable_class_apply(cls):
cls._apply = cls._old_apply_of_skip_init_hook

# add hooks for to_empty: apply_(empty_like)
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_enable_class_apply(subclass)

# add a restore hook when exiting skip_init
module.to_empty = post_wrapper_to_empty(module.to_empty)

def partition_after(f):

@functools.wraps(f)
Expand All @@ -456,16 +502,25 @@ def wrapper(module, *args, **kwargs):
is_child_module = True
setattr(module, "_ds_child_entered", True)

f(module, *args, **kwargs)
init_on_meta = 'device' in kwargs and kwargs['device'] == 'meta'
if init_on_meta:
self.skip_init_depth += 1

f(module, *args, **kwargs)
if init_on_meta and self.skip_init_depth == 1:
# check and handle the logic of empty_init
hook_for_skip_init(module)
if is_child_module:
# child's __init__ is done, now we can run a single post_init on the child object
delattr(module, "_ds_child_entered")

print_rank_0(f'Running post_init for {module.__class__.__name__}', force=False)
self._post_init_method(module)
if self.skip_init_depth == 0:
self._post_init_method(module)

print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False)
if init_on_meta:
self.skip_init_depth -= 1

return wrapper

Expand Down Expand Up @@ -512,7 +567,6 @@ def _init_subclass(cls, **kwargs):
self.patched = True

def unpatch_init_and_builtins(self):

if self.patched:

def _disable_class(cls):
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __init__(self,
f"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}." \
f"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam."

if self.reduce_scatter:
if self.reduce_scatter and self.partition_gradients:
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
Expand Down
7 changes: 7 additions & 0 deletions docs/_tutorials/automatic-tensor-parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,29 @@ The following results were collected using V100 SXM2 32GB GPUs.
The following model families have been successfully tested with automatic tensor parallelism. Other models may work but have not been tested yet.

- albert
- baichuan
- bert
- bigbird_pegasus
- bloom
- camembert
- codegen
- codellama
- deberta_v2
- electra
- ernie
- esm
- falcon
- glm
- gpt-j
- gpt-neo
- gpt-neox
- longt5
- luke
- llama
- llama2
- m2m_100
- marian
- mistral
- mpt
- mvp
- nezha
Expand All @@ -147,10 +152,12 @@ The following model families have been successfully tested with automatic tensor
- pegasus
- perceiver
- plbart
- qwen
- reformer
- roberta
- roformer
- splinter
- starcode
- t5
- xglm
- xlm_roberta
Expand Down
10 changes: 10 additions & 0 deletions op_builder/hpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''

from .cpu_adam import CPUAdamBuilder
from .fused_adam import FusedAdamBuilder
from .no_impl import NotImplementedBuilder
Loading

0 comments on commit 0ae3bdb

Please sign in to comment.