Skip to content

Commit

Permalink
Change style for splitting on commas
Browse files Browse the repository at this point in the history
Splits imports in the following way
```
from composer.algorithms.low_precision_groupnorm.low_precision_groupnorm import (LowPrecisionGroupNorm,
                                                                                 apply_low_precision_groupnorm)
```
```
from composer.algorithms.low_precision_groupnorm.low_precision_groupnorm import (
    LowPrecisionGroupNorm,
    apply_low_precision_groupnorm,
)
```

Changes function arguments in the following way:
```
    freeze_percentage = _freeze_schedule(current_duration=current_duration,
                                         freeze_start=freeze_start,
                                         freeze_level=freeze_level)
```
```
    freeze_percentage = _freeze_schedule(
        current_duration=current_duration,
        freeze_start=freeze_start,
        freeze_level=freeze_level,
    )
```
  • Loading branch information
b-chu committed Mar 1, 2024
1 parent cee4523 commit 3414873
Show file tree
Hide file tree
Showing 243 changed files with 9,473 additions and 5,448 deletions.
12 changes: 7 additions & 5 deletions .github/bin/gen_docker_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ def _parse_args() -> Namespace:
"""
args = ArgumentParser(description='Process a Docker matrix YAML file.')
args.add_argument('yaml_file', type=FileType('r'), help='The YAML file to be processed.')
args.add_argument('-b',
'--build_args',
action='append',
required=False,
help='List of build args to override globally')
args.add_argument(
'-b',
'--build_args',
action='append',
required=False,
help='List of build args to override globally',
)

return args.parse_args()

Expand Down
6 changes: 4 additions & 2 deletions composer/algorithms/alibi/alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def replacement_function(module: torch.nn.Module, module_index: int):
count = len(replaced_pairs)
if count == 0:
supported_modules = ''.join(sorted(['\n\t' + c.__module__ + '.' + c.__name__ for c in policy_registry.keys()]))
log.warning(f'ALiBi had no effect on the model! Support for ALiBi surgery '
f'is currently limited to the following classes: {supported_modules}')
log.warning(
f'ALiBi had no effect on the model! Support for ALiBi surgery '
f'is currently limited to the following classes: {supported_modules}',
)
else:
log.info(f' {count} instances of ALiBi added')

Expand Down
30 changes: 19 additions & 11 deletions composer/algorithms/alibi/attention_surgery_functions/_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from transformers.models.bert.modeling_bert import BertEmbeddings, BertSelfAttention
from transformers.models.roberta.modeling_roberta import RobertaEmbeddings, RobertaSelfAttention

from composer.algorithms.alibi.attention_surgery_functions.utils import (policy_registry, register_alibi,
zero_and_freeze_expand_position_embeddings)
from composer.algorithms.alibi.attention_surgery_functions.utils import (
policy_registry,
register_alibi,
zero_and_freeze_expand_position_embeddings,
)


@policy_registry.register(BertEmbeddings, RobertaEmbeddings)
Expand All @@ -22,9 +25,11 @@ def bert_embedding_converter(module: torch.nn.Module, module_index: int, max_seq
assert isinstance(module, (BertEmbeddings, RobertaEmbeddings))
del module_index # unused
new_module = copy.deepcopy(module)
zero_and_freeze_expand_position_embeddings(new_module,
max_sequence_length,
position_embedding_attribute='position_embeddings')
zero_and_freeze_expand_position_embeddings(
new_module,
max_sequence_length,
position_embedding_attribute='position_embeddings',
)

module_device = next(new_module.parameters()).device
new_module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device))
Expand All @@ -36,10 +41,12 @@ def bert_attention_converter(module: torch.nn.Module, module_index: int, max_seq
"""Adds ALiBi to Bert-style SelfAttention."""
assert isinstance(module, (BertSelfAttention, RobertaSelfAttention))
del module_index # unused
module = register_alibi(module=module,
n_heads=int(module.num_attention_heads),
max_token_length=max_sequence_length,
causal=False)
module = register_alibi(
module=module,
n_heads=int(module.num_attention_heads),
max_token_length=max_sequence_length,
causal=False,
)
setattr(module, 'forward', MethodType(forward, module))

return module
Expand Down Expand Up @@ -101,8 +108,9 @@ def forward(
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query':
raise NotImplementedError('ALiBi is not supported for BERT with position_embedding_type: {}'.format(
self.position_embedding_type))
raise NotImplementedError(
'ALiBi is not supported for BERT with position_embedding_type: {}'.format(self.position_embedding_type),
)
#### REMOVES THE FOLLOWING CODE ########
# seq_length = hidden_states.size()[1]
# position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
Expand Down
14 changes: 10 additions & 4 deletions composer/algorithms/alibi/attention_surgery_functions/_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import torch
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model

from composer.algorithms.alibi.attention_surgery_functions.utils import (policy_registry, register_alibi,
zero_and_freeze_expand_position_embeddings)
from composer.algorithms.alibi.attention_surgery_functions.utils import (
policy_registry,
register_alibi,
zero_and_freeze_expand_position_embeddings,
)


@policy_registry.register(GPT2Model)
Expand All @@ -31,7 +34,8 @@ def gpt2_attention_converter(module: torch.nn.Module, module_index: int, max_seq
module=module,
n_heads=int(module.num_heads), #type: ignore num_heads member of GPT2Attention
max_token_length=max_sequence_length,
causal=True)
causal=True,
)
setattr(module, '_attn', MethodType(_attn, module))

module = enlarge_mask(module, max_sequence_length)
Expand Down Expand Up @@ -92,6 +96,8 @@ def enlarge_mask(module: torch.nn.Module, max_sequence_length: int) -> torch.nn.
torch.ones(
(max_sequence_length, max_sequence_length), # type: ignore
dtype=torch.uint8,
device=old_mask.device)).view(1, 1, max_sequence_length, max_sequence_length) # type: ignore
device=old_mask.device,
),
).view(1, 1, max_sequence_length, max_sequence_length) # type: ignore
setattr(module, 'bias', new_mask)
return module
43 changes: 28 additions & 15 deletions composer/algorithms/alibi/attention_surgery_functions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
class PolicyRegistry(Dict[Type[torch.nn.Module], AlibiReplacementFunction]):
"""A registry mapping for ALiBi surgery."""

def register(self,
*modules: Type[torch.nn.Module]) -> Callable[[AlibiReplacementFunction], AlibiReplacementFunction]:
def register(
self,
*modules: Type[torch.nn.Module],
) -> Callable[[AlibiReplacementFunction], AlibiReplacementFunction]:
"""This decorator registers mappings from torch module types to their ALiBi surgery functions.
To accommodate the specifics of composer's module surgery, our ALiBi implementation uses a
Expand Down Expand Up @@ -76,11 +78,13 @@ def _validate_signature(func: Callable):
parameters = signature.parameters
if len(parameters) != 3:
raise ValueError(
f'Each alibi surgery function must accept 3 arguments, {func} accepts {len(parameters)}')
f'Each alibi surgery function must accept 3 arguments, {func} accepts {len(parameters)}',
)
((_, module_param), (_, index_param), (max_seq_name, max_seq_param)) = parameters.items()
if module_param.annotation != torch.nn.Module:
raise TypeError(
f'The first argument of alibi surgery function {func} must be of type "torch.nn.Module"')
f'The first argument of alibi surgery function {func} must be of type "torch.nn.Module"',
)
if index_param.annotation != int:
raise TypeError(f'The second argument of alibi surgery function {func} must be of type "int"')
if max_seq_param.annotation != int:
Expand All @@ -93,7 +97,8 @@ def _register_module(module: Type[torch.nn.Module], func: Callable) -> None:
raise TypeError(f'Module {module.__name__} is not a subclass of `torch.nn.Module`.')
if module in self:
raise ValueError(
f'An AlibiReplacementFunction has already been registered for module {module.__name__}.')
f'An AlibiReplacementFunction has already been registered for module {module.__name__}.',
)
self[module] = func
return

