Skip to content

Commit

Permalink
Merge branch 'main' into anna/asynceval
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl authored Dec 18, 2023
2 parents cd2a31d + 06b9a1f commit f6393a9
Show file tree
Hide file tree
Showing 25 changed files with 482 additions and 278 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/smoketest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Smoketest
on:
push:
branches:
- main
- release/*
pull_request:
branches:
- main
- release/*
workflow_dispatch:
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }}
defaults:
run:
working-directory: .
jobs:
smoketest:
runs-on: ubuntu-20.04
timeout-minutes: 10
strategy:
matrix:
python_version:
- "3.9"
- "3.10"
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
- name: Setup
run: |
set -ex
python -m pip install --upgrade 'pip<23' wheel
python -m pip install --upgrade .
python -m pip install pytest==7.2.1 pytest_codeblocks==0.16.1
- name: Run checks
run: |
pytest tests/test_smoketest.py
20 changes: 20 additions & 0 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
import torch

try:
import warnings

# bitsandbytes is a very noisy library. A lot of it is print statements that we can't easily suppress,
# but we can at least suppress a bunch of spurious warnings.
warnings.filterwarnings('ignore',
category=UserWarning,
module='bitsandbytes')

import logging

from llmfoundry.utils.logging_utils import SpecificWarningFilter

# Filter out Hugging Face warning for not using a pinned revision of the model
hf_dynamic_modules_logger = logging.getLogger(
'transformers.dynamic_module_utils')
new_files_warning_filter = SpecificWarningFilter(
'A new version of the following files was downloaded from')

hf_dynamic_modules_logger.addFilter(new_files_warning_filter)

# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
Expand Down
1 change: 0 additions & 1 deletion llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ def build_text_denoising_dataloader(
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id,
decoder_only=cfg.mixture_of_denoisers.decoder_only_format)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
Expand Down
3 changes: 1 addition & 2 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)
token_counting_func = get_tokens_per_batch_func()

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)

Expand Down
16 changes: 15 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.logging_utils import SpecificWarningFilter

log = logging.getLogger(__name__)

__all__ = ['dataset_constructor']
Expand Down Expand Up @@ -236,7 +238,7 @@ def wrapper(func: Callable) -> Callable:

def print_registered_tasks(self) -> None:
tasks = sorted(self._task_preprocessing_registry.keys())
print('\n'.join(tasks))
log.info('\n'.join(tasks))

def get_preprocessing_fn_from_dict(
self, mapping: Union[Dict, DictConfig]
Expand Down Expand Up @@ -365,6 +367,15 @@ def build_from_hf(
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

hf_tokenization_logger = logging.getLogger(
'transformers.tokenization_utils_base')
sequence_length_warning_filter = SpecificWarningFilter(
'Token indices sequence length is longer than the specified maximum sequence length'
)

# We will trim examples later in the collate_fn, so we want to silence this warning from Hugging Face
hf_tokenization_logger.addFilter(sequence_length_warning_filter)

error: Optional[Exception] = None
filtered_dataset = None
try:
Expand Down Expand Up @@ -433,6 +444,9 @@ def filter_long_or_empty_examples(example: Dict) -> bool:
log.error('Error during data prep')
raise error
log.debug('All ranks finished data prep')

hf_tokenization_logger.removeFilter(sequence_length_warning_filter)

assert filtered_dataset is not None
return filtered_dataset

Expand Down
21 changes: 9 additions & 12 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,13 @@ def build_text_dataloader(
# and if tokenizing on the fly, we require that the tokenizer has a pad token.
token_counting_func = None
if tokenizer.pad_token_id is not None:
token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)
token_counting_func = get_tokens_per_batch_func()

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def get_tokens_per_batch_func(pad_token_id: int,
decoder_only: bool = True
) -> Callable[[Batch], int]:
def get_tokens_per_batch_func(
decoder_only: bool = True) -> Callable[[Batch], int]:
"""Returns a callable that counts the number of tokens in a batch.
Args:
Expand All @@ -327,25 +325,24 @@ def get_tokens_per_batch_func(pad_token_id: int,
"""

def get_num_samples_in_batch(batch: Batch) -> int:
if not isinstance(batch, Mapping) or 'input_ids' not in batch:
if not isinstance(batch, Mapping) or 'attention_mask' not in batch:
raise ValueError(
'get_tokens_per_batch_func() requires a batch with an input_ids key'
'get_tokens_per_batch_func() requires a batch with an attention_mask key'
)

if not decoder_only and 'decoder_input_ids' not in batch:
if not decoder_only and 'decoder_attention_mask' not in batch:
raise ValueError(
'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_input_ids key'
'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key'
)

# Count number of non padding tokens in batch
input_ids_tokens = int(
torch.sum(batch['input_ids'] != pad_token_id).item())
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())

# For encoder decoder models only
decoder_input_ids_tokens = 0
if not decoder_only:
decoder_input_ids_tokens = int(
torch.sum(batch['decoder_input_ids'] != pad_token_id).item())
torch.sum(batch['decoder_attention_mask']).item())

return input_ids_tokens + decoder_input_ids_tokens

Expand Down
28 changes: 18 additions & 10 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from composer.utils import dist
from omegaconf import DictConfig
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM,
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
Expand Down Expand Up @@ -102,20 +102,27 @@ def __init__(self, om_model_config: Union[DictConfig,
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please install flash_attn==2.3.2`.')

requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
attn_implementation=requested_attention_implementation,
use_cache=
False, # Necessary due to https://github.com/huggingface/transformers/issues/28056
)

# This is not how you are supposed to set this, but transformers currently only
# supports enabling flash attention 2 when using the from_pretrained API.
# We need to support it for both from_pretrained and from_config, so we have to
# set the private attribute here. This will just skip all of transformers'
# validation logic that it is ok to use flash attention 2, so we check
# whether it is installed above, and whether the chosen config supports it here.
# https://github.com/huggingface/transformers/issues/26878
config._flash_attn_2_enabled = use_flash_attention_2
# This is not ideal, however Hugging Face's _autoset_attn_implementation function
# forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading
# the model and then casting it back to fp32, we are monkeypatching their check.
# https://github.com/huggingface/transformers/issues/28052
def _autoset_attn_implementation_monkeypatch(
cls, config, *args, **kwargs): # type: ignore
config._attn_implementation = requested_attention_implementation
return config

PreTrainedModel._autoset_attn_implementation = classmethod(
_autoset_attn_implementation_monkeypatch)

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
Expand Down Expand Up @@ -184,7 +191,8 @@ def __init__(self, om_model_config: Union[DictConfig,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
load_in_8bit=load_in_8bit,
config=config)
config=config,
)
else:
model = AutoModelForCausalLM.from_config(
config,
Expand Down
7 changes: 3 additions & 4 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):

def get_metrics(self, is_train: bool = False):
if is_train:
raise NotImplementedError(
'You cannot use inference wrappers for training')
metrics = None
else:
metrics = self.eval_metrics

Expand All @@ -55,6 +54,7 @@ def rebatch(self, batch: Batch):
return batch

def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
padding_tok = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
# If the batch mode is generate, we will generate a requested number of tokens using the underlying
# model's generate function. Extra generation kwargs can be passed in via the batch. Strings will
# be returned from eval_forward
Expand All @@ -80,8 +80,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
[output_logits,
next_logit_tensor.reshape(1, -1)])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
torch.full((seqlen - output_logits.shape[0],), padding_tok),
num_classes=self.tokenizer.vocab_size)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)
Expand Down
Loading

0 comments on commit f6393a9

Please sign in to comment.