Skip to content

Commit

Permalink
Merge branch 'main' into fix_siqa
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml authored Dec 4, 2023
2 parents f9ea102 + 84b5d96 commit 2df7441
Show file tree
Hide file tree
Showing 73 changed files with 2,279 additions and 768 deletions.
8 changes: 8 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Require admin approval to modify all files in the root of the repository
# This includes setup.py, the README, and the CODEOWNERS file itself!
/* @mosaicml/composer-team-admins

# Require admin approval to change the CI build configuration
# All CI Changes should be reviewed for security
/.ci/ @mosaicml/composer-team-admins
/.github/ @mosaicml/composer-team-admins
6 changes: 3 additions & 3 deletions .github/mcp/mcp_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
print(line, end='')

print('[GHA] Run completed. Waiting for run to finish...')
run = wait_for_run_status(run, status='completed')
run = wait_for_run_status(run, status=RunStatus.COMPLETED)

# Fail if command exited with non-zero exit code or timed out
assert run.status == RunStatus.COMPLETED
# Fail if command exited with non-zero exit code or timed out (didn't reach COMPLETED)
assert run.status == RunStatus.COMPLETED, f'Run did not complete: {run.status} ({run.reason})'
14 changes: 6 additions & 8 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,17 @@ jobs:
GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7)
echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV}
if [ "${{ github.event_name }}" == "push" ]; then
echo "Triggered by push event."
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
elif [ "${{ github.event_name }}" == "pull_request" ]; then
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "Triggered by pull_request event."
STAGING_REPO="mosaicml/ci-staging"
IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}"
IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache"
else
echo "Triggered by unknown event: ${{ github.event_name }}"
exit 1
# Triggered by push or workflow_dispatch event
echo "Triggered by ${{ github.event_name }} event, releasing to prod"
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
fi
echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV}
Expand Down
32 changes: 12 additions & 20 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.text_data import build_text_dataloader

LOADER_NAME_TO_FUNCTION = {
'text': build_text_dataloader,
'text_denoising': build_text_denoising_dataloader,
'finetuning': build_finetuning_dataloader,
}


def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataSpec:
Expand All @@ -22,23 +28,9 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
if cfg.name == 'text':
return build_text_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(
cfg,
tokenizer,
device_batch_size,
)
else:
raise ValueError(f'Not sure how to build dataloader with config: {cfg}')
if cfg.name not in LOADER_NAME_TO_FUNCTION:
allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys())
raise ValueError(f'Expected dataloader name to be one of {allowed}' +
f' but found name "{cfg.name}" in config: {cfg}')

return LOADER_NAME_TO_FUNCTION[cfg.name](cfg, tokenizer, device_batch_size)
39 changes: 30 additions & 9 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,46 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

__all__ = ['dataset_constructor']

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}


def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
if ('prompt' not in example) or ('response' not in example):
"""Tokenize a formatted example and validate expected keys."""
example_keys = set(example.keys())
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)

if len(prompt_keys) != 1:
raise KeyError(
f'Unable to tokenize example because {len(prompt_keys)} of the allowed prompt keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_PROMPT_KEYS=}'
)

if len(response_keys) != 1:
raise KeyError(
'Unable to tokenize example because it has not been properly formatted. ' +\
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {example=}.'
f'Unable to tokenize example because {len(response_keys)} of the allowed response keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_RESPONSE_KEYS=}'
)
if not isinstance(example['prompt'], str):

prompt_key = prompt_keys.pop()
response_key = response_keys.pop()
prompt = example[prompt_key]
response = example[response_key]

if not isinstance(prompt, str):
raise TypeError(
f'Unable to tokenize example because "prompt" was not a string. {example=}'
f'Unable to tokenize example because {prompt_key} was not a string. {example=}'
)
if not isinstance(example['response'], str):

if not isinstance(response, str):
raise TypeError(
f'Unable to tokenize example because "response" was not a string. {example=}'
f'Unable to tokenize example because {response_key} was not a string. {example=}'
)
return tokenizer(text=example['prompt'], text_target=example['response'])

return tokenizer(text=prompt, text_target=response)


class StreamingFinetuningDataset(StreamingDataset):
Expand Down
104 changes: 65 additions & 39 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def scaled_multihead_dot_product_attention(
multiquery: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:

if multiquery:
warnings.warn(
DeprecationWarning(
Expand Down Expand Up @@ -219,6 +218,9 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
Expand Down Expand Up @@ -249,58 +251,65 @@ def flash_attn_fn(

past_key_value = (key, value)

if attn_bias is not None:
# clamp to 0 necessary for torch 2.0 compile()
_s_q = max(0, attn_bias.size(2) - query.size(1))
_s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:]

if attn_bias is not None:
raise NotImplementedError(f'attn_bias not implemented for flash attn.')

batch_size, seqlen = query.shape[:2]

if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
if attention_mask_in_length is None:
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
value_unpad, _, _, _ = unpadding_function(value, key_padding_mask)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

# multi-query case
if kv_n_heads == 1:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately
if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
not should_repeat_kv_for_gqa):
raise ValueError(
'For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.'
)

key_unpad = repeat_kv_for_gqa(
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)
if should_repeat_kv_for_gqa:
# multi-query case
if kv_n_heads == 1:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -331,7 +340,8 @@ def flash_attn_fn(
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
return_attn_probs=needs_weights,
window_size=(sliding_window_size, sliding_window_size))
else:
raise RuntimeError(
'flash-attn==1.0.9 or flash-attn==2.3.2 is required.')
Expand Down Expand Up @@ -490,6 +500,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
):
super().__init__()

Expand All @@ -500,6 +511,7 @@ def __init__(
self.d_model = d_model
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads
self.sliding_window_size = sliding_window_size

self.head_dim = d_model // n_heads

Expand Down Expand Up @@ -569,6 +581,7 @@ def forward(
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand Down Expand Up @@ -626,6 +639,14 @@ def forward(
query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)

extra_attn_kwargs = {}
if self.attn_impl == 'flash':
extra_attn_kwargs = {
'attention_mask_in_length': attention_mask_in_length,
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
}

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
Expand All @@ -640,6 +661,7 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
**extra_attn_kwargs,
)

return self.out_proj(context), attn_weights, past_key_value
Expand All @@ -665,6 +687,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
):
super().__init__(
d_model=d_model,
Expand All @@ -679,6 +702,7 @@ def __init__(
fc_type=fc_type,
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
)


Expand All @@ -702,6 +726,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
):
super().__init__(
d_model=d_model,
Expand All @@ -716,6 +741,7 @@ def __init__(
fc_type=fc_type,
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
)


Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'sliding_window_size': -1,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
Expand Down Expand Up @@ -113,6 +114,7 @@ def forward(
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -124,6 +126,7 @@ def forward(
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
attention_mask_in_length=attention_mask_in_length,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
Loading

0 comments on commit 2df7441

Please sign in to comment.