Expand Down Expand Up @@ -123,22 +128,29 @@ def zero_and_freeze_expand_position_embeddings(
pos_embedding_module = attrgetter(position_embedding_attribute)(module)
old_weight = getattr(pos_embedding_module, 'weight')
if not isinstance(old_weight, torch.nn.Parameter):
raise TypeError(f'Module {module._get_name()}, position embedding {position_embedding_attribute}, '
f"'weight' attribute must be of type torch.nn.Module")
raise TypeError(
f'Module {module._get_name()}, position embedding {position_embedding_attribute}, '
f"'weight' attribute must be of type torch.nn.Module",
)
new_weight = torch.nn.Parameter(
torch.zeros((max_sequence_length, old_weight.shape[1]),
dtype=old_weight.dtype,
layout=old_weight.layout,
device=old_weight.device))
torch.zeros(
(max_sequence_length, old_weight.shape[1]),
dtype=old_weight.dtype,
layout=old_weight.layout,
device=old_weight.device,
),
)
new_weight.requires_grad = False
setattr(pos_embedding_module, 'weight', new_weight)

log.info(f' Position embedding expanded to sequence length {max_sequence_length}, zeroed, and frozen')

except AttributeError:
log.error(f'Unable to zero and freeze position embeddings. Module '
f'{module} may lack attribute {position_embedding_attribute}, or position '
f"embeddings may lack attribute 'weight'.")
log.error(
f'Unable to zero and freeze position embeddings. Module '
f'{module} may lack attribute {position_embedding_attribute}, or position '
f"embeddings may lack attribute 'weight'.",
)
raise


