Skip to content

Commit

Permalink
Add trailing commas through ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Mar 7, 2024
1 parent 3414873 commit 5ee552e
Show file tree
Hide file tree
Showing 25 changed files with 136 additions and 125 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
Expand Down
6 changes: 3 additions & 3 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ def __init__(
self,
folder: Union[str, pathlib.Path] = '{run_name}/checkpoints',
filename: Union[str, pathlib.Path] = 'ep{epoch}-ba{batch}-rank{rank}.pt',
remote_file_name: Optional[Union[str,
pathlib.Path,
]] = '{run_name}/checkpoints/ep{epoch}-ba{batch}-rank{rank}.pt',
remote_file_name: Optional[Union[str, pathlib.Path]
] = ('{run_name}/checkpoints/'
'ep{epoch}-ba{batch}-rank{rank}.pt'),
latest_filename: Optional[Union[str, pathlib.Path]] = 'latest-rank{rank}.pt',
latest_remote_file_name: Optional[Union[str, pathlib.Path]] = '{run_name}/checkpoints/latest-rank{rank}.pt',
save_interval: Union[Time, str, int, Callable[[State, Event], bool]] = '1ep',
Expand Down
9 changes: 1 addition & 8 deletions composer/callbacks/early_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,7 @@ def __init__(
self,
monitor: str,
dataloader_label: str,
comp: Optional[Union[str,
Callable[[
Any,
Any,
],
Any,
],
]] = None,
comp: Optional[Union[str, Callable[[Any, Any], Any]]] = None,
min_delta: float = 0.0,
patience: Union[int, str, Time] = 1,
):
Expand Down
9 changes: 1 addition & 8 deletions composer/callbacks/threshold_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,7 @@ def __init__(
dataloader_label: str,
threshold: float,
*,
comp: Optional[Union[str,
Callable[[
Any,
Any,
],
Any,
],
]] = None,
comp: Optional[Union[str, Callable[[Any, Any], Any]]] = None,
stop_on_batch: bool = False,
):
self.monitor = monitor
Expand Down
6 changes: 1 addition & 5 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,7 @@ def __init__(
logger: Logger,
algorithm_passes: Optional[Union[passes.AlgorithmPass,
Tuple[passes.AlgorithmPass, int],
Sequence[Union[passes.AlgorithmPass,
Tuple[passes.AlgorithmPass,
int,
],
]],
Sequence[Union[passes.AlgorithmPass, Tuple[passes.AlgorithmPass, int]]],
]] = None,
):
self.logger = logger
Expand Down
6 changes: 4 additions & 2 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,8 +779,10 @@ def fsdp_device_mesh(self):

@property
def load_fsdp_monolith_rank0_only(self):
return self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config[
'state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True
return (
self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config['state_dict_type'] == 'full' and
self.fsdp_config['load_monolith_rank0_only'] == True
)

def _get_integrations_state_dict(self) -> Dict[str, Any]:
"""Gets a dictionary of information about integrations to store in the state dict.
Expand Down
4 changes: 2 additions & 2 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,8 @@ def compute(self):
n = self.num_generations

for k in self.pass_at_k:
pass_at_k = sum([self.estimator(n, int(c.item()), k)
for c in self.correct[complete]],) / complete.sum().item()
estimators = [self.estimator(n, int(c.item()), k) for c in self.correct[complete]]
pass_at_k = sum(estimators) / complete.sum().item()
results[f'pass@{k}'] = torch.tensor(pass_at_k)

if len(results) == 1: # backwards compatibility
Expand Down
6 changes: 4 additions & 2 deletions composer/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ def __init__(
torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json',
torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
torch_prof_memory_filename: Optional[str] = None,
torch_prof_memory_remote_file_name: Optional[
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html',
torch_prof_memory_remote_file_name: Optional[str] = (
'{run_name}/torch_memory_traces/'
'rank{rank}.{batch}.pt.memory_trace.html'
),
torch_prof_overwrite: bool = False,
torch_prof_use_gzip: bool = False,
torch_prof_record_shapes: bool = False,
Expand Down
6 changes: 4 additions & 2 deletions composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ def __init__(
filename: str = 'rank{rank}.{batch}.pt.trace.json',
remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
memory_filename: Optional[str] = None,
memory_remote_file_name: Optional[str
] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html',
memory_remote_file_name: Optional[str] = (
'{run_name}/torch_memory_traces/'
'rank{rank}.{batch}.pt.trace.memory.html'
),
overwrite: bool = False,
use_gzip: bool = False,
record_shapes: bool = False,
Expand Down
8 changes: 4 additions & 4 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,9 +1862,7 @@ def fit(
# Schedulers
schedulers: Optional[Union[ComposerScheduler,
LRScheduler,
Sequence[Union[ComposerScheduler,
LRScheduler,
]],
Sequence[Union[ComposerScheduler, LRScheduler]],
]] = None,
scale_schedule_ratio: float = 1.0,
step_schedulers_every_batch: Optional[bool] = None,
Expand Down Expand Up @@ -2496,7 +2494,9 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
for metric in self.state.train_metrics.values():
metric.reset()

total_loss_dict = {'loss/train/total': self.state.device.tensor_to_device(torch.zeros(size=(1,)))}
total_loss_dict = {
'loss/train/total': self.state.device.tensor_to_device(torch.zeros(size=(1,))),
}
found_cuda_oom = 0 # int since bool BOR not supported on all torch.distributed backends
try:
assert self.state.scaler is not None
Expand Down
13 changes: 9 additions & 4 deletions composer/utils/eval_client/lambda_eval_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ def __init__(self) -> None:

def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bool]]]:
"""Invoke a batch of provided payloads for code evaluations."""
ret = [[[self.invoke_helper(test_case)
for test_case in generation_group]
for generation_group in prompt_group]
for prompt_group in payload]
ret = []
for prompt_group in payload:
ret_prompt_group = []
for generation_group in prompt_group:
ret_generation_group = []
for test_case in generation_group:
ret_generation_group.append(self.invoke_helper(test_case))
ret_prompt_group.append(ret_generation_group)
ret.append(ret_prompt_group)
return ret

def invoke_helper(self, payload: Dict[str, str]) -> bool:
Expand Down
14 changes: 10 additions & 4 deletions composer/utils/eval_client/local_eval_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@ class LocalEvalClient(EvalClient):

def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bool]]]:
"""Invoke a batch of provided payloads for code evaluations."""
return [[[self.invoke_helper(test_case)
for test_case in generation_group]
for generation_group in prompt_group]
for prompt_group in payload]
ret = []
for prompt_group in payload:
ret_prompt_group = []
for generation_group in prompt_group:
ret_generation_group = []
for test_case in generation_group:
ret_generation_group.append(self.invoke_helper(test_case))
ret_prompt_group.append(ret_generation_group)
ret.append(ret_prompt_group)
return ret

def invoke_helper(self, payload: Dict[str, str]) -> bool:
"""Invoke a provided dictionary payload to a multiprocessing subprocess that performs code eval."""
Expand Down
13 changes: 9 additions & 4 deletions composer/utils/eval_client/mosaicml_lambda_eval_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bo
log.error(f'Failed to get code eval output with unexpected error. Error: {e}')
break

ret = [[[ret_helper[cum_tests[i] + j * num_tests[i] + k]
for k in range(num_tests[i])]
for j in range(num_beams)]
for i in range(len(payload))]
ret = []
for i in range(len(payload)):
ret_payload = []
for j in range(num_beams):
ret_num_beams = []
for k in range(num_tests[i]):
ret_num_beams.append(ret_helper[cum_tests[i] + j * num_tests[i] + k])
ret_payload.append(ret_num_beams)
ret.append(ret_payload)
return ret
4 changes: 1 addition & 3 deletions composer/utils/module_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def replace_module_classes(
)
replaced_pairs = {}
children_to_parents_and_names: OrderedDict[torch.nn.Module,
List[Tuple[torch.nn.Module,
str,
]],
List[Tuple[torch.nn.Module, str]],
] = collections.OrderedDict()
_add_children_recursive(module, children_to_parents_and_names)
indices = indices if indices is not None else {c: 0 for c in policies}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ select = [
"LOG",
"PERF",
"PLE",
"COM812",
]

ignore = [
Expand Down
5 changes: 3 additions & 2 deletions tests/algorithms/test_alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def check_batch_reshaping(before, after, length):

assert k in after, 'No keys should be removed during sequence reshaping.'

assert after[k
].shape == input_ids_after_shape, 'All tensors should have the same size after sequence reshaping.'
assert after[k].shape == input_ids_after_shape, (
'All tensors should have the same size after sequence reshaping.'
)

b_numel = before[k].shape[0] * before[k].shape[1]
a_numel = after[k].shape[0] * after[k].shape[1]
Expand Down
6 changes: 4 additions & 2 deletions tests/algorithms/test_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def test_gradient_clipping_algorithm_with_deepspeed_enabled(
engine.run_event(Event.INIT)

# Make sure deepspeed_config's gradient_clipping field is set properly.
assert 'gradient_clipping' in state.deepspeed_config and state.deepspeed_config['gradient_clipping'
] == clipping_threshold
assert (
'gradient_clipping' in state.deepspeed_config and
state.deepspeed_config['gradient_clipping'] == clipping_threshold
)

# Make sure apply_gradient_clipping is not called.
apply_gc_fn.assert_not_called()
Expand Down
10 changes: 6 additions & 4 deletions tests/algorithms/test_seq_length_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def check_batch_truncation(before, after, length, preserve_end_of_sequence=False

assert k in after, 'No keys should be removed during sequence truncation.'

assert before[k].shape[0] == after[k].shape[0
], 'The batch size should not be changed during sequence truncation.'
assert before[k].shape[0] == after[k].shape[0], (
'The batch size should not be changed during sequence truncation.'
)

if before[k].ndim >= 2:

Expand All @@ -50,8 +51,9 @@ def check_batch_non_truncation(before, after, length):

assert k in after, 'No keys should be removed during sequence reshaping.'

assert after[k
].shape == input_ids_after_shape, 'All tensors should have the same size after sequence reshaping.'
assert after[k].shape == input_ids_after_shape, (
'All tensors should have the same size after sequence reshaping.'
)

b_numel = before[k].shape[0] * before[k].shape[1]
a_numel = after[k].shape[0] * after[k].shape[1]
Expand Down
96 changes: 47 additions & 49 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,55 +104,53 @@
except ImportError:
_NEPTUNE_INSTALLED = False

_callback_kwargs: Dict[Type[Callback],
Dict[str, Any],
] = {
Generate: {
'prompts': ['a', 'b', 'c'],
'interval': '1ba',
'batch_size': 2,
'max_new_tokens': 20,
},
RemoteUploaderDownloader: {
'bucket_uri': 'libcloud://.',
'backend_kwargs': {
'provider': 'local',
'container': '.',
'provider_kwargs': {
'key': '.',
},
},
'use_procs': False,
'num_concurrent_uploads': 1,
},
ThresholdStopper: {
'monitor': 'MulticlassAccuracy',
'dataloader_label': 'train',
'threshold': 0.99,
},
EarlyStopper: {
'monitor': 'MulticlassAccuracy',
'dataloader_label': 'train',
},
ExportForInferenceCallback: {
'save_format': 'torchscript',
'save_path': '/tmp/model.pth',
},
MLPerfCallback: {
'root_folder': '.',
'index': 0,
},
SpeedMonitor: {
'window_size': 1,
},
NeptuneLogger: {
'mode': 'debug',
},
composer.profiler.Profiler: {
'trace_handlers': [MagicMock()],
'schedule': composer.profiler.cyclic_schedule(),
},
}
_callback_kwargs: Dict[Type[Callback], Dict[str, Any]] = {
Generate: {
'prompts': ['a', 'b', 'c'],
'interval': '1ba',
'batch_size': 2,
'max_new_tokens': 20,
},
RemoteUploaderDownloader: {
'bucket_uri': 'libcloud://.',
'backend_kwargs': {
'provider': 'local',
'container': '.',
'provider_kwargs': {
'key': '.',
},
},
'use_procs': False,
'num_concurrent_uploads': 1,
},
ThresholdStopper: {
'monitor': 'MulticlassAccuracy',
'dataloader_label': 'train',
'threshold': 0.99,
},
EarlyStopper: {
'monitor': 'MulticlassAccuracy',
'dataloader_label': 'train',
},
ExportForInferenceCallback: {
'save_format': 'torchscript',
'save_path': '/tmp/model.pth',
},
MLPerfCallback: {
'root_folder': '.',
'index': 0,
},
SpeedMonitor: {
'window_size': 1,
},
NeptuneLogger: {
'mode': 'debug',
},
composer.profiler.Profiler: {
'trace_handlers': [MagicMock()],
'schedule': composer.profiler.cyclic_schedule(),
},
}

_callback_marks: Dict[
Type[Callback],
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def test_speed_monitor(flops_per_batch: bool):
assert isinstance(trainer.state.dataloader, collections.abc.Sized)
assert trainer.state.dataloader_label is not None
assert trainer.state.dataloader_len is not None
expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples) +
1) * int(trainer.state.timestamp.epoch)
calls_per_epoch = trainer.state.dataloader_len - len(speed_monitor.history_samples) + 1
expected_step_calls = calls_per_epoch * int(trainer.state.timestamp.epoch)
assert len(in_memory_logger.data['throughput/batches_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/device/batches_per_sec']) == expected_step_calls
Expand Down
5 changes: 3 additions & 2 deletions tests/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,9 @@ def __getitem__(self, index: int):
self.x = torch.randn(self.size * self.batch_size, self.feature_size)
if self.y is None:
self.y = torch.randint(0, self.num_classes, size=(self.size * self.batch_size,), dtype=torch.long)
return self.x[index * self.batch_size:(index + 1) *
self.batch_size], self.y[index * self.batch_size:(index + 1) * self.batch_size]
start_index = index * self.batch_size
end_index = start_index + self.batch_size
return self.x[start_index:end_index], self.y[start_index:end_index]


def dummy_transformer_classifier_batch(vocab_size=10, num_classes=2):
Expand Down
Loading

0 comments on commit 5ee552e

Please sign in to comment.