Skip to content

Commit

Permalink
Clean up the logs, bump datasets and transformers (#804)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Dec 15, 2023
1 parent 15e79f3 commit 06b9a1f
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 56 deletions.
20 changes: 20 additions & 0 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
import torch

try:
import warnings

# bitsandbytes is a very noisy library. A lot of it is print statements that we can't easily suppress,
# but we can at least suppress a bunch of spurious warnings.
warnings.filterwarnings('ignore',
category=UserWarning,
module='bitsandbytes')

import logging

from llmfoundry.utils.logging_utils import SpecificWarningFilter

# Filter out Hugging Face warning for not using a pinned revision of the model
hf_dynamic_modules_logger = logging.getLogger(
'transformers.dynamic_module_utils')
new_files_warning_filter = SpecificWarningFilter(
'A new version of the following files was downloaded from')

hf_dynamic_modules_logger.addFilter(new_files_warning_filter)

# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
Expand Down
16 changes: 15 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.logging_utils import SpecificWarningFilter

log = logging.getLogger(__name__)

__all__ = ['dataset_constructor']
Expand Down Expand Up @@ -236,7 +238,7 @@ def wrapper(func: Callable) -> Callable:

def print_registered_tasks(self) -> None:
tasks = sorted(self._task_preprocessing_registry.keys())
print('\n'.join(tasks))
log.info('\n'.join(tasks))

def get_preprocessing_fn_from_dict(
self, mapping: Union[Dict, DictConfig]
Expand Down Expand Up @@ -365,6 +367,15 @@ def build_from_hf(
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

hf_tokenization_logger = logging.getLogger(
'transformers.tokenization_utils_base')
sequence_length_warning_filter = SpecificWarningFilter(
'Token indices sequence length is longer than the specified maximum sequence length'
)

# We will trim examples later in the collate_fn, so we want to silence this warning from Hugging Face
hf_tokenization_logger.addFilter(sequence_length_warning_filter)

error: Optional[Exception] = None
filtered_dataset = None
try:
Expand Down Expand Up @@ -433,6 +444,9 @@ def filter_long_or_empty_examples(example: Dict) -> bool:
log.error('Error during data prep')
raise error
log.debug('All ranks finished data prep')

hf_tokenization_logger.removeFilter(sequence_length_warning_filter)

assert filtered_dataset is not None
return filtered_dataset

Expand Down
28 changes: 18 additions & 10 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from composer.utils import dist
from omegaconf import DictConfig
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM,
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
Expand Down Expand Up @@ -102,20 +102,27 @@ def __init__(self, om_model_config: Union[DictConfig,
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please install flash_attn==2.3.2`.')

requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
attn_implementation=requested_attention_implementation,
use_cache=
False, # Necessary due to https://github.com/huggingface/transformers/issues/28056
)

# This is not how you are supposed to set this, but transformers currently only
# supports enabling flash attention 2 when using the from_pretrained API.
# We need to support it for both from_pretrained and from_config, so we have to
# set the private attribute here. This will just skip all of transformers'
# validation logic that it is ok to use flash attention 2, so we check
# whether it is installed above, and whether the chosen config supports it here.
# https://github.com/huggingface/transformers/issues/26878
config._flash_attn_2_enabled = use_flash_attention_2
# This is not ideal, however Hugging Face's _autoset_attn_implementation function
# forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading
# the model and then casting it back to fp32, we are monkeypatching their check.
# https://github.com/huggingface/transformers/issues/28052
def _autoset_attn_implementation_monkeypatch(
cls, config, *args, **kwargs): # type: ignore
config._attn_implementation = requested_attention_implementation
return config

PreTrainedModel._autoset_attn_implementation = classmethod(
_autoset_attn_implementation_monkeypatch)

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
Expand Down Expand Up @@ -184,7 +191,8 @@ def __init__(self, om_model_config: Union[DictConfig,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
load_in_8bit=load_in_8bit,
config=config)
config=config,
)
else:
model = AutoModelForCausalLM.from_config(
config,
Expand Down
21 changes: 21 additions & 0 deletions llmfoundry/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging


class SpecificWarningFilter(logging.Filter):

def __init__(self, message_to_suppress: str):
"""Filter out a specific warning message based on its content.
This can be useful for filtering out specific warning messages from third party packages.
Args:
message_to_suppress (str): The warning message to suppress.
"""
super().__init__()
self.message_to_suppress = message_to_suppress

def filter(self, record: logging.LogRecord) -> bool:
return self.message_to_suppress not in record.getMessage()
12 changes: 7 additions & 5 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
build_tokenizer)
from llmfoundry.utils.config_utils import pop_config, process_init_device

log = logging.getLogger(__name__)


def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
num_retries: int) -> ComposerModel:
Expand Down Expand Up @@ -65,7 +67,7 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
if retries >= num_retries:
raise e
else:
print(
log.info(
f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining'
)

Expand All @@ -89,7 +91,7 @@ def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
if retries >= num_retries:
raise e
else:
print(
log.info(
f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining'
)

Expand All @@ -116,7 +118,7 @@ def evaluate_model(
icl_subset_num_batches: Optional[int],
):

print(f'Evaluating model: {model_cfg.model_name}', flush=True)
log.info(f'Evaluating model: {model_cfg.model_name}')
# Build tokenizer and model
tokenizer_cfg: Dict[str,
Any] = om.to_container(model_cfg.tokenizer,
Expand Down Expand Up @@ -200,7 +202,7 @@ def evaluate_model(
if torch.cuda.is_available():
torch.cuda.synchronize()
b = time.time()
print(f'Ran {model_cfg.model_name} eval in: {b-a} seconds')
log.info(f'Ran {model_cfg.model_name} eval in: {b-a} seconds')
return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df)


Expand All @@ -215,7 +217,7 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
must_exist=False,
default_value=None)
if eval_gauntlet_config:
print(
warnings.warn(
'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`'
)

Expand Down
36 changes: 19 additions & 17 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
process_init_device,
update_batch_size_info)

log = logging.getLogger(__name__)


def validate_config(cfg: DictConfig):
"""Validates compatible model and dataloader selection."""
Expand Down Expand Up @@ -138,17 +140,17 @@ def build_composer_peft_model(
+ f'Error encountered: {e}')

# 1) loads a hf model, 2) adds peft modules, 3) wraps it in a ComposerHFCausalLM.
print('Building Lora config...')
log.info('Building Lora config...')
lora_cfg = LoraConfig(**lora_args)

print('Building model from HuggingFace checkpoint...')
log.info('Building model from HuggingFace checkpoint...')
model = MPTForCausalLM.from_pretrained(pretrained_model_name_or_path,
trust_remote_code=True)
print('Model built!')
log.info('Model built!')

print('Adding Lora modules...')
log.info('Adding Lora modules...')
model = get_peft_model(model, lora_cfg)
print('Lora modules added!')
log.info('Lora modules added!')

model = ComposerHFCausalLM(model, tokenizer)

Expand All @@ -163,7 +165,7 @@ def print_trainable_parameters(model: torch.nn.Module) -> None:
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
log.info(
f'trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}'
)

Expand Down Expand Up @@ -260,9 +262,9 @@ def main(cfg: DictConfig) -> Trainer:
must_exist=False,
default_value=None)
if eval_gauntlet_config is not None:
print(
'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`'
)
warnings.warn(
'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`',
DeprecationWarning)
icl_subset_num_batches: Optional[int] = pop_config(cfg,
'icl_subset_num_batches',
must_exist=False,
Expand Down Expand Up @@ -398,7 +400,7 @@ def main(cfg: DictConfig) -> Trainer:
autoresume_default = True

if cfg.get('autoresume') is None and autoresume_default:
print('As run_name, save_folder, and save_latest_filename are set, \
log.info('As run_name, save_folder, and save_latest_filename are set, \
changing autoresume default to True...')

autoresume: bool = pop_config(cfg,
Expand Down Expand Up @@ -514,7 +516,7 @@ def main(cfg: DictConfig) -> Trainer:
] if algorithm_configs else None

# Dataloaders
print('Building train loader...')
log.info('Building train loader...')
train_loader = build_dataloader(
train_loader_config,
tokenizer,
Expand All @@ -525,7 +527,7 @@ def main(cfg: DictConfig) -> Trainer:
mosaicml_logger.log_metrics({'data_validated': time.time()})

## Evaluation
print('Building eval loader...')
log.info('Building eval loader...')
eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len
evaluators, _, eval_gauntlet_callback = build_evaluators(
eval_loader_config,
Expand All @@ -541,7 +543,7 @@ def main(cfg: DictConfig) -> Trainer:
callbacks.append(eval_gauntlet_callback)

# Build Model
print('Initializing model...')
log.info('Initializing model...')
with init_context:
if lora_config is not None: # frozen model + trainable lora modules
model: ComposerHFCausalLM = build_composer_peft_model(
Expand Down Expand Up @@ -570,7 +572,7 @@ def main(cfg: DictConfig) -> Trainer:
evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics)

# Build the Trainer
print('Building trainer...')
log.info('Building trainer...')
trainer = Trainer(
run_name=run_name,
seed=seed,
Expand Down Expand Up @@ -609,7 +611,7 @@ def main(cfg: DictConfig) -> Trainer:
compile_config=compile_config,
)

print('Logging config')
log.info('Logging config')
log_config(logged_cfg)
torch.cuda.empty_cache()
gc.collect()
Expand All @@ -618,10 +620,10 @@ def main(cfg: DictConfig) -> Trainer:
if eval_first and trainer.state.timestamp.batch.value == 0:
trainer.eval()

print('Starting training...')
log.info('Starting training...')
trainer.fit()

print('Done.')
log.info('Done.')
return trainer


Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@

install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.1,<0.18',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.34.1,<4.35',
'accelerate>=0.25,<0.26', # for HF inference `device_map`
'transformers>=4.36,<4.37',
'mosaicml-streaming>=0.7.1,<0.8',
'torch>=2.1,<2.1.1',
'datasets>=2.14.5,<2.15',
'datasets==2.15.0',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
'sentencepiece==0.1.97',
'einops==0.5.0',
Expand Down
4 changes: 2 additions & 2 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_callback_inits():
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)])
@patch('os.cpu_count', MagicMock(return_value=None))
@patch('os.cpu_count', MagicMock(return_value=1))
def test_huggingface_conversion_callback_interval(
tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str,
save_interval: str, max_duration: str, expected_hf_checkpoints: int,
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_huggingface_conversion_callback_interval(
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('1ba', '1ba', '1ba', 1, 1)])
@patch('os.cpu_count', MagicMock(return_value=None))
@patch('os.cpu_count', MagicMock(return_value=1))
def test_huggingface_conversion_callback(
model: str,
tmp_path: pathlib.Path,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path:


@fixture
@patch('os.cpu_count', MagicMock(return_value=None))
@patch('os.cpu_count', MagicMock(return_value=1))
def tiny_ft_dataloader(tiny_ft_dataset_path: Path,
mpt_tokenizer: PreTrainedTokenizerBase,
max_seq_len: int = 128,
Expand Down
Loading

0 comments on commit 06b9a1f

Please sign in to comment.