Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change style for splitting on commas #3078

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions .github/bin/gen_docker_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ def _parse_args() -> Namespace:
"""
args = ArgumentParser(description='Process a Docker matrix YAML file.')
args.add_argument('yaml_file', type=FileType('r'), help='The YAML file to be processed.')
args.add_argument('-b',
'--build_args',
action='append',
required=False,
help='List of build args to override globally')
args.add_argument(
'-b',
'--build_args',
action='append',
required=False,
help='List of build args to override globally',
)

return args.parse_args()

Expand Down
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
Expand Down
6 changes: 4 additions & 2 deletions composer/algorithms/alibi/alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def replacement_function(module: torch.nn.Module, module_index: int):
count = len(replaced_pairs)
if count == 0:
supported_modules = ''.join(sorted(['\n\t' + c.__module__ + '.' + c.__name__ for c in policy_registry.keys()]))
log.warning(f'ALiBi had no effect on the model! Support for ALiBi surgery '
f'is currently limited to the following classes: {supported_modules}')
log.warning(
f'ALiBi had no effect on the model! Support for ALiBi surgery '
f'is currently limited to the following classes: {supported_modules}',
)
else:
log.info(f' {count} instances of ALiBi added')

Expand Down
30 changes: 19 additions & 11 deletions composer/algorithms/alibi/attention_surgery_functions/_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from transformers.models.bert.modeling_bert import BertEmbeddings, BertSelfAttention
from transformers.models.roberta.modeling_roberta import RobertaEmbeddings, RobertaSelfAttention

from composer.algorithms.alibi.attention_surgery_functions.utils import (policy_registry, register_alibi,
zero_and_freeze_expand_position_embeddings)
from composer.algorithms.alibi.attention_surgery_functions.utils import (
policy_registry,
register_alibi,
zero_and_freeze_expand_position_embeddings,
)


@policy_registry.register(BertEmbeddings, RobertaEmbeddings)
Expand All @@ -22,9 +25,11 @@ def bert_embedding_converter(module: torch.nn.Module, module_index: int, max_seq
assert isinstance(module, (BertEmbeddings, RobertaEmbeddings))
del module_index # unused
new_module = copy.deepcopy(module)
zero_and_freeze_expand_position_embeddings(new_module,
max_sequence_length,
position_embedding_attribute='position_embeddings')
zero_and_freeze_expand_position_embeddings(
new_module,
max_sequence_length,
position_embedding_attribute='position_embeddings',
)

module_device = next(new_module.parameters()).device
new_module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device))
Expand All @@ -36,10 +41,12 @@ def bert_attention_converter(module: torch.nn.Module, module_index: int, max_seq
"""Adds ALiBi to Bert-style SelfAttention."""
assert isinstance(module, (BertSelfAttention, RobertaSelfAttention))
del module_index # unused
module = register_alibi(module=module,
n_heads=int(module.num_attention_heads),
max_token_length=max_sequence_length,
causal=False)
module = register_alibi(
module=module,
b-chu marked this conversation as resolved.
Show resolved Hide resolved
n_heads=int(module.num_attention_heads),
max_token_length=max_sequence_length,
causal=False,
)
setattr(module, 'forward', MethodType(forward, module))

return module
Expand Down Expand Up @@ -101,8 +108,9 @@ def forward(
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query':
raise NotImplementedError('ALiBi is not supported for BERT with position_embedding_type: {}'.format(
self.position_embedding_type))
raise NotImplementedError(
'ALiBi is not supported for BERT with position_embedding_type: {}'.format(self.position_embedding_type),
)
#### REMOVES THE FOLLOWING CODE ########
# seq_length = hidden_states.size()[1]
# position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
Expand Down
14 changes: 10 additions & 4 deletions composer/algorithms/alibi/attention_surgery_functions/_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import torch
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model

from composer.algorithms.alibi.attention_surgery_functions.utils import (policy_registry, register_alibi,
zero_and_freeze_expand_position_embeddings)
from composer.algorithms.alibi.attention_surgery_functions.utils import (
policy_registry,
register_alibi,
zero_and_freeze_expand_position_embeddings,
)


@policy_registry.register(GPT2Model)
Expand All @@ -31,7 +34,8 @@ def gpt2_attention_converter(module: torch.nn.Module, module_index: int, max_seq
module=module,
n_heads=int(module.num_heads), #type: ignore num_heads member of GPT2Attention
max_token_length=max_sequence_length,
causal=True)
causal=True,
)
setattr(module, '_attn', MethodType(_attn, module))

module = enlarge_mask(module, max_sequence_length)
Expand Down Expand Up @@ -92,6 +96,8 @@ def enlarge_mask(module: torch.nn.Module, max_sequence_length: int) -> torch.nn.
torch.ones(
(max_sequence_length, max_sequence_length), # type: ignore
dtype=torch.uint8,
device=old_mask.device)).view(1, 1, max_sequence_length, max_sequence_length) # type: ignore
device=old_mask.device,
),
).view(1, 1, max_sequence_length, max_sequence_length) # type: ignore
setattr(module, 'bias', new_mask)
return module
43 changes: 28 additions & 15 deletions composer/algorithms/alibi/attention_surgery_functions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
class PolicyRegistry(Dict[Type[torch.nn.Module], AlibiReplacementFunction]):
"""A registry mapping for ALiBi surgery."""

def register(self,
*modules: Type[torch.nn.Module]) -> Callable[[AlibiReplacementFunction], AlibiReplacementFunction]:
def register(
self,
*modules: Type[torch.nn.Module],
) -> Callable[[AlibiReplacementFunction], AlibiReplacementFunction]:
"""This decorator registers mappings from torch module types to their ALiBi surgery functions.

