Skip to content

Commit

Permalink
Bump to composer 0.17 (#736)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 16, 2023
1 parent e796218 commit e730995
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
]

install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17',
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17,<0.18',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.34.1,<4.35',
'mosaicml-streaming>=0.7.1,<0.8',
Expand Down Expand Up @@ -84,11 +84,11 @@
]

extra_deps['databricks'] = [
'mosaicml[databricks]',
'mosaicml[databricks]>=0.17,<0.18',
]

extra_deps['tensorboard'] = [
'mosaicml[tensorboard]>=0.16.1,<0.17',
'mosaicml[tensorboard]>=0.17,<0.18',
]

extra_deps['gpu'] = [
Expand Down
55 changes: 52 additions & 3 deletions tests/test_mpt_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,15 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool,
@pytest.mark.gpu
@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
@pytest.mark.parametrize('use_alibi', [True, False])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_mpt_generate_callback(attn_impl: str, use_alibi: bool,
tie_word_embeddings: bool,
build_tiny_mpt: Callable[...,
ComposerMPTCausalLM],
tiny_ft_dataloader: DataLoader):
device = get_device('gpu')

# build mpt model
model = build_tiny_mpt(
tie_word_embeddings=tie_word_embeddings,
tie_word_embeddings=True,
attn_config={
'attn_impl': attn_impl,
'attn_uses_sequence_id': False,
Expand Down Expand Up @@ -143,3 +141,54 @@ def test_mpt_generate_callback(attn_impl: str, use_alibi: bool,

generate.generate.assert_called_once()
trainer.logger.log_table.assert_called_once()


@pytest.mark.gpu
@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
@pytest.mark.parametrize('use_alibi', [True, False])
def test_mpt_generate_callback_not_tied(
use_alibi: bool, attn_impl: str,
build_tiny_mpt: Callable[..., ComposerMPTCausalLM],
tiny_ft_dataloader: DataLoader):
device = get_device('gpu')

# build mpt model
model = build_tiny_mpt(
tie_word_embeddings=False,
attn_config={
'attn_impl': attn_impl,
'attn_uses_sequence_id': False,
'alibi': use_alibi,
},
)
model = device.module_to_device(model)

# generate callback
prompts = [
'The best banana bread recipe is',
'2+2=',
'how much wood could a woodchuck chuck',
]
gen_interval = 1
generate = ComposerGenerate(
prompts,
interval=f'{gen_interval}ba',
max_new_tokens=5,
batch_size=len(prompts),
use_cache=True,
)
generate.generate = Mock(wraps=generate.generate, autospec=True)

# build trainer
trainer = Trainer(
model=model,
train_dataloader=tiny_ft_dataloader,
device=device,
max_duration=f'{gen_interval}ba',
callbacks=[generate],
)
trainer.logger.log_table = Mock()
trainer.fit()

generate.generate.assert_called_once()
trainer.logger.log_table.assert_called_once()

0 comments on commit e730995

Please sign in to comment.