Skip to content

Commit

Permalink
Merge branch 'dev' into neptune-logger
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksanderWWW authored Jan 18, 2024
2 parents 97b3894 + 09dbe18 commit 608c767
Show file tree
Hide file tree
Showing 91 changed files with 1,525 additions and 654 deletions.
2 changes: 1 addition & 1 deletion .ci/FILE_HEADER
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Copyright 2022 MosaicML Composer authors
Copyright 2024 MosaicML Composer authors
SPDX-License-Identifier: Apache-2.0
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,15 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.1.13
rev: v1.5.4
hooks:
- id: insert-license
args:
- --license-filepath
- .ci/FILE_HEADER
- --comment-style
- "#"
- --allow-past-years
types: [python]
exclude: 'composer\/trainer\/activation_checkpointing.py'

Expand Down Expand Up @@ -110,7 +111,7 @@ repos:
types: [python]
pass_filenames: false
args: [--warnings]
additional_dependencies: ["[email protected].256"]
additional_dependencies: ["[email protected].310"]
- repo: https://github.com/trufflesecurity/trufflehog.git
rev: v3.40.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion composer/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Composer Version."""

__version__ = '0.17.2'
__version__ = '0.18.0'
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from composer.utils import MissingConditionalImportError

try:
from composer.algorithms.alibi.attention_surgery_functions import _bert, _gpt2 # pyright: reportUnusedImport=none
from composer.algorithms.alibi.attention_surgery_functions import _bert # pyright: ignore[reportUnusedImport]
from composer.algorithms.alibi.attention_surgery_functions import _gpt2 # pyright: ignore[reportUnusedImport]
from composer.algorithms.alibi.attention_surgery_functions.utils import policy_registry
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e
Expand Down
15 changes: 10 additions & 5 deletions composer/algorithms/colout/colout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
__all__ = ['ColOut', 'ColOutTransform', 'colout_batch']


def colout_batch(sample: Union[ImgT, Tuple[ImgT, ImgT]],
p_row: float = 0.15,
p_col: float = 0.15,
resize_target: Union[bool, str] = 'auto') -> Union[ImgT, Tuple[ImgT, ImgT]]:
def colout_batch(
sample: Union[ImgT, Tuple[ImgT, ImgT]],
p_row: float = 0.15,
p_col: float = 0.15,
resize_target: Union[bool,
str] = 'auto') -> Union[torch.Tensor, ImgT, Tuple[Tensor, Tensor], Tuple[ImgT, ImgT]]:
"""Applies ColOut augmentation to a batch of images and (optionally) targets,
dropping the same random rows and columns from all images and targets in a batch.
Expand Down Expand Up @@ -136,7 +138,10 @@ def __init__(self, p_row: float = 0.15, p_col: float = 0.15, resize_target: Unio
self.p_col = p_col
self.resize_target = resize_target

def __call__(self, sample: Union[ImgT, Tuple[ImgT, ImgT]]) -> Union[ImgT, Tuple[ImgT, ImgT]]:
def __call__(
self, sample: Union[ImgT,
Tuple[ImgT,
ImgT]]) -> Union[torch.Tensor, ImgT, Tuple[Tensor, Tensor], Tuple[ImgT, ImgT]]:
"""Drops random rows and columns from up to two images.
Args:
Expand Down
17 changes: 9 additions & 8 deletions composer/algorithms/factorize/factorize_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def solution_for_rank(self, input: torch.Tensor, rank: int) -> LowRankSolution:

def apply_solution(self, solution: LowRankSolution):
self.latent_size = solution.rank
self.module0.out_channels = solution.rank
self.module1.in_channels = solution.rank
self.module0.out_channels = solution.rank # pyright: ignore[reportGeneralTypeIssues]
self.module1.in_channels = solution.rank # pyright: ignore[reportGeneralTypeIssues]
_apply_solution_to_module_parameters(solution, self.module0, self.module1, transpose=False)

@staticmethod
Expand Down Expand Up @@ -452,8 +452,8 @@ def solution_for_rank(self, input: torch.Tensor, rank: int) -> LowRankSolution:

def apply_solution(self, solution: LowRankSolution) -> None:
self.latent_size = solution.rank
self.module0.out_features = solution.rank
self.module1.in_features = solution.rank
self.module0.out_features = solution.rank # pyright: ignore[reportGeneralTypeIssues]
self.module1.in_features = solution.rank # pyright: ignore[reportGeneralTypeIssues]
_apply_solution_to_module_parameters(solution, self.module0, self.module1, transpose=True)

@staticmethod
Expand All @@ -471,9 +471,10 @@ def max_allowed_latent_channels(in_features: int, out_features: int) -> int:

@staticmethod
def from_linear(module: torch.nn.Linear, module_ix: int = -1, **kwargs) -> FactorizedLinear:
ret = FactorizedLinear(in_features=module.in_features,
out_features=module.out_features,
bias=((module.bias is not None) and (module.bias is not False)),
**kwargs)
ret = FactorizedLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=(module.bias is not None and module.bias is not False), # pyright: ignore[reportUnnecessaryComparison]
**kwargs)
ret.reset_parameters()
return ret
17 changes: 11 additions & 6 deletions composer/algorithms/gated_linear_units/gated_linear_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def from_BertOutput(layer: torch.nn.Module,
non_gated_layer_bias: bool = False) -> BERTGatedFFOutput:
"""Defines a replacement policy from a :class:`transformers.models.bert.modeling_bert.BertOutput` to a :class:`composer.algorithms.gated_linear_units.gated_linear_unit_layers.BERTGatedFFOutput`"""
assert isinstance(
layer, BertOutput
layer,
BertOutput # pyright: ignore[reportUnboundVariable]
), 'The replacement policy requires an instance of transformers.models.bert.modeling_bert.BertOutput for the necessary fields to be defined.'
return BERTGatedFFOutput(
d_embed=layer.dense.out_features, #type: ignore dense.out_features member of BertOutput
Expand Down Expand Up @@ -85,16 +86,20 @@ def apply_gated_linear_units(model: torch.nn.Module,
unwrapped_model = model.model if isinstance(model, HuggingFaceModel) else model

# ensure that the model is an instance of a Hugging Face BertPreTrainedModel class, since our replacement policy is only defined for BERTs
if not isinstance(unwrapped_model, BertPreTrainedModel):
if not isinstance(unwrapped_model, BertPreTrainedModel): # pyright: ignore[reportUnboundVariable]
raise TypeError(
'Gated Linear Units only has a surgery policy defined for subclasses of transformers.BertPreTrainedModel')

# Early exit if nothing to replace
if module_surgery.count_module_instances(module=model, module_class=BertIntermediate) == 0:
if module_surgery.count_module_instances(
module=model, module_class=BertIntermediate) == 0: # pyright: ignore[reportUnboundVariable]
return

if act_fn is None:
intermediate_modules = {module for module in model.modules() if isinstance(module, BertIntermediate)}
intermediate_modules = {
module for module in model.modules()
if isinstance(module, BertIntermediate) # pyright: ignore[reportUnboundVariable]
} # pyright: ignore[reportUnboundVariable]
if len(intermediate_modules) == 0:
warnings.warn(
NoEffectWarning('No instances of BertIntermediate were found so Gated Linear Units will be skipped '
Expand Down Expand Up @@ -130,8 +135,8 @@ def from_bound_BertOutput(layer: torch.nn.Module, module_index: int) -> BERTGate

# prepare the replacement policy and perform replacement
policy: Dict[Type[torch.nn.Module], module_surgery.ReplacementFunction] = {
BertIntermediate: from_BertIntermediate,
BertOutput: from_bound_BertOutput
BertIntermediate: from_BertIntermediate, # pyright: ignore[reportUnboundVariable]
BertOutput: from_bound_BertOutput # pyright: ignore[reportUnboundVariable]
}
replaced_instances = module_surgery.replace_module_classes(module=model, optimizers=optimizers, policies=policy)
if len(replaced_instances) == 0:
Expand Down
4 changes: 2 additions & 2 deletions composer/algorithms/ghost_batchnorm/ghost_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, base_batchnorm: _TORCH_BATCHNORM_BASE_CLASS, ghost_batch_size
super().__init__()
self.ghost_batch_size = ghost_batch_size
self.batchnorm = base_batchnorm
self.batchnorm._already_ghost_batchnormed = True # Mark to avoid rewrapping on duplicate calls
self.batchnorm._already_ghost_batchnormed = True # Mark to avoid rewrapping on duplicate calls # pyright: ignore[reportGeneralTypeIssues]

def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
batch_size = input.shape[0]
Expand All @@ -161,7 +161,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
raise ValueError(f'Worker batch size {batch_size} < ghost_batch_size {self.ghost_batch_size}')

nchunks: int = int(math.ceil(batch_size / self.ghost_batch_size))
has_momentum = self.batchnorm.momentum is not None
has_momentum: bool = hasattr(self.batchnorm, 'momentum')
original_momentum: float = self.batchnorm.momentum

if self.training and has_momentum:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, device=None
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
downcast_weight = _cast_if_autocast_enabled(
self.weight) if self.weight is not None else self.weight # pyright: ignore[reportUnnecessaryComparison]
downcast_bias = _cast_if_autocast_enabled(
self.bias) if self.bias is not None else self.bias # pyright: ignore[reportUnnecessaryComparison]
with torch.autocast(enabled=False, device_type=module_device.type):
return F.group_norm(downcast_x, self.num_groups, downcast_weight, downcast_bias, self.eps)

Expand All @@ -111,11 +113,11 @@ def _to_LPGroupNorm(layer: torch.nn.Module, module_index: int) -> LPGroupNorm:
lp_groupnorm = LPGroupNorm(layer.num_groups, layer.num_channels, layer.eps, layer.affine)

with torch.no_grad():
if layer.weight is None:
if layer.weight is None: # pyright: ignore[reportUnnecessaryComparison]
lp_groupnorm.register_parameter('weight', None)
else:
lp_groupnorm.weight.copy_(layer.weight) # type: ignore
if layer.bias is None:
if layer.bias is None: # pyright: ignore[reportUnnecessaryComparison]
lp_groupnorm.register_parameter('bias', None)
else:
lp_groupnorm.bias.copy_(layer.bias) # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
downcast_weight = _cast_if_autocast_enabled(
self.weight) if self.weight is not None else self.weight # pyright: ignore[reportUnnecessaryComparison]
downcast_bias = _cast_if_autocast_enabled(
self.bias) if self.bias is not None else self.bias # pyright: ignore[reportUnnecessaryComparison]
with torch.autocast(enabled=False, device_type=module_device.type):
return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)

Expand Down Expand Up @@ -141,11 +143,11 @@ def _to_LPLayerNorm(layer: torch.nn.Module, module_index: int) -> LPLayerNorm:
lp_layernorm = LPLayerNorm(layer.normalized_shape, layer.eps, layer.elementwise_affine)

with torch.no_grad():
if layer.weight is None:
if hasattr(layer, 'weight'):
lp_layernorm.register_parameter('weight', None)
else:
lp_layernorm.weight.copy_(layer.weight) # type: ignore
if layer.bias is None:
if layer.bias is None: # pyright: ignore[reportUnnecessaryComparison]
lp_layernorm.register_parameter('bias', None)
else:
lp_layernorm.bias.copy_(layer.bias) # type: ignore
Expand All @@ -160,12 +162,12 @@ def _to_FusedLayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLa
fused_layernorm = APEXFusedLayerNorm(normalized_shape=layer.normalized_shape, eps=layer.eps)

with torch.no_grad():
if layer.weight is None:
fused_layernorm.weight = None
if layer.weight is None: # pyright: ignore[reportUnnecessaryComparison]
fused_layernorm.weight = None # pyright: ignore[reportGeneralTypeIssues]
else:
fused_layernorm.weight.copy_(layer.weight)
if layer.bias is None:
fused_layernorm.bias = None
if layer.bias is None: # pyright: ignore[reportUnnecessaryComparison]
fused_layernorm.bias = None # pyright: ignore[reportGeneralTypeIssues]
else:
fused_layernorm.bias.copy_(layer.bias)

Expand Down
8 changes: 4 additions & 4 deletions composer/algorithms/sam/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self,
defaults = {'rho': rho, 'epsilon': epsilon, **kwargs}
super(SAMOptimizer, self).__init__(self.base_optimizer.param_groups, defaults)

@torch.no_grad()
@torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def sub_e_w(self):
for group in self.param_groups:
for p in group['params']:
Expand All @@ -59,7 +59,7 @@ def sub_e_w(self):
e_w = self.state[p]['e_w'] # retrieve stale e(w)
p.sub_(e_w) # get back to "w" from "w + e(w)"

@torch.no_grad()
@torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def first_step(self):
grad_norm = self._grad_norm()
for group in self.param_groups:
Expand All @@ -71,7 +71,7 @@ def first_step(self):
p.add_(e_w) # climb to the local maximum "w + e(w)"
self.state[p]['e_w'] = e_w

@torch.no_grad()
@torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def second_step(self):
for group in self.param_groups:
for p in group['params']:
Expand All @@ -80,7 +80,7 @@ def second_step(self):
p.sub_(self.state[p]['e_w']) # get back to "w" from "w + e(w)"
self.base_optimizer.step() # do the actual "sharpness-aware" update

@torch.no_grad()
@torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def step(self, closure=None):
assert closure is not None, 'Sharpness Aware Minimization requires closure, but it was not provided'
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/squeeze_excite/squeeze_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class SqueezeExciteConv2d(torch.nn.Module):
def __init__(self, *args, latent_channels: float = 0.125, conv: Optional[torch.nn.Conv2d] = None, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(*args, **kwargs) if conv is None else conv
self.conv._already_squeeze_excited = True # Mark to avoid rewrapping on duplicate calls
self.conv._already_squeeze_excited = True # Mark to avoid rewrapping on duplicate calls # pyright: ignore[reportGeneralTypeIssues]
self.se = SqueezeExcite2d(num_features=self.conv.out_channels, latent_channels=latent_channels)

def forward(self, input: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions composer/callbacks/activation_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def register_forward_hook(self, model: torch.nn.Module, logger: Logger, step: Op
def _register_forward_hook(self, logger: Logger, step: Optional[int], module: torch.nn.Module):
self.handles.append(module.register_forward_hook(partial(self.forward_hook, logger, step)))

def forward_hook(self, logger: Logger, step: Optional[int], module: torch.nn.Module, input: Sequence,
output: Sequence):
def forward_hook(self, logger: Logger, step: Optional[int], module: torch.nn.Module, input: Optional[Sequence],
output: Optional[Sequence]):
module_name = self.module_names[module]

if self.ignore_module_types is not None:
Expand Down
Loading

0 comments on commit 608c767

Please sign in to comment.