To accommodate the specifics of composer's module surgery, our ALiBi implementation uses a
Expand Down Expand Up @@ -76,11 +78,13 @@ def _validate_signature(func: Callable):
parameters = signature.parameters
if len(parameters) != 3:
raise ValueError(
f'Each alibi surgery function must accept 3 arguments, {func} accepts {len(parameters)}')
f'Each alibi surgery function must accept 3 arguments, {func} accepts {len(parameters)}',
)
((_, module_param), (_, index_param), (max_seq_name, max_seq_param)) = parameters.items()
if module_param.annotation != torch.nn.Module:
raise TypeError(
f'The first argument of alibi surgery function {func} must be of type "torch.nn.Module"')
f'The first argument of alibi surgery function {func} must be of type "torch.nn.Module"',
)
if index_param.annotation != int:
raise TypeError(f'The second argument of alibi surgery function {func} must be of type "int"')
if max_seq_param.annotation != int:
Expand All @@ -93,7 +97,8 @@ def _register_module(module: Type[torch.nn.Module], func: Callable) -> None:
raise TypeError(f'Module {module.__name__} is not a subclass of `torch.nn.Module`.')
if module in self:
raise ValueError(
f'An AlibiReplacementFunction has already been registered for module {module.__name__}.')
f'An AlibiReplacementFunction has already been registered for module {module.__name__}.',
)
self[module] = func
return

Expand Down Expand Up @@ -123,22 +128,29 @@ def zero_and_freeze_expand_position_embeddings(
pos_embedding_module = attrgetter(position_embedding_attribute)(module)
old_weight = getattr(pos_embedding_module, 'weight')
if not isinstance(old_weight, torch.nn.Parameter):
raise TypeError(f'Module {module._get_name()}, position embedding {position_embedding_attribute}, '
f"'weight' attribute must be of type torch.nn.Module")
raise TypeError(
f'Module {module._get_name()}, position embedding {position_embedding_attribute}, '
f"'weight' attribute must be of type torch.nn.Module",
)
new_weight = torch.nn.Parameter(
torch.zeros((max_sequence_length, old_weight.shape[1]),
dtype=old_weight.dtype,
layout=old_weight.layout,
device=old_weight.device))
torch.zeros(
(max_sequence_length, old_weight.shape[1]),
dtype=old_weight.dtype,
layout=old_weight.layout,
device=old_weight.device,
),
)
new_weight.requires_grad = False
setattr(pos_embedding_module, 'weight', new_weight)

log.info(f' Position embedding expanded to sequence length {max_sequence_length}, zeroed, and frozen')

except AttributeError:
log.error(f'Unable to zero and freeze position embeddings. Module '
f'{module} may lack attribute {position_embedding_attribute}, or position '
f"embeddings may lack attribute 'weight'.")
log.error(
f'Unable to zero and freeze position embeddings. Module '
f'{module} may lack attribute {position_embedding_attribute}, or position '
f"embeddings may lack attribute 'weight'.",
)
raise


Expand Down Expand Up @@ -196,4 +208,5 @@ def get_slopes_power_of_2(n_heads):
else:
closest_power_of_2 = 2**math.floor(math.log2(n_heads))
return get_slopes_power_of_2(closest_power_of_2) + _get_alibi_head_slopes(
2 * closest_power_of_2)[0::2][:n_heads - closest_power_of_2]
2 * closest_power_of_2,
)[0::2][:n_heads - closest_power_of_2]
99 changes: 60 additions & 39 deletions composer/algorithms/augmix/augmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
ImgT = TypeVar('ImgT', torch.Tensor, PillowImage)


