Skip to content

Commit

Permalink
Speed up embedding tests (#1668)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 25, 2024
1 parent cd9535a commit bd113da
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 28 deletions.
5 changes: 4 additions & 1 deletion llmfoundry/models/llm_embed/modeling_llm_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class ContrastiveModel(HuggingFaceModel):
config_overrides (Optional[Dict[str, Any]], optional): Overrides for the model configuration. Defaults to None.
load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Defaults to False.
loss_fn (str, optional): The loss function to use (either 'torch_crossentropy' or 'fused_crossentropy'). Defaults to 'fused_crossentropy'.
pretrained (bool, optional): Whether to use a pretrained model when using a Hugging Face architecture. Defaults to True.
**kwargs (Dict[str, Any]): Additional keyword arguments.
"""

Expand All @@ -109,9 +110,11 @@ def __init__(
config_overrides: Optional[dict[str, Any]] = None,
load_in_8bit: bool = False,
loss_fn: str = 'fused_crossentropy',
pretrained: bool = True,
**kwargs: dict[str, Any],
):
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.pretrained = pretrained
self.pretrained_lora_id_or_path = pretrained_lora_id_or_path
self.trust_remote_code = trust_remote_code
self.init_device = init_device
Expand Down Expand Up @@ -191,7 +194,7 @@ def construct_model(self):
model_class = registry.models.get('hf_causal_lm')
model_class = cast(type[ComposerHFCausalLM], model_class)
model = model_class.build_inner_model(
pretrained=True,
pretrained=self.pretrained,
pretrained_model_name_or_path=self.
pretrained_model_name_or_path,
pretrained_lora_id_or_path=self.pretrained_lora_id_or_path,
Expand Down
50 changes: 24 additions & 26 deletions tests/models/llm_embed/test_llm_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from contextlib import nullcontext
from typing import Any, Optional
from unittest.mock import patch

Expand Down Expand Up @@ -96,23 +95,31 @@ def model(
def build_lm_config(is_hf: bool, attn_impl: Optional[str]) -> dict[str, Any]:
if is_hf:
assert attn_impl is None
return {'pretrained_model_name_or_path': 'facebook/opt-350m'}
return {
'pretrained_model_name_or_path': 'facebook/opt-350m',
'pretrained': False,
'config_overrides': {
'hidden_size': 2,
'num_attention_heads': 2,
'num_hidden_layers': 2,
},
}
else:
assert attn_impl is not None
return {
'num_layers': 2,
'word_embed_proj_dim': 768,
'd_model': 768,
'n_heads': 12,
'vocab_size': 100352,
'word_embed_proj_dim': 128,
'd_model': 128,
'n_heads': 2,
'vocab_size': 4096,
'attn_config': {
'attn_impl': attn_impl,
},
}


def build_tokenizer_config(is_hf: bool) -> dict[str, Any]:
return {'vocab_size': 50257 if is_hf else 100352}
return {'vocab_size': 50257 if is_hf else 4096}


@pytest.mark.gpu
Expand All @@ -126,24 +133,20 @@ def test_mpt_embedding_lm(
maybe_attn_impl = None if is_hf else attn_impl
lm_config = build_lm_config(is_hf, maybe_attn_impl)

model = ContrastiveModel(**lm_config, tokenizer=mock_tokenizer).to(
torch.bfloat16,
).to('cuda')
model = ContrastiveModel(**lm_config, tokenizer=mock_tokenizer).to('cuda')
msl = 32
model_inputs_batch = mock_tokenizer([['pair 1 a', 'pair 1 b'],
['pair 2 a', 'pair 2 b']],
padding='max_length',
truncation=True,
max_length=128,
max_length=msl,
return_tensors='pt')
if isinstance(model_inputs_batch, dict):
model_inputs_batch = {
k: v.to('cuda') for k, v in model_inputs_batch.items()
}

ctx = get_precision_context(
'amp_bf16',
) if maybe_attn_impl == 'flash' else nullcontext()
with ctx:
with get_precision_context('amp_bf16'):
outputs = model(model_inputs_batch)

assert isinstance(outputs, dict)
Expand All @@ -156,7 +159,7 @@ def test_mpt_embedding_lm(
proj_dim = model.model.config.word_embed_proj_dim
assert last_hidden_state.shape == (
4,
128,
msl,
proj_dim,
) # 2 pairs * 2 texts per pair, 128 sequence length, word_embed_proj_dim dim
assert last_hidden_state.dtype == torch.bfloat16
Expand Down Expand Up @@ -194,26 +197,21 @@ def test_contrastive_loss(

with temporary_contrastive_streaming_dataset(ds_format) as data_dir:
lm_config = build_lm_config(is_hf, maybe_attn_impl)
model = ContrastiveModel(**lm_config, tokenizer=mock_tokenizer).to(
torch.bfloat16,
).to('cuda')
model = ContrastiveModel(**lm_config,
tokenizer=mock_tokenizer).to('cuda')

train_dataloader = build_dataloader(
dataloader_config(data_dir, 'local'),
mock_tokenizer,
2,
)

precision = 'amp_bf16' if maybe_attn_impl == 'flash' else 'fp32'
ctx = get_precision_context(
'amp_bf16',
) if attn_impl == 'flash' else nullcontext()
with ctx:
with get_precision_context('amp_bf16',):
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
precision=precision,
max_duration='3ba',
precision='amp_bf16',
max_duration='1ba',
)
trainer.fit()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self) -> None:
self.eos_token: str = '</s>'
self.bos_token: str = '<s>'
self.unk_token: str = '<unk>'
self._vocab_size: int = 30000
self._vocab_size: int = 128

def __len__(self) -> int:
return self._vocab_size
Expand Down

0 comments on commit bd113da

Please sign in to comment.