Skip to content

Commit

Permalink
Merge branch 'main' into cl-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 authored Jul 29, 2024
2 parents 70d5154 + 5c7e99b commit 523fb61
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 1 deletion.
19 changes: 19 additions & 0 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ def transform_config(
copied_config.ffn_config['moe_world_size'] = 1
return copied_config

def pre_register_edit(self, local_save_path: str):
"""Edit the model before registering with MLflow.
This allows a subclass to modify the model before registering with MLflow. The base class implementation will
make no modifications.
Args:
local_save_path (str): The path to the model to be transformed.
"""
pass

def transform_model_pre_registration(
self,
model: PreTrainedModel,
Expand Down Expand Up @@ -602,6 +613,12 @@ def tensor_hook(
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
# Add the pip requirements directly to avoid mlflow
# attempting to run inference on the model
model_saving_kwargs['pip_requirements'] = [
'transformers',
'torch',
]
mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
Expand All @@ -618,6 +635,8 @@ def tensor_hook(
os.path.join(local_save_path, license_filename),
)

self.pre_register_edit(local_save_path,)

# Spawn a new process to register the model.
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def train(cfg: DictConfig) -> Trainer:
dist_timeout=train_cfg.dist_timeout,
profiler=profiler,
compile_config=compile_config,
spin_dataloaders=train_cfg.spin_dataloaders,
)

# Optionally just save an HF checkpoint
Expand Down
7 changes: 7 additions & 0 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def __init__(
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
norm_eps: float = 1e-05,
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
bias: bool = True,
Expand Down Expand Up @@ -520,6 +521,7 @@ def __init__(
self.q_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
eps=norm_eps,
device=device,
)
if self.reuse_kv_layer_idx is None:
Expand All @@ -528,6 +530,7 @@ def __init__(
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
eps=norm_eps,
device=device,
)

Expand Down Expand Up @@ -796,6 +799,7 @@ def __init__(
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
norm_eps: float = 1e-05,
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
bias: bool = True,
Expand All @@ -814,6 +818,7 @@ def __init__(
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
norm_eps=norm_eps,
fc_type=fc_type,
device=device,
bias=bias,
Expand Down Expand Up @@ -841,6 +846,7 @@ def __init__(
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
norm_eps: float = 1e-05,
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
bias: bool = True,
Expand All @@ -859,6 +865,7 @@ def __init__(
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
norm_eps=norm_eps,
fc_type=fc_type,
device=device,
bias=bias,
Expand Down
7 changes: 7 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
ffn_config: Optional[Dict] = None,
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
norm_eps: float = 1e-05,
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
no_bias: bool = False,
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
fc_type=fc_type,
resid_pdrop=resid_pdrop,
norm_type=norm_type,
norm_eps=norm_eps,
device=device,
no_bias=no_bias,
)
Expand All @@ -99,6 +101,7 @@ def __init__(
self.norm_1 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
eps=norm_eps,
device=device,
)
self.attn = build_attention_layer(
Expand All @@ -117,6 +120,7 @@ def __init__(
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
eps=norm_eps,
device=device,
)

Expand Down Expand Up @@ -260,6 +264,7 @@ def __init__(
fc_type: Optional[dict[str, Any]] = None,
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
norm_eps: float = 1e-05,
device: Optional[str] = None,
no_bias: bool = False,
**kwargs: Any,
Expand All @@ -283,6 +288,7 @@ def __init__(
self.norm_1 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
eps=norm_eps,
device=device,
)
self.attn = build_attention_layer(
Expand All @@ -302,6 +308,7 @@ def __init__(
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
eps=norm_eps,
device=device,
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
def build_norm(
name: str,
normalized_shape: Union[int, List[int], torch.Size],
eps: Optional[float] = 1e-5,
device: Optional[str] = None,
):
kwargs = {
'normalized_shape': normalized_shape,
'eps': eps,
'device': device,
}

Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
no_bias: bool = False,
embedding_fraction: float = 1.0,
norm_type: str = 'low_precision_layernorm',
norm_eps: float = 1e-05,
use_cache: bool = False,
init_config: Optional[Dict] = None,
fc_type: Union[str, Dict] = 'torch',
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
no_bias (bool): Whether to use bias in all layers.
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
norm_type (str): choose type of norm to use
norm_eps (float): epsilon value for norm layer
use_cache (bool): Whether or not the model should return the last key/values attentions
init_config (Dict): A dictionary used to configure the model initialization:
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
Expand Down Expand Up @@ -168,6 +170,7 @@ def __init__(
self.no_bias = no_bias
self.embedding_fraction = embedding_fraction
self.norm_type = norm_type
self.norm_eps = norm_eps
self.use_cache = use_cache
self.init_config = init_config if init_config is not None else copy.deepcopy(
init_config_defaults,
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def __init__(self, config: MPTConfig):
self.norm_f = build_norm(
name=config.norm_type.lower(),
normalized_shape=config.d_model,
eps=config.norm_eps,
device=config.init_device,
)

Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class TrainConfig:
# Dataloader
device_train_microbatch_size: Union[str, int, float] = 'auto'
global_train_batch_size: Optional[int] = None
spin_dataloaders: bool = True

# Eval dataloader
eval_subset_num_batches: int = -1
Expand Down Expand Up @@ -531,7 +532,6 @@ def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]):
fsdp_config['sync_module_states'] = True

# Set defaults for mixed initialization
fsdp_config.setdefault('use_orig_params', False)
fsdp_config.setdefault('load_monolith_rank0_only', True)

# Set ffn_config.device_mesh to fsdp_config.device_mesh
Expand Down
7 changes: 7 additions & 0 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ def test_huggingface_conversion_callback_interval(
checkpointer_callback.transform_model_pre_registration = MagicMock(
wraps=checkpointer_callback.transform_model_pre_registration,
)
checkpointer_callback.pre_register_edit = MagicMock(
wraps=checkpointer_callback.pre_register_edit,
)
trainer = Trainer(
model=original_model,
device='gpu',
Expand All @@ -411,11 +414,14 @@ def test_huggingface_conversion_callback_interval(
task='llm/v1/completions',
input_example=ANY,
metadata={},
pip_requirements=ANY,
)
assert checkpointer_callback.transform_model_pre_registration.call_count == 1
assert checkpointer_callback.pre_register_edit.call_count == 1
assert mlflow_logger_mock.register_model_with_run_id.call_count == 1
else:
assert checkpointer_callback.transform_model_pre_registration.call_count == 0
assert checkpointer_callback.pre_register_edit.call_count == 0
assert mlflow_logger_mock.save_model.call_count == 0
assert mlflow_logger_mock.register_model_with_run_id.call_count == 0

Expand Down Expand Up @@ -589,6 +595,7 @@ def _assert_mlflow_logger_calls(
'task': 'llm/v1/completions',
'input_example': default_input_example,
'metadata': {},
'pip_requirements': ANY,
}
mlflow_logger_mock.save_model.assert_called_with(**expectation)
assert mlflow_logger_mock.register_model_with_run_id.call_count == 1
Expand Down

0 comments on commit 523fb61

Please sign in to comment.