def augmix_image(img: ImgT,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: List = augmentation_sets['all']) -> ImgT:
def augmix_image(
img: ImgT,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: List = augmentation_sets['all'],
) -> ImgT:
r"""Applies the AugMix (`Hendrycks et al, 2020 <http://arxiv.org/abs/1912.02781>`_) data augmentation.

This function works on a single image or batch of images. See :class:`.AugMix` and
Expand Down Expand Up @@ -69,8 +71,14 @@ def augmix_image(img: ImgT,
PIL.Image: AugMix'd image.
"""

def _augmix_pil_image(img_pil: PillowImage, severity: int, depth: int, width: int, alpha: float,
augmentation_set: List) -> PillowImage:
def _augmix_pil_image(
img_pil: PillowImage,
severity: int,
depth: int,
width: int,
alpha: float,
augmentation_set: List,
) -> PillowImage:
chain_weights = np.random.dirichlet([alpha] * width).astype(np.float32)
mixing_weight = np.float32(np.random.beta(alpha, alpha))
augmented_combination = np.zeros_like(img_pil, dtype=np.float32)
Expand All @@ -92,12 +100,14 @@ def _augmix_pil_image(img_pil: PillowImage, severity: int, depth: int, width: in
mixed = Image.fromarray(np.uint8(mixed))
return mixed

f_pil = functools.partial(_augmix_pil_image,
severity=severity,
depth=depth,
width=width,
alpha=alpha,
augmentation_set=augmentation_set)
f_pil = functools.partial(
_augmix_pil_image,
severity=severity,
depth=depth,
width=width,
alpha=alpha,
augmentation_set=augmentation_set,
)
return map_pillow_function(f_pil, img)


Expand Down Expand Up @@ -136,12 +146,14 @@ class AugmentAndMixTransform(torch.nn.Module):
:class:`.AugMix`.
"""

def __init__(self,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: str = 'all'):
def __init__(
self,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: str = 'all',
):
super().__init__()
if severity < 0 or severity > 10:
raise ValueError('AugMix severity value must satisfy 0 ≤ severity ≤ 10')
Expand All @@ -157,12 +169,14 @@ def __init__(self,

def forward(self, img: PillowImage) -> PillowImage:

return augmix_image(img=img,
severity=self.severity,
depth=self.depth,
width=self.width,
alpha=self.alpha,
augmentation_set=self.augmentation_set)
return augmix_image(
img=img,
severity=self.severity,
depth=self.depth,
width=self.width,
alpha=self.alpha,
augmentation_set=self.augmentation_set,
)


class AugMix(Algorithm):
Expand Down Expand Up @@ -239,12 +253,14 @@ class AugMix(Algorithm):
# TODO document each value of augmentation_set in more detail; i.e.,
# which augmentations are actually used

def __init__(self,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: str = 'all'):
def __init__(
self,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: str = 'all',
):
if severity < 0 or severity > 10:
raise ValueError('AugMix severity value must satisfy 0 ≤ severity ≤ 10')
if width < 1:
Expand All @@ -267,17 +283,22 @@ def match(self, event: Event, state: State) -> bool:
return state.dataloader.dataset not in self._transformed_datasets

def apply(self, event: Event, state: State, logger: Logger) -> None:
am = AugmentAndMixTransform(severity=self.severity,
depth=self.depth,
width=self.width,
alpha=self.alpha,
augmentation_set=self.augmentation_set)
am = AugmentAndMixTransform(
severity=self.severity,
depth=self.depth,
width=self.width,
alpha=self.alpha,
augmentation_set=self.augmentation_set,
)
assert isinstance(state.dataloader, torch.utils.data.DataLoader), 'dataloader type checked on match()'
dataset = state.dataloader.dataset
if not isinstance(dataset, VisionDataset):
raise TypeError(
textwrap.dedent(f"""\
textwrap.dedent(
f"""\
To use {type(self).__name__}, the dataset must be a
{VisionDataset.__qualname__}, not {type(dataset).__name__}"""))
{VisionDataset.__qualname__}, not {type(dataset).__name__}""",
),
)
add_vision_dataset_transform(dataset, am, is_tensor_transform=False)
self._transformed_datasets.add(dataset)
Loading
Loading