Skip to content

Commit

Permalink
Merge branch 'main' into shashank/seq_id_flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Nov 18, 2023
2 parents 88c6808 + 269ded6 commit 511a405
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 119 deletions.
1 change: 1 addition & 0 deletions scripts/inference/hf_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def main(args: Namespace) -> None:
if device is not None:
print(f'Placing model on {device=}...')
model.to(device)
model.to(model_dtype)
except Exception as e:
raise RuntimeError(
'Unable to load HF model. ' +
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
]

install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17',
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17,<0.18',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.34.1,<4.35',
'mosaicml-streaming>=0.7.1,<0.8',
Expand Down Expand Up @@ -84,11 +84,11 @@
]

extra_deps['databricks'] = [
'mosaicml[databricks]',
'mosaicml[databricks]>=0.17,<0.18',
]

extra_deps['tensorboard'] = [
'mosaicml[tensorboard]>=0.16.1,<0.17',
'mosaicml[tensorboard]>=0.17,<0.18',
]

extra_deps['gpu'] = [
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from unittest.mock import MagicMock, patch

from composer.utils import dist
from omegaconf import DictConfig
Expand All @@ -25,6 +26,7 @@ def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path:


@fixture
@patch('os.cpu_count', MagicMock(return_value=None))
def tiny_ft_dataloader(tiny_ft_dataset_path: Path,
mpt_tokenizer: PreTrainedTokenizerBase,
max_seq_len: int = 128,
Expand Down
Loading

0 comments on commit 511a405

Please sign in to comment.