Skip to content

Commit

Permalink
Add trailing commas through ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Mar 7, 2024
1 parent 23def47 commit a00329f
Show file tree
Hide file tree
Showing 28 changed files with 161 additions and 150 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
Expand Down
12 changes: 7 additions & 5 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ def __init__(
self,
folder: Union[str, pathlib.Path] = '{run_name}/checkpoints',
filename: Union[str, pathlib.Path] = 'ep{epoch}-ba{batch}-rank{rank}.pt',
remote_file_name: Optional[Union[str,
pathlib.Path,
]] = '{run_name}/checkpoints/ep{epoch}-ba{batch}-rank{rank}.pt',
remote_file_name: Optional[Union[str, pathlib.Path]
] = ('{run_name}/checkpoints/'
'ep{epoch}-ba{batch}-rank{rank}.pt'),
latest_filename: Optional[Union[str, pathlib.Path]] = 'latest-rank{rank}.pt',
latest_remote_file_name: Optional[Union[str, pathlib.Path]] = '{run_name}/checkpoints/latest-rank{rank}.pt',
save_interval: Union[Time, str, int, Callable[[State, Event], bool]] = '1ep',
Expand Down Expand Up @@ -368,8 +368,10 @@ def epoch_checkpoint(self, state: State, logger: Logger):

def iteration_checkpoint(self, state: State, logger: Logger):
assert callable(self.save_interval)
if (self.save_interval(state, Event.ITERATION_CHECKPOINT) and
self.last_checkpoint_batch != state.timestamp.batch):
if (
self.save_interval(state, Event.ITERATION_CHECKPOINT) and
self.last_checkpoint_batch != state.timestamp.batch
):
self._save_checkpoint(
state,
logger,
Expand Down
9 changes: 1 addition & 8 deletions composer/callbacks/early_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,7 @@ def __init__(
self,
monitor: str,
dataloader_label: str,
comp: Optional[Union[str,
Callable[[
Any,
Any,
],
Any,
],
]] = None,
comp: Optional[Union[str, Callable[[Any, Any], Any]]] = None,
min_delta: float = 0.0,
patience: Union[int, str, Time] = 1,
):
Expand Down
9 changes: 1 addition & 8 deletions composer/callbacks/threshold_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,7 @@ def __init__(
dataloader_label: str,
threshold: float,
*,
comp: Optional[Union[str,
Callable[[
Any,
Any,
],
Any,
],
]] = None,
comp: Optional[Union[str, Callable[[Any, Any], Any]]] = None,
stop_on_batch: bool = False,
):
self.monitor = monitor
Expand Down
15 changes: 4 additions & 11 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,10 @@ def __init__(
self,
state: State,
logger: Logger,
algorithm_passes: Optional[Union[
passes.AlgorithmPass,
Tuple[passes.AlgorithmPass, int],
Sequence[Union[
passes.AlgorithmPass,
Tuple[
passes.AlgorithmPass,
int,
],
]],
]] = None,
algorithm_passes: Optional[Union[passes.AlgorithmPass,
Tuple[passes.AlgorithmPass, int],
Sequence[Union[passes.AlgorithmPass, Tuple[passes.AlgorithmPass, int]]],
]] = None,
):
self.logger = logger
self.state = state
Expand Down
6 changes: 4 additions & 2 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,10 @@ def fsdp_device_mesh(self):

@property
def load_fsdp_monolith_rank0_only(self):
return self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config[
'state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True
return (
self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config['state_dict_type'] == 'full' and
self.fsdp_config['load_monolith_rank0_only'] == True
)

def _get_integrations_state_dict(self) -> Dict[str, Any]:
"""Gets a dictionary of information about integrations to store in the state dict.
Expand Down
12 changes: 5 additions & 7 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,11 @@ def __init__(
self._exception_queue: Union[queue.Queue[Exception],
multiprocessing.JoinableQueue[Exception],
] = mp_ctx.JoinableQueue()
self._finished_cls: Union[
Callable[
[],
multiprocessing._EventType,
], # pyright: ignore[reportGeneralTypeIssues]
Type[threading.Event],
] = mp_ctx.Event
self._finished_cls: Union[Callable[[],
multiprocessing._EventType, # pyright: ignore[reportGeneralTypeIssues]
],
Type[threading.Event],
] = mp_ctx.Event
self._proc_class = mp_ctx.Process
else:
self._file_upload_queue = queue.Queue()
Expand Down
4 changes: 2 additions & 2 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,8 +770,8 @@ def compute(self):
n = self.num_generations

for k in self.pass_at_k:
pass_at_k = sum([self.estimator(n, int(c.item()), k)
for c in self.correct[complete]],) / complete.sum().item()
estimators = [self.estimator(n, int(c.item()), k) for c in self.correct[complete]]
pass_at_k = sum(estimators) / complete.sum().item()
results[f'pass@{k}'] = torch.tensor(pass_at_k)

if len(results) == 1: # backwards compatibility
Expand Down
6 changes: 4 additions & 2 deletions composer/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ def __init__(
torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json',
torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
torch_prof_memory_filename: Optional[str] = None,
torch_prof_memory_remote_file_name: Optional[
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html',
torch_prof_memory_remote_file_name: Optional[str] = (
'{run_name}/torch_memory_traces/'
'rank{rank}.{batch}.pt.memory_trace.html'
),
torch_prof_overwrite: bool = False,
torch_prof_use_gzip: bool = False,
torch_prof_record_shapes: bool = False,
Expand Down
10 changes: 6 additions & 4 deletions composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ def __init__(
filename: str = 'rank{rank}.{batch}.pt.trace.json',
remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
memory_filename: Optional[str] = None,
memory_remote_file_name: Optional[str
] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html',
memory_remote_file_name: Optional[str] = (
'{run_name}/torch_memory_traces/'
'rank{rank}.{batch}.pt.trace.memory.html'
),
overwrite: bool = False,
use_gzip: bool = False,
record_shapes: bool = False,
Expand Down Expand Up @@ -312,8 +314,8 @@ def handler_fn(prof: torch.profiler.profiler.profile):
export_memory_timeline_html(
prof,
memory_trace_file_name,
torch.cuda.current_device(),
) # type: ignore
torch.cuda.current_device(), # type: ignore
)
log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}')
if self.memory_remote_file_name is not None:
memory_trace_remote_file_name = format_name_with_dist_and_time(
Expand Down
5 changes: 3 additions & 2 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,9 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num
first_wrap_fn = checkpoint_wrapper if activation_checkpointing else (lambda module: module)
second_wrap_fn = (
lambda module: offload_wrapper(
first_wrap_fn(module) if activation_checkpointing else module,
) # type: ignore reportGeneralTypeIssues
first_wrap_fn(module)
if activation_checkpointing else module, # type: ignore reportGeneralTypeIssues
)
) if activation_cpu_offload else first_wrap_fn
else:
if not activation_checkpointing_reentrant:
Expand Down
14 changes: 8 additions & 6 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,9 +1862,7 @@ def fit(
# Schedulers
schedulers: Optional[Union[ComposerScheduler,
LRScheduler,
Sequence[Union[ComposerScheduler,
LRScheduler,
]],
Sequence[Union[ComposerScheduler, LRScheduler]],
]] = None,
scale_schedule_ratio: float = 1.0,
step_schedulers_every_batch: Optional[bool] = None,
Expand Down Expand Up @@ -2424,8 +2422,10 @@ def _train_loop(self) -> None:
self.engine.run_event(Event.EPOCH_CHECKPOINT)

# Increment iteration
if (self.state._iteration_length is not None and
self.state.timestamp.epoch_in_iteration == self.state._iteration_length):
if (
self.state._iteration_length is not None and
self.state.timestamp.epoch_in_iteration == self.state._iteration_length
):
self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_iteration()
self.engine.run_event(Event.ITERATION_END)
Expand Down Expand Up @@ -2507,7 +2507,9 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
for metric in self.state.train_metrics.values():
metric.reset()

total_loss_dict = {'loss/train/total': self.state.device.tensor_to_device(torch.zeros(size=(1,)))}
total_loss_dict = {
'loss/train/total': self.state.device.tensor_to_device(torch.zeros(size=(1,))),
}
found_cuda_oom = 0 # int since bool BOR not supported on all torch.distributed backends
try:
assert self.state.scaler is not None
Expand Down
13 changes: 9 additions & 4 deletions composer/utils/eval_client/lambda_eval_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ def __init__(self) -> None:

def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bool]]]:
"""Invoke a batch of provided payloads for code evaluations."""
ret = [[[self.invoke_helper(test_case)
for test_case in generation_group]
for generation_group in prompt_group]
for prompt_group in payload]
ret = []
for prompt_group in payload:
ret_prompt_group = []
for generation_group in prompt_group:
ret_generation_group = []
for test_case in generation_group:
ret_generation_group.append(self.invoke_helper(test_case))
ret_prompt_group.append(ret_generation_group)
ret.append(ret_prompt_group)
return ret

def invoke_helper(self, payload: Dict[str, str]) -> bool:
Expand Down
18 changes: 12 additions & 6 deletions composer/utils/eval_client/local_eval_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@ class LocalEvalClient(EvalClient):

def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bool]]]:
"""Invoke a batch of provided payloads for code evaluations."""
return [[[self.invoke_helper(test_case)
for test_case in generation_group]
for generation_group in prompt_group]
for prompt_group in payload]
ret = []
for prompt_group in payload:
ret_prompt_group = []
for generation_group in prompt_group:
ret_generation_group = []
for test_case in generation_group:
ret_generation_group.append(self.invoke_helper(test_case))
ret_prompt_group.append(ret_generation_group)
ret.append(ret_prompt_group)
return ret

def invoke_helper(self, payload: Dict[str, str]) -> bool:
"""Invoke a provided dictionary payload to a multiprocessing subprocess that performs code eval."""
Expand Down Expand Up @@ -55,8 +61,8 @@ def update_offline_helper(
test_output: str,
entry_point: str,
language: str,
val: multiprocessing.Value,
): # type: ignore
val: multiprocessing.Value, # type: ignore
):
"""Helper function to evaluate test case in a subprocess.
This function compiles the code generation,
Expand Down
13 changes: 9 additions & 4 deletions composer/utils/eval_client/mosaicml_lambda_eval_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bo
log.error(f'Failed to get code eval output with unexpected error. Error: {e}')
break

ret = [[[ret_helper[cum_tests[i] + j * num_tests[i] + k]
for k in range(num_tests[i])]
for j in range(num_beams)]
for i in range(len(payload))]
ret = []
for i in range(len(payload)):
ret_payload = []
for j in range(num_beams):
ret_num_beams = []
for k in range(num_tests[i]):
ret_num_beams.append(ret_helper[cum_tests[i] + j * num_tests[i] + k])
ret_payload.append(ret_num_beams)
ret.append(ret_payload)
return ret
4 changes: 2 additions & 2 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def create_interval_scheduler(
interval_event = Event.BATCH_CHECKPOINT if checkpoint_events else Event.BATCH_END
else:
raise NotImplementedError(
f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.',
)

last_batch_seen = -1
Expand Down Expand Up @@ -94,7 +94,7 @@ def check_interval(state: State, event: Event):
count = state.timestamp.get(state.max_duration.unit)
else:
raise NotImplementedError(
f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
f'Unknown interval: {time_interval.unit}. Must be TimeUnit.ITERATION, TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.',
)

threshold_passed = math.floor(previous_count / time_interval.value) != math.floor(count / time_interval.value)
Expand Down
4 changes: 1 addition & 3 deletions composer/utils/module_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def replace_module_classes(
)
replaced_pairs = {}
children_to_parents_and_names: OrderedDict[torch.nn.Module,
List[Tuple[torch.nn.Module,
str,
]],
List[Tuple[torch.nn.Module, str]],
] = collections.OrderedDict()
_add_children_recursive(module, children_to_parents_and_names)
indices = indices if indices is not None else {c: 0 for c in policies}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ select = [
"LOG",
"PERF",
"PLE",
"COM812",
]

ignore = [
Expand Down
5 changes: 3 additions & 2 deletions tests/algorithms/test_alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def check_batch_reshaping(before, after, length):

assert k in after, 'No keys should be removed during sequence reshaping.'

assert after[k
].shape == input_ids_after_shape, 'All tensors should have the same size after sequence reshaping.'
assert after[k].shape == input_ids_after_shape, (
'All tensors should have the same size after sequence reshaping.'
)

b_numel = before[k].shape[0] * before[k].shape[1]
a_numel = after[k].shape[0] * after[k].shape[1]
Expand Down
6 changes: 4 additions & 2 deletions tests/algorithms/test_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def test_gradient_clipping_algorithm_with_deepspeed_enabled(
engine.run_event(Event.INIT)

# Make sure deepspeed_config's gradient_clipping field is set properly.
assert 'gradient_clipping' in state.deepspeed_config and state.deepspeed_config['gradient_clipping'
] == clipping_threshold
assert (
'gradient_clipping' in state.deepspeed_config and
state.deepspeed_config['gradient_clipping'] == clipping_threshold
)

# Make sure apply_gradient_clipping is not called.
apply_gc_fn.assert_not_called()
Expand Down
10 changes: 6 additions & 4 deletions tests/algorithms/test_seq_length_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def check_batch_truncation(before, after, length, preserve_end_of_sequence=False

assert k in after, 'No keys should be removed during sequence truncation.'

assert before[k].shape[0] == after[k].shape[0
], 'The batch size should not be changed during sequence truncation.'
assert before[k].shape[0] == after[k].shape[0], (
'The batch size should not be changed during sequence truncation.'
)

if before[k].ndim >= 2:

Expand All @@ -50,8 +51,9 @@ def check_batch_non_truncation(before, after, length):

assert k in after, 'No keys should be removed during sequence reshaping.'

assert after[k
].shape == input_ids_after_shape, 'All tensors should have the same size after sequence reshaping.'
assert after[k].shape == input_ids_after_shape, (
'All tensors should have the same size after sequence reshaping.'
)

b_numel = before[k].shape[0] * before[k].shape[1]
a_numel = after[k].shape[0] * after[k].shape[1]
Expand Down
Loading

0 comments on commit a00329f

Please sign in to comment.