Skip to content

Commit

Permalink
Merge branch 'main' into log_original_config
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jul 30, 2024
2 parents 62919af + 9a62bfd commit a000b8b
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 14 deletions.
22 changes: 13 additions & 9 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@ jobs:
base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04-aws
dep_groups: "[gpu]"
steps:
- name: Maximize Build Space on Worker
uses: easimon/maximize-build-space@v4
with:
overprovision-lvm: true
remove-dotnet: true
remove-android: true
remove-haskell: true

- name: Checkout
uses: actions/checkout@v3
Expand All @@ -47,6 +40,13 @@ jobs:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}

- name: Login to GHCR
uses: docker/login-action@v2
with:
username: ${{ secrets.GHCR_USERNAME }}
password: ${{ secrets.GHCR_TOKEN }}
registry: ghcr.io

- name: Calculate Docker Image Variables
run: |
set -euxo pipefail
Expand All @@ -60,13 +60,17 @@ jobs:
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "Triggered by pull_request event."
STAGING_REPO="mosaicml/ci-staging"
IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}"
GHCR_STAGING_REPO="ghcr.io/databricks-mosaic/ci-staging"
GHCR_IMAGE_TAG="${GHCR_STAGING_REPO}:${{matrix.name}}-${GIT_SHA}"
IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA},${GHCR_IMAGE_TAG}"
IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache"
else
# Triggered by push or workflow_dispatch event
echo "Triggered by ${{ github.event_name }} event, releasing to prod"
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
GHCR_PROD_REPO="ghcr.io/databricks-mosaic/llm-foundry"
GHCR_IMAGE_TAG="${GHCR_PROD_REPO}:${{matrix.name}}-${GIT_SHA},${GHCR_PROD_REPO}:${{matrix.name}}-latest"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest,${GHCR_IMAGE_TAG}"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
fi
Expand Down
24 changes: 20 additions & 4 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@
}


def quickgelu_activation(input: torch.Tensor) -> torch.Tensor:
"""Applies GELU approximation that is fast but somewhat inaccurate.
Args:
input (torch.Tensor): Input tensor of shape(*), where * means any
number of dimensions
Returns:
torch.Tensor: Tensor with same shape as input tensor
"""
return input * torch.sigmoid(1.702 * input)


def resolve_ffn_act_fn(
config: Optional[dict] = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
Expand All @@ -70,10 +83,13 @@ def resolve_ffn_act_fn(
config = _FFN_ACT_FN_DEFAULT
config = deepcopy(config)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognized activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)
if name == 'quick_gelu':
return quickgelu_activation
else:
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognized activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)


_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
'mlflow>=2.14.1,<2.15',
'accelerate>=0.25,<0.33', # for HF inference `device_map`
'transformers>=4.43.2,<4.44',
'mosaicml-streaming>=0.7.6,<0.8',
'mosaicml-streaming>=0.8.0,<0.9',
'torch>=2.3.0,<2.4',
'datasets>=2.19,<2.20',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
Expand Down
73 changes: 73 additions & 0 deletions tests/models/layers/test_ffn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import torch.distributed as dist
import torch.nn as nn

from llmfoundry.models.layers.ffn import quickgelu_activation
from llmfoundry.models.layers.layer_builders import build_ffn


@pytest.mark.gpu
def test_quickgelu_activation():
d_model = 32
expansion_ratio = 1
no_bias = True
ffn_config = {
'ffn_act_fn': {
'name': 'quick_gelu',
},
'ffn_type': 'mptmlp',
}
rank: int = dist.get_rank()
device_str = f'cuda:{rank}'
device: torch.device = torch.device(device_str)

ffn1 = build_ffn(
name=ffn_config['ffn_type'],
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device_str,
bias=not no_bias,
ffn_kwargs=ffn_config,
)
assert (
ffn1.act == quickgelu_activation
), f'Expected quick_gelu activation function, got {ffn1.act}'

ffn_config = {
'ffn_act_fn': {
'name': 'gelu',
},
'ffn_type': 'mptmlp',
}
ffn2 = build_ffn(
name=ffn_config['ffn_type'],
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device_str,
bias=not no_bias,
ffn_kwargs=ffn_config,
)

def num_params(model: nn.Module) -> int:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([p.numel() for p in model_parameters])

ffn1_numparams = num_params(ffn1)
ffn2_numparams = num_params(ffn2)
assert (
ffn1_numparams == ffn2_numparams
), 'Only activation paths should have changed, re-check modeling!'

input_ = torch.rand(1, d_model, device=device)
output1 = ffn1(input_)
output2 = ffn2(input_)
assert (
output1.numel() == output2.numel()
), 'Only activation paths should have changed, re-check modeling!'
assert (
not torch.allclose(output1, output2)
), 'Functions are different, outputs should not match!'

0 comments on commit a000b8b

Please sign in to comment.