Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 1, 2023
2 parents 6c59dce + 8339cd3 commit 805313b
Show file tree
Hide file tree
Showing 21 changed files with 927 additions and 223 deletions.
8 changes: 8 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Require admin approval to modify all files in the root of the repository
# This includes setup.py, the README, and the CODEOWNERS file itself!
/* @mosaicml/composer-team-admins

# Require admin approval to change the CI build configuration
# All CI Changes should be reviewed for security
/.ci/ @mosaicml/composer-team-admins
/.github/ @mosaicml/composer-team-admins
6 changes: 3 additions & 3 deletions .github/mcp/mcp_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
print(line, end='')

print('[GHA] Run completed. Waiting for run to finish...')
run = wait_for_run_status(run, status='completed')
run = wait_for_run_status(run, status=RunStatus.COMPLETED)

# Fail if command exited with non-zero exit code or timed out
assert run.status == RunStatus.COMPLETED
# Fail if command exited with non-zero exit code or timed out (didn't reach COMPLETED)
assert run.status == RunStatus.COMPLETED, f'Run did not complete: {run.status} ({run.reason})'
32 changes: 12 additions & 20 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.text_data import build_text_dataloader

LOADER_NAME_TO_FUNCTION = {
'text': build_text_dataloader,
'text_denoising': build_text_denoising_dataloader,
'finetuning': build_finetuning_dataloader,
}


def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataSpec:
Expand All @@ -22,23 +28,9 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
if cfg.name == 'text':
return build_text_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(
cfg,
tokenizer,
device_batch_size,
)
else:
raise ValueError(f'Not sure how to build dataloader with config: {cfg}')
if cfg.name not in LOADER_NAME_TO_FUNCTION:
allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys())
raise ValueError(f'Expected dataloader name to be one of {allowed}' +
f' but found name "{cfg.name}" in config: {cfg}')

return LOADER_NAME_TO_FUNCTION[cfg.name](cfg, tokenizer, device_batch_size)
51 changes: 28 additions & 23 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor,

def forward(
self,
input_ids: torch.LongTensor,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
Expand Down Expand Up @@ -497,11 +497,6 @@ def forward(
'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
)

# Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds is not implemented for MPT.')

if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
raise ValueError(
Expand All @@ -515,14 +510,25 @@ def forward(
'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
)

S = input_ids.size(1)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both input_ids and inputs_embeds.')
elif input_ids is not None:
S = input_ids.size(1)
x = self.wte(input_ids)
input_device = input_ids.device
elif inputs_embeds is not None:
S = inputs_embeds.size(1)
x = inputs_embeds
input_device = inputs_embeds.device
else:
raise ValueError('You must specify input_ids or inputs_embeds')

assert (
S <= self.config.max_seq_len
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'

rotary_emb_w_meta_info = None
x = self.wte(input_ids)
if self.learned_pos_emb or self.rope:
past_position = 0
if past_key_values is not None:
Expand Down Expand Up @@ -552,7 +558,7 @@ def forward(
past_position,
S + past_position,
dtype=torch.long,
device=input_ids.device,
device=input_device,
).unsqueeze(0)
if attention_mask is not None:
# adjust the position indices to account for padding tokens
Expand Down Expand Up @@ -743,7 +749,7 @@ def get_decoder(self) -> MPTModel:

def forward(
self,
input_ids: torch.LongTensor,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
Expand All @@ -760,11 +766,6 @@ def forward(
use_cache = (use_cache
if use_cache is not None else self.config.use_cache)

# if input_embeds is not none, raise a not implemented error
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds has to be None (for hf/peft support).')
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
Expand All @@ -775,6 +776,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
)

if self.lm_head is not None:
Expand Down Expand Up @@ -864,10 +866,6 @@ def prepare_inputs_for_generation(
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds is not implemented for MPT yet')

attention_mask = kwargs['attention_mask'].bool()
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
raise NotImplementedError(
Expand All @@ -878,6 +876,7 @@ def prepare_inputs_for_generation(
else:
sequence_id = None

# only last token for inputs_ids if past is defined in kwargs
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)

Expand All @@ -891,14 +890,20 @@ def prepare_inputs_for_generation(
else:
prefix_mask = None

return {
'input_ids': input_ids,
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}

model_inputs.update({
'attention_mask': attention_mask,
'prefix_mask': prefix_mask,
'sequence_id': sequence_id,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache', True),
}
})
return model_inputs

@staticmethod
def _reorder_cache(
Expand Down Expand Up @@ -989,7 +994,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
add_bidirectional_mask_if_missing(batch)
# Note: prefix_mask is only used if model.prefix_lm is True
return self.model(
input_ids=batch['input_ids'],
input_ids=batch.get('input_ids', None),
attention_mask=batch.get('attention_mask', None),
prefix_mask=batch.get('bidirectional_mask', None),
sequence_id=batch.get('sequence_id', None),
Expand Down
81 changes: 81 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling,
HuggingFaceCheckpointer, LayerFreezing,
MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW, DecoupledLionW_8bit)
from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler
Expand All @@ -42,6 +44,85 @@
log = logging.getLogger(__name__)


def build_evaluators(
eval_loader_config: Optional[Union[DictConfig, ListConfig]],
icl_tasks_config: Optional[Union[str, ListConfig]],
eval_gauntlet_config: Optional[Union[str, DictConfig]],
*,
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
icl_seq_len: int,
icl_subset_num_batches: Optional[int],
) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:

evaluators = []
if eval_loader_config is not None:
evaluators = build_eval_loaders(
eval_loader_config,
tokenizer,
device_eval_batch_size,
)

logger_keys = []
eval_gauntlet_callback = None
if icl_tasks_config is not None:
icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks_config,
eval_gauntlet_config,
tokenizer,
device_eval_batch_size,
icl_seq_len,
icl_subset_num_batches,
)
evaluators.extend(icl_evaluators)

return evaluators, logger_keys, eval_gauntlet_callback


def build_eval_loaders(
eval_loader_config: Union[DictConfig, ListConfig],
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
) -> List[Evaluator]:
evaluators: List[Evaluator] = []
if isinstance(eval_loader_config, ListConfig):
eval_configs: ListConfig = eval_loader_config
is_multi_eval = True
else:
eval_configs = ListConfig([eval_loader_config])
is_multi_eval = False

for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader: Evaluator = Evaluator(
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
# Load the eval data to fail fast. metrics will get added
# later in add_metrics_to_eval_loaders, after the model is loaded
metric_names=[],
)
evaluators.append(eval_loader)
return evaluators


def add_metrics_to_eval_loaders(
evaluators: List[Evaluator],
metrics: Dict[str, Metric],
) -> List[Evaluator]:
metric_names = list(metrics.keys())
eval_loaders, other_evaluators = [], []
for evaluator in evaluators:
if evaluator.metric_names == []:
evaluator.metric_names = metric_names
eval_loaders.append(evaluator)
else:
other_evaluators.append(evaluator)

# Put the base eval_loaders first
return eval_loaders + other_evaluators


def build_icl_data_and_gauntlet(
icl_tasks_config: Union[str, ListConfig],
eval_gauntlet_config: Optional[Union[str, DictConfig]],
Expand Down
11 changes: 6 additions & 5 deletions llmfoundry/utils/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,16 @@ def __init__(

def __iter__(self):
for object_name in self.object_names:
object_name = object_name.strip('/')
output_filename = os.path.join(self.output_folder, object_name)
# Default output_filename, used for local paths.
output_filename = object_name

# Download objects if remote path.
if self.object_store is not None:
output_filename = os.path.join(self.output_folder,
object_name.strip('/'))
self.object_store.download_object(object_name=object_name,
filename=output_filename,
overwrite=True)
else:
# Inputs are local so we do not need to download them.
output_filename = object_name

with open(output_filename) as _txt_file:
txt = _txt_file.read()
Expand Down
58 changes: 58 additions & 0 deletions llmfoundry/utils/prompt_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
from typing import List, Optional

PROMPTFILE_PREFIX = 'file::'


def load_prompts(prompts: List[str],
prompt_delimiter: Optional[str] = None) -> List[str]:
"""Loads a set of prompts, both free text and from file.
Args:
prompts (List[str]): List of free text prompts and prompt files
prompt_delimiter (Optional str): Delimiter for text file
If not provided, assumes the prompt file is a single prompt (non-delimited)
Returns:
List of prompt string(s)
"""
prompt_strings = []
for prompt in prompts:
if prompt.startswith(PROMPTFILE_PREFIX):
prompts = load_prompts_from_file(prompt, prompt_delimiter)
prompt_strings.extend(prompts)
else:
prompt_strings.append(prompt)
return prompt_strings


def load_prompts_from_file(prompt_path: str,
prompt_delimiter: Optional[str] = None) -> List[str]:
"""Load a set of prompts from a text fie.
Args:
prompt_path (str): Path for text file
prompt_delimiter (Optional str): Delimiter for text file
If not provided, assumes the prompt file is a single prompt (non-delimited)
Returns:
List of prompt string(s)
"""
if not prompt_path.startswith(PROMPTFILE_PREFIX):
raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}')

_, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1)
prompt_file_path = os.path.expanduser(prompt_file_path)
if not os.path.isfile(prompt_file_path):
raise FileNotFoundError(
f'{prompt_file_path=} does not match any existing files.')

with open(prompt_file_path, 'r') as f:
prompt_string = f.read()

if prompt_delimiter is None:
return [prompt_string]
return [i for i in prompt_string.split(prompt_delimiter) if i]
Loading

0 comments on commit 805313b

Please sign in to comment.