Expand Down Expand Up @@ -196,4 +208,5 @@ def get_slopes_power_of_2(n_heads):
else:
closest_power_of_2 = 2**math.floor(math.log2(n_heads))
return get_slopes_power_of_2(closest_power_of_2) + _get_alibi_head_slopes(
2 * closest_power_of_2)[0::2][:n_heads - closest_power_of_2]
2 * closest_power_of_2,
)[0::2][:n_heads - closest_power_of_2]
99 changes: 60 additions & 39 deletions composer/algorithms/augmix/augmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
ImgT = TypeVar('ImgT', torch.Tensor, PillowImage)


def augmix_image(img: ImgT,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: List = augmentation_sets['all']) -> ImgT:
def augmix_image(
img: ImgT,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: List = augmentation_sets['all'],
) -> ImgT:
r"""Applies the AugMix (`Hendrycks et al, 2020 <http://arxiv.org/abs/1912.02781>`_) data augmentation.
This function works on a single image or batch of images. See :class:`.AugMix` and
Expand Down Expand Up @@ -69,8 +71,14 @@ def augmix_image(img: ImgT,
PIL.Image: AugMix'd image.
"""

def _augmix_pil_image(img_pil: PillowImage, severity: int, depth: int, width: int, alpha: float,
augmentation_set: List) -> PillowImage:
def _augmix_pil_image(
img_pil: PillowImage,
severity: int,
depth: int,
width: int,
alpha: float,
augmentation_set: List,
) -> PillowImage:
chain_weights = np.random.dirichlet([alpha] * width).astype(np.float32)
mixing_weight = np.float32(np.random.beta(alpha, alpha))
augmented_combination = np.zeros_like(img_pil, dtype=np.float32)
Expand All @@ -92,12 +100,14 @@ def _augmix_pil_image(img_pil: PillowImage, severity: int, depth: int, width: in
mixed = Image.fromarray(np.uint8(mixed))
return mixed

f_pil = functools.partial(_augmix_pil_image,
severity=severity,
depth=depth,
width=width,
alpha=alpha,
augmentation_set=augmentation_set)
f_pil = functools.partial(
_augmix_pil_image,
severity=severity,
depth=depth,
width=width,
alpha=alpha,
augmentation_set=augmentation_set,
)
return map_pillow_function(f_pil, img)


Expand Down Expand Up @@ -136,12 +146,14 @@ class AugmentAndMixTransform(torch.nn.Module):
:class:`.AugMix`.
"""

def __init__(self,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: str = 'all'):
def __init__(
self,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: str = 'all',
):
super().__init__()
if severity < 0 or severity > 10:
raise ValueError('AugMix severity value must satisfy 0 ≤ severity ≤ 10')
Expand All @@ -157,12 +169,14 @@ def __init__(self,

def forward(self, img: PillowImage) -> PillowImage:

return augmix_image(img=img,
severity=self.severity,
depth=self.depth,
width=self.width,
alpha=self.alpha,
augmentation_set=self.augmentation_set)
return augmix_image(
img=img,
severity=self.severity,
depth=self.depth,
width=self.width,
alpha=self.alpha,
augmentation_set=self.augmentation_set,
)


class AugMix(Algorithm):
Expand Down Expand Up @@ -239,12 +253,14 @@ class AugMix(Algorithm):
# TODO document each value of augmentation_set in more detail; i.e.,
# which augmentations are actually used

def __init__(self,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: str = 'all'):
def __init__(
self,
severity: int = 3,
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: str = 'all',
):
if severity < 0 or severity > 10:
raise ValueError('AugMix severity value must satisfy 0 ≤ severity ≤ 10')
if width < 1:
Expand All @@ -267,17 +283,22 @@ def match(self, event: Event, state: State) -> bool:
return state.dataloader.dataset not in self._transformed_datasets

def apply(self, event: Event, state: State, logger: Logger) -> None:
am = AugmentAndMixTransform(severity=self.severity,
depth=self.depth,
width=self.width,
alpha=self.alpha,
augmentation_set=self.augmentation_set)
am = AugmentAndMixTransform(
severity=self.severity,
depth=self.depth,
width=self.width,
alpha=self.alpha,
augmentation_set=self.augmentation_set,
)
assert isinstance(state.dataloader, torch.utils.data.DataLoader), 'dataloader type checked on match()'
dataset = state.dataloader.dataset
if not isinstance(dataset, VisionDataset):
raise TypeError(
textwrap.dedent(f"""\
textwrap.dedent(
f"""\
To use {type(self).__name__}, the dataset must be a
{VisionDataset.__qualname__}, not {type(dataset).__name__}"""))
{VisionDataset.__qualname__}, not {type(dataset).__name__}""",
),
)
add_vision_dataset_transform(dataset, am, is_tensor_transform=False)
self._transformed_datasets.add(dataset)
Loading

0 comments on commit 3414873

Please sign in